1use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15 #[serde(default)]
16 pub id: Option<String>,
17 pub function: FunctionCall,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23 pub name: String,
24 pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28 pub fn to_agent_action(&self) -> Result<AgentAction> {
30 let args = &self.function.arguments;
31
32 let action = match self.function.name.as_str() {
33 "read_file" => {
34 let path = Self::get_string_arg(args, "path")?;
35 AgentAction::ReadFile { paths: vec![path] }
36 }
37
38 "write_file" => {
39 let path = Self::get_string_arg(args, "path")?;
40 let content = Self::get_string_arg(args, "content")?;
41 AgentAction::WriteFile { path, content }
42 }
43
44 "delete_file" => {
45 let path = Self::get_string_arg(args, "path")?;
46 AgentAction::DeleteFile { path }
47 }
48
49 "create_directory" => {
50 let path = Self::get_string_arg(args, "path")?;
51 AgentAction::CreateDirectory { path }
52 }
53
54 "execute_command" => {
55 let command = Self::get_string_arg(args, "command")?;
56 let working_dir = Self::get_optional_string_arg(args, "working_dir");
57 let timeout = args.get("timeout").and_then(|v| v.as_u64());
58 AgentAction::ExecuteCommand {
59 command,
60 working_dir,
61 timeout,
62 }
63 }
64
65 "git_diff" => {
66 let path = Self::get_optional_string_arg(args, "path");
67 AgentAction::GitDiff { paths: vec![path] }
68 }
69
70 "git_status" => AgentAction::GitStatus,
71
72 "git_commit" => {
73 let message = Self::get_string_arg(args, "message")?;
74 let files = Self::get_string_array_arg(args, "files")?;
75 AgentAction::GitCommit { message, files }
76 }
77
78 "web_search" => {
79 let query = Self::get_string_arg(args, "query")?;
80 let max_results = Self::get_int_arg(args, "max_results")
81 .or_else(|_| Self::get_int_arg(args, "result_count"))
82 .unwrap_or(5)
83 .clamp(1, 10);
84 AgentAction::WebSearch {
85 queries: vec![(query, max_results)],
86 }
87 }
88
89 "edit_file" => {
90 let path = Self::get_string_arg(args, "path")?;
91 let old_string = Self::get_string_arg(args, "old_string")?;
92 let new_string = Self::get_string_arg(args, "new_string")?;
93 AgentAction::EditFile { path, old_string, new_string }
94 }
95
96 "web_fetch" => {
97 let url = Self::get_string_arg(args, "url")?;
98 AgentAction::WebFetch { url }
99 }
100
101 name => {
102 return Err(anyhow!(
103 "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
104 name
105 ))
106 }
107 };
108
109 Ok(action)
110 }
111
112 fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
115 args.get(key)
116 .and_then(|v| v.as_str())
117 .map(|s| s.to_string())
118 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
119 }
120
121 fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
122 args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
123 }
124
125 fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
126 args.get(key)
127 .and_then(|v| v.as_u64())
128 .map(|n| n as usize)
129 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
130 }
131
132 fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
133 args.get(key)
134 .and_then(|v| v.as_array())
135 .map(|arr| {
136 arr.iter()
137 .filter_map(|item| item.as_str().map(|s| s.to_string()))
138 .collect()
139 })
140 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
141 }
142}
143
144pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
146 tool_calls
147 .iter()
148 .filter_map(|tc| match tc.to_agent_action() {
149 Ok(action) => Some(action),
150 Err(e) => {
151 warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
152 None
153 }
154 })
155 .collect()
156}
157
158pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
161 if actions.is_empty() {
162 return actions;
163 }
164
165 let mut result = Vec::new();
166 let mut current_group: Vec<String> = Vec::new();
167
168 for action in actions {
169 match action {
170 AgentAction::ReadFile { paths } => {
171 current_group.extend(paths);
172 }
173 other => {
174 if !current_group.is_empty() {
176 result.push(AgentAction::ReadFile {
177 paths: std::mem::take(&mut current_group),
178 });
179 }
180 result.push(other);
181 }
182 }
183 }
184
185 if !current_group.is_empty() {
187 result.push(AgentAction::ReadFile {
188 paths: current_group,
189 });
190 }
191
192 result
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use serde_json::json;
199
200 #[test]
201 fn test_parse_read_file_tool_call() {
202 let tool_call = ToolCall {
203 id: Some("call_123".to_string()),
204 function: FunctionCall {
205 name: "read_file".to_string(),
206 arguments: json!({
207 "path": "src/main.rs"
208 }),
209 },
210 };
211
212 let action = tool_call.to_agent_action().unwrap();
213 match action {
214 AgentAction::ReadFile { paths } => {
215 assert_eq!(paths.len(), 1);
216 assert_eq!(paths[0], "src/main.rs");
217 }
218 _ => panic!("Expected ReadFile action"),
219 }
220 }
221
222 #[test]
223 fn test_parse_write_file_tool_call() {
224 let tool_call = ToolCall {
225 id: None,
226 function: FunctionCall {
227 name: "write_file".to_string(),
228 arguments: json!({
229 "path": "test.txt",
230 "content": "Hello, world!"
231 }),
232 },
233 };
234
235 let action = tool_call.to_agent_action().unwrap();
236 match action {
237 AgentAction::WriteFile { path, content } => {
238 assert_eq!(path, "test.txt");
239 assert_eq!(content, "Hello, world!");
240 }
241 _ => panic!("Expected WriteFile action"),
242 }
243 }
244
245 #[test]
246 fn test_parse_execute_command_tool_call() {
247 let tool_call = ToolCall {
248 id: None,
249 function: FunctionCall {
250 name: "execute_command".to_string(),
251 arguments: json!({
252 "command": "cargo test",
253 "working_dir": "/path/to/project"
254 }),
255 },
256 };
257
258 let action = tool_call.to_agent_action().unwrap();
259 match action {
260 AgentAction::ExecuteCommand {
261 command,
262 working_dir,
263 timeout,
264 } => {
265 assert_eq!(command, "cargo test");
266 assert_eq!(working_dir, Some("/path/to/project".to_string()));
267 assert_eq!(timeout, None);
268 }
269 _ => panic!("Expected ExecuteCommand action"),
270 }
271 }
272
273 #[test]
274 fn test_parse_web_search_tool_call() {
275 let tool_call = ToolCall {
276 id: None,
277 function: FunctionCall {
278 name: "web_search".to_string(),
279 arguments: json!({
280 "query": "Rust async features",
281 "result_count": 5
282 }),
283 },
284 };
285
286 let action = tool_call.to_agent_action().unwrap();
287 match action {
288 AgentAction::WebSearch { queries } => {
289 assert_eq!(queries.len(), 1);
290 assert_eq!(queries[0].0, "Rust async features");
291 assert_eq!(queries[0].1, 5);
292 }
293 _ => panic!("Expected WebSearch action"),
294 }
295 }
296
297 #[test]
298 fn test_unknown_tool_returns_error() {
299 let tool_call = ToolCall {
300 id: None,
301 function: FunctionCall {
302 name: "unknown_tool".to_string(),
303 arguments: json!({}),
304 },
305 };
306
307 assert!(tool_call.to_agent_action().is_err());
308 }
309
310 #[test]
311 fn test_group_parallel_reads() {
312 let actions = vec![
313 AgentAction::ReadFile {
314 paths: vec!["file1.rs".to_string()],
315 },
316 AgentAction::ReadFile {
317 paths: vec!["file2.rs".to_string()],
318 },
319 AgentAction::ReadFile {
320 paths: vec!["file3.rs".to_string()],
321 },
322 ];
323
324 let grouped = group_parallel_reads(actions);
325 assert_eq!(grouped.len(), 1);
326
327 match &grouped[0] {
328 AgentAction::ReadFile { paths } => {
329 assert_eq!(paths.len(), 3);
330 assert_eq!(paths[0], "file1.rs");
331 assert_eq!(paths[1], "file2.rs");
332 assert_eq!(paths[2], "file3.rs");
333 }
334 _ => panic!("Expected ReadFile action"),
335 }
336 }
337
338 #[test]
339 fn test_group_parallel_reads_single_read() {
340 let actions = vec![AgentAction::ReadFile {
341 paths: vec!["file1.rs".to_string()],
342 }];
343
344 let grouped = group_parallel_reads(actions);
345 assert_eq!(grouped.len(), 1);
346
347 match &grouped[0] {
348 AgentAction::ReadFile { paths } => {
349 assert_eq!(paths.len(), 1);
350 assert_eq!(paths[0], "file1.rs");
351 }
352 _ => panic!("Expected ReadFile action"),
353 }
354 }
355}