agtrace_types/tool/
call.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use super::args::{ExecuteArgs, FileEditArgs, FileReadArgs, FileWriteArgs, McpArgs, SearchArgs};
5use super::kind::ToolKind;
6
7/// Normalized tool call with structured arguments
8///
9/// This enum provides type-safe access to common tool call patterns while
10/// maintaining compatibility with the original JSON structure.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(untagged)]
13pub enum ToolCallPayload {
14    /// File read operation (Read, Glob, etc.)
15    FileRead {
16        name: String,
17        arguments: FileReadArgs,
18        #[serde(default, skip_serializing_if = "Option::is_none")]
19        provider_call_id: Option<String>,
20    },
21
22    /// File edit operation (Edit)
23    FileEdit {
24        name: String,
25        arguments: FileEditArgs,
26        #[serde(default, skip_serializing_if = "Option::is_none")]
27        provider_call_id: Option<String>,
28    },
29
30    /// File write operation (Write)
31    FileWrite {
32        name: String,
33        arguments: FileWriteArgs,
34        #[serde(default, skip_serializing_if = "Option::is_none")]
35        provider_call_id: Option<String>,
36    },
37
38    /// Execute/shell command (Bash, etc.)
39    Execute {
40        name: String,
41        arguments: ExecuteArgs,
42        #[serde(default, skip_serializing_if = "Option::is_none")]
43        provider_call_id: Option<String>,
44    },
45
46    /// Search operation (Grep, WebSearch, etc.)
47    Search {
48        name: String,
49        arguments: SearchArgs,
50        #[serde(default, skip_serializing_if = "Option::is_none")]
51        provider_call_id: Option<String>,
52    },
53
54    /// MCP (Model Context Protocol) tool call
55    Mcp {
56        name: String,
57        arguments: McpArgs,
58        #[serde(default, skip_serializing_if = "Option::is_none")]
59        provider_call_id: Option<String>,
60    },
61
62    /// Generic/fallback for unknown or custom tools
63    Generic {
64        name: String,
65        arguments: Value,
66        #[serde(default, skip_serializing_if = "Option::is_none")]
67        provider_call_id: Option<String>,
68    },
69}
70
71impl ToolCallPayload {
72    /// Get tool name regardless of variant
73    pub fn name(&self) -> &str {
74        match self {
75            ToolCallPayload::FileRead { name, .. } => name,
76            ToolCallPayload::FileEdit { name, .. } => name,
77            ToolCallPayload::FileWrite { name, .. } => name,
78            ToolCallPayload::Execute { name, .. } => name,
79            ToolCallPayload::Search { name, .. } => name,
80            ToolCallPayload::Mcp { name, .. } => name,
81            ToolCallPayload::Generic { name, .. } => name,
82        }
83    }
84
85    /// Get provider call ID regardless of variant
86    pub fn provider_call_id(&self) -> Option<&str> {
87        match self {
88            ToolCallPayload::FileRead {
89                provider_call_id, ..
90            } => provider_call_id.as_deref(),
91            ToolCallPayload::FileEdit {
92                provider_call_id, ..
93            } => provider_call_id.as_deref(),
94            ToolCallPayload::FileWrite {
95                provider_call_id, ..
96            } => provider_call_id.as_deref(),
97            ToolCallPayload::Execute {
98                provider_call_id, ..
99            } => provider_call_id.as_deref(),
100            ToolCallPayload::Search {
101                provider_call_id, ..
102            } => provider_call_id.as_deref(),
103            ToolCallPayload::Mcp {
104                provider_call_id, ..
105            } => provider_call_id.as_deref(),
106            ToolCallPayload::Generic {
107                provider_call_id, ..
108            } => provider_call_id.as_deref(),
109        }
110    }
111
112    /// Derive semantic ToolKind from ToolCallPayload variant
113    pub fn kind(&self) -> ToolKind {
114        match self {
115            ToolCallPayload::FileRead { .. } => ToolKind::Read,
116            ToolCallPayload::FileEdit { .. } => ToolKind::Write,
117            ToolCallPayload::FileWrite { .. } => ToolKind::Write,
118            ToolCallPayload::Execute { .. } => ToolKind::Execute,
119            ToolCallPayload::Search { .. } => ToolKind::Search,
120            ToolCallPayload::Mcp { .. } => ToolKind::Other,
121            ToolCallPayload::Generic { .. } => ToolKind::Other,
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_tool_call_serialization_roundtrip() {
132        let original = ToolCallPayload::FileRead {
133            name: "Read".to_string(),
134            arguments: FileReadArgs {
135                file_path: Some("/path/to/file.rs".to_string()),
136                path: None,
137                pattern: None,
138                extra: serde_json::json!({}),
139            },
140            provider_call_id: Some("call_123".to_string()),
141        };
142
143        let json = serde_json::to_string(&original).unwrap();
144        let deserialized: ToolCallPayload = serde_json::from_str(&json).unwrap();
145
146        match deserialized {
147            ToolCallPayload::FileRead {
148                name,
149                arguments,
150                provider_call_id,
151            } => {
152                assert_eq!(name, "Read");
153                assert_eq!(arguments.file_path, Some("/path/to/file.rs".to_string()));
154                assert_eq!(provider_call_id, Some("call_123".to_string()));
155            }
156            _ => panic!("Expected FileRead variant"),
157        }
158    }
159
160    #[test]
161    fn test_tool_call_kind_derivation() {
162        let read_payload = ToolCallPayload::FileRead {
163            name: "Read".to_string(),
164            arguments: FileReadArgs {
165                file_path: Some("/path".to_string()),
166                path: None,
167                pattern: None,
168                extra: serde_json::json!({}),
169            },
170            provider_call_id: None,
171        };
172        assert_eq!(read_payload.kind(), ToolKind::Read);
173
174        let edit_payload = ToolCallPayload::FileEdit {
175            name: "Edit".to_string(),
176            arguments: FileEditArgs {
177                file_path: "/path".to_string(),
178                old_string: "old".to_string(),
179                new_string: "new".to_string(),
180                replace_all: false,
181            },
182            provider_call_id: None,
183        };
184        assert_eq!(edit_payload.kind(), ToolKind::Write);
185
186        let write_payload = ToolCallPayload::FileWrite {
187            name: "Write".to_string(),
188            arguments: FileWriteArgs {
189                file_path: "/path".to_string(),
190                content: "content".to_string(),
191            },
192            provider_call_id: None,
193        };
194        assert_eq!(write_payload.kind(), ToolKind::Write);
195
196        let exec_payload = ToolCallPayload::Execute {
197            name: "Bash".to_string(),
198            arguments: ExecuteArgs {
199                command: Some("ls".to_string()),
200                description: None,
201                timeout: None,
202                extra: serde_json::json!({}),
203            },
204            provider_call_id: None,
205        };
206        assert_eq!(exec_payload.kind(), ToolKind::Execute);
207
208        let search_payload = ToolCallPayload::Search {
209            name: "Grep".to_string(),
210            arguments: SearchArgs {
211                pattern: Some("pattern".to_string()),
212                query: None,
213                input: None,
214                path: None,
215                extra: serde_json::json!({}),
216            },
217            provider_call_id: None,
218        };
219        assert_eq!(search_payload.kind(), ToolKind::Search);
220
221        let mcp_payload = ToolCallPayload::Mcp {
222            name: "mcp__o3__search".to_string(),
223            arguments: McpArgs {
224                inner: serde_json::json!({}),
225            },
226            provider_call_id: None,
227        };
228        assert_eq!(mcp_payload.kind(), ToolKind::Other);
229
230        let generic_payload = ToolCallPayload::Generic {
231            name: "CustomTool".to_string(),
232            arguments: serde_json::json!({}),
233            provider_call_id: None,
234        };
235        assert_eq!(generic_payload.kind(), ToolKind::Other);
236    }
237}