agtrace_providers/
normalization.rs

1// Tool call normalization from raw provider data to typed ToolCallPayload variants
2//
3// Rationale for provider-layer placement:
4//   This module contains provider-specific knowledge about tool names and their
5//   argument schemas. While the ToolCallPayload enum itself is in agtrace-types
6//   (domain model), the logic to map raw tool names to typed variants belongs here.
7//
8// Design principle:
9//   - agtrace-types: Defines domain model structure (ToolCallPayload enum)
10//   - agtrace-providers: Knows how to normalize provider data into domain model
11//   - This separation keeps types pure and provider logic centralized
12
13use agtrace_types::ToolCallPayload;
14use serde_json::Value;
15
16/// Normalize raw tool call data into a typed ToolCallPayload variant
17///
18/// This function encapsulates provider-specific knowledge about:
19/// - Tool name mapping (e.g., "Read" -> FileRead variant)
20/// - Argument schema parsing (e.g., JSON -> FileReadArgs)
21/// - Fallback handling (unknown tools -> Generic variant)
22///
23/// # Arguments
24/// * `name` - Tool name from provider (e.g., "Read", "Bash", "mcp__o3__search")
25/// * `arguments` - Raw JSON arguments from provider
26/// * `provider_call_id` - Optional provider-specific call identifier
27///
28/// # Returns
29/// Typed ToolCallPayload variant with parsed arguments, or Generic variant as fallback
30pub fn normalize_tool_call(
31    name: String,
32    arguments: Value,
33    provider_call_id: Option<String>,
34) -> ToolCallPayload {
35    // Try to parse into specific variants based on name
36    match name.as_str() {
37        "Read" | "Glob" => {
38            if let Ok(args) = serde_json::from_value(arguments.clone()) {
39                return ToolCallPayload::FileRead {
40                    name,
41                    arguments: args,
42                    provider_call_id,
43                };
44            }
45        }
46        "Edit" => {
47            if let Ok(args) = serde_json::from_value(arguments.clone()) {
48                return ToolCallPayload::FileEdit {
49                    name,
50                    arguments: args,
51                    provider_call_id,
52                };
53            }
54        }
55        "Write" => {
56            if let Ok(args) = serde_json::from_value(arguments.clone()) {
57                return ToolCallPayload::FileWrite {
58                    name,
59                    arguments: args,
60                    provider_call_id,
61                };
62            }
63        }
64        "Bash" | "KillShell" | "BashOutput" => {
65            if let Ok(args) = serde_json::from_value(arguments.clone()) {
66                return ToolCallPayload::Execute {
67                    name,
68                    arguments: args,
69                    provider_call_id,
70                };
71            }
72        }
73        "Grep" | "WebSearch" | "WebFetch" => {
74            if let Ok(args) = serde_json::from_value(arguments.clone()) {
75                return ToolCallPayload::Search {
76                    name,
77                    arguments: args,
78                    provider_call_id,
79                };
80            }
81        }
82        _ if name.starts_with("mcp__") => {
83            if let Ok(args) = serde_json::from_value(arguments.clone()) {
84                return ToolCallPayload::Mcp {
85                    name,
86                    arguments: args,
87                    provider_call_id,
88                };
89            }
90        }
91        _ => {}
92    }
93
94    // Fallback to generic
95    ToolCallPayload::Generic {
96        name,
97        arguments,
98        provider_call_id,
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use serde_json::json;
106
107    #[test]
108    fn test_normalize_file_read() {
109        let payload = normalize_tool_call(
110            "Read".to_string(),
111            json!({"file_path": "/path/to/file.rs"}),
112            Some("call_123".to_string()),
113        );
114
115        match payload {
116            ToolCallPayload::FileRead {
117                name,
118                arguments,
119                provider_call_id,
120            } => {
121                assert_eq!(name, "Read");
122                assert_eq!(arguments.path(), Some("/path/to/file.rs"));
123                assert_eq!(provider_call_id, Some("call_123".to_string()));
124            }
125            _ => panic!("Expected FileRead variant"),
126        }
127    }
128
129    #[test]
130    fn test_normalize_execute() {
131        let payload = normalize_tool_call(
132            "Bash".to_string(),
133            json!({"command": "ls -la"}),
134            Some("call_456".to_string()),
135        );
136
137        match payload {
138            ToolCallPayload::Execute {
139                name,
140                arguments,
141                provider_call_id,
142            } => {
143                assert_eq!(name, "Bash");
144                assert_eq!(arguments.command, Some("ls -la".to_string()));
145                assert_eq!(provider_call_id, Some("call_456".to_string()));
146            }
147            _ => panic!("Expected Execute variant"),
148        }
149    }
150
151    #[test]
152    fn test_normalize_mcp_tool() {
153        let payload = normalize_tool_call(
154            "mcp__o3__search".to_string(),
155            json!({"query": "test"}),
156            Some("call_789".to_string()),
157        );
158
159        match payload {
160            ToolCallPayload::Mcp {
161                name,
162                arguments,
163                provider_call_id,
164            } => {
165                assert_eq!(name, "mcp__o3__search");
166                // McpArgs wraps raw JSON, verify it contains the query
167                assert_eq!(
168                    arguments.inner.get("query").and_then(|v| v.as_str()),
169                    Some("test")
170                );
171                assert_eq!(provider_call_id, Some("call_789".to_string()));
172            }
173            _ => panic!("Expected Mcp variant"),
174        }
175    }
176
177    #[test]
178    fn test_normalize_unknown_tool_fallback() {
179        let payload = normalize_tool_call(
180            "UnknownTool".to_string(),
181            json!({"foo": "bar"}),
182            Some("call_999".to_string()),
183        );
184
185        match payload {
186            ToolCallPayload::Generic {
187                name,
188                arguments,
189                provider_call_id,
190            } => {
191                assert_eq!(name, "UnknownTool");
192                assert_eq!(arguments, json!({"foo": "bar"}));
193                assert_eq!(provider_call_id, Some("call_999".to_string()));
194            }
195            _ => panic!("Expected Generic variant for unknown tool"),
196        }
197    }
198
199    #[test]
200    fn test_normalize_invalid_arguments_fallback() {
201        // FileReadArgs has `extra: Value` field, so it accepts any fields
202        // This test verifies that invalid fields are captured in `extra`
203        let payload = normalize_tool_call(
204            "Read".to_string(),
205            json!({"invalid_field": 123}),
206            Some("call_000".to_string()),
207        );
208
209        // Should parse as FileRead with invalid field in `extra`
210        match payload {
211            ToolCallPayload::FileRead {
212                name, arguments, ..
213            } => {
214                assert_eq!(name, "Read");
215                assert_eq!(arguments.file_path, None);
216                assert_eq!(arguments.path, None);
217                assert_eq!(arguments.pattern, None);
218                assert_eq!(arguments.extra.get("invalid_field"), Some(&json!(123)));
219            }
220            _ => panic!("Expected FileRead variant, got: {:?}", payload.kind()),
221        }
222    }
223}