1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use super::args::{ExecuteArgs, FileEditArgs, FileReadArgs, FileWriteArgs, McpArgs, SearchArgs};
5use super::kind::ToolKind;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(untagged)]
13pub enum ToolCallPayload {
14 FileRead {
16 name: String,
17 arguments: FileReadArgs,
18 #[serde(default, skip_serializing_if = "Option::is_none")]
19 provider_call_id: Option<String>,
20 },
21
22 FileEdit {
24 name: String,
25 arguments: FileEditArgs,
26 #[serde(default, skip_serializing_if = "Option::is_none")]
27 provider_call_id: Option<String>,
28 },
29
30 FileWrite {
32 name: String,
33 arguments: FileWriteArgs,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
35 provider_call_id: Option<String>,
36 },
37
38 Execute {
40 name: String,
41 arguments: ExecuteArgs,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
43 provider_call_id: Option<String>,
44 },
45
46 Search {
48 name: String,
49 arguments: SearchArgs,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
51 provider_call_id: Option<String>,
52 },
53
54 Mcp {
56 name: String,
57 arguments: McpArgs,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
59 provider_call_id: Option<String>,
60 },
61
62 Generic {
64 name: String,
65 arguments: Value,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
67 provider_call_id: Option<String>,
68 },
69}
70
71impl ToolCallPayload {
72 pub fn name(&self) -> &str {
74 match self {
75 ToolCallPayload::FileRead { name, .. } => name,
76 ToolCallPayload::FileEdit { name, .. } => name,
77 ToolCallPayload::FileWrite { name, .. } => name,
78 ToolCallPayload::Execute { name, .. } => name,
79 ToolCallPayload::Search { name, .. } => name,
80 ToolCallPayload::Mcp { name, .. } => name,
81 ToolCallPayload::Generic { name, .. } => name,
82 }
83 }
84
85 pub fn provider_call_id(&self) -> Option<&str> {
87 match self {
88 ToolCallPayload::FileRead {
89 provider_call_id, ..
90 } => provider_call_id.as_deref(),
91 ToolCallPayload::FileEdit {
92 provider_call_id, ..
93 } => provider_call_id.as_deref(),
94 ToolCallPayload::FileWrite {
95 provider_call_id, ..
96 } => provider_call_id.as_deref(),
97 ToolCallPayload::Execute {
98 provider_call_id, ..
99 } => provider_call_id.as_deref(),
100 ToolCallPayload::Search {
101 provider_call_id, ..
102 } => provider_call_id.as_deref(),
103 ToolCallPayload::Mcp {
104 provider_call_id, ..
105 } => provider_call_id.as_deref(),
106 ToolCallPayload::Generic {
107 provider_call_id, ..
108 } => provider_call_id.as_deref(),
109 }
110 }
111
112 pub fn kind(&self) -> ToolKind {
114 match self {
115 ToolCallPayload::FileRead { .. } => ToolKind::Read,
116 ToolCallPayload::FileEdit { .. } => ToolKind::Write,
117 ToolCallPayload::FileWrite { .. } => ToolKind::Write,
118 ToolCallPayload::Execute { .. } => ToolKind::Execute,
119 ToolCallPayload::Search { .. } => ToolKind::Search,
120 ToolCallPayload::Mcp { .. } => ToolKind::Other,
121 ToolCallPayload::Generic { .. } => ToolKind::Other,
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_tool_call_serialization_roundtrip() {
132 let original = ToolCallPayload::FileRead {
133 name: "Read".to_string(),
134 arguments: FileReadArgs {
135 file_path: Some("/path/to/file.rs".to_string()),
136 path: None,
137 pattern: None,
138 extra: serde_json::json!({}),
139 },
140 provider_call_id: Some("call_123".to_string()),
141 };
142
143 let json = serde_json::to_string(&original).unwrap();
144 let deserialized: ToolCallPayload = serde_json::from_str(&json).unwrap();
145
146 match deserialized {
147 ToolCallPayload::FileRead {
148 name,
149 arguments,
150 provider_call_id,
151 } => {
152 assert_eq!(name, "Read");
153 assert_eq!(arguments.file_path, Some("/path/to/file.rs".to_string()));
154 assert_eq!(provider_call_id, Some("call_123".to_string()));
155 }
156 _ => panic!("Expected FileRead variant"),
157 }
158 }
159
160 #[test]
161 fn test_tool_call_kind_derivation() {
162 let read_payload = ToolCallPayload::FileRead {
163 name: "Read".to_string(),
164 arguments: FileReadArgs {
165 file_path: Some("/path".to_string()),
166 path: None,
167 pattern: None,
168 extra: serde_json::json!({}),
169 },
170 provider_call_id: None,
171 };
172 assert_eq!(read_payload.kind(), ToolKind::Read);
173
174 let edit_payload = ToolCallPayload::FileEdit {
175 name: "Edit".to_string(),
176 arguments: FileEditArgs {
177 file_path: "/path".to_string(),
178 old_string: "old".to_string(),
179 new_string: "new".to_string(),
180 replace_all: false,
181 },
182 provider_call_id: None,
183 };
184 assert_eq!(edit_payload.kind(), ToolKind::Write);
185
186 let write_payload = ToolCallPayload::FileWrite {
187 name: "Write".to_string(),
188 arguments: FileWriteArgs {
189 file_path: "/path".to_string(),
190 content: "content".to_string(),
191 },
192 provider_call_id: None,
193 };
194 assert_eq!(write_payload.kind(), ToolKind::Write);
195
196 let exec_payload = ToolCallPayload::Execute {
197 name: "Bash".to_string(),
198 arguments: ExecuteArgs {
199 command: Some("ls".to_string()),
200 description: None,
201 timeout: None,
202 extra: serde_json::json!({}),
203 },
204 provider_call_id: None,
205 };
206 assert_eq!(exec_payload.kind(), ToolKind::Execute);
207
208 let search_payload = ToolCallPayload::Search {
209 name: "Grep".to_string(),
210 arguments: SearchArgs {
211 pattern: Some("pattern".to_string()),
212 query: None,
213 input: None,
214 path: None,
215 extra: serde_json::json!({}),
216 },
217 provider_call_id: None,
218 };
219 assert_eq!(search_payload.kind(), ToolKind::Search);
220
221 let mcp_payload = ToolCallPayload::Mcp {
222 name: "mcp__o3__search".to_string(),
223 arguments: McpArgs {
224 inner: serde_json::json!({}),
225 },
226 provider_call_id: None,
227 };
228 assert_eq!(mcp_payload.kind(), ToolKind::Other);
229
230 let generic_payload = ToolCallPayload::Generic {
231 name: "CustomTool".to_string(),
232 arguments: serde_json::json!({}),
233 provider_call_id: None,
234 };
235 assert_eq!(generic_payload.kind(), ToolKind::Other);
236 }
237}