1use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15 #[serde(default)]
16 pub id: Option<String>,
17 pub function: FunctionCall,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23 pub name: String,
24 pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28 pub fn to_agent_action(&self) -> Result<AgentAction> {
30 let args = &self.function.arguments;
31
32 let action = match self.function.name.as_str() {
33 "read_file" => {
34 let path = Self::get_string_arg(args, "path")?;
35 AgentAction::ReadFile { paths: vec![path] }
36 }
37
38 "write_file" => {
39 let path = Self::get_string_arg(args, "path")?;
40 let content = Self::get_string_arg(args, "content")?;
41 AgentAction::WriteFile { path, content }
42 }
43
44 "delete_file" => {
45 let path = Self::get_string_arg(args, "path")?;
46 AgentAction::DeleteFile { path }
47 }
48
49 "create_directory" => {
50 let path = Self::get_string_arg(args, "path")?;
51 AgentAction::CreateDirectory { path }
52 }
53
54 "execute_command" => {
55 let command = Self::get_string_arg(args, "command")?;
56 let working_dir = Self::get_optional_string_arg(args, "working_dir");
57 AgentAction::ExecuteCommand {
58 command,
59 working_dir,
60 }
61 }
62
63 "git_diff" => {
64 let path = Self::get_optional_string_arg(args, "path");
65 AgentAction::GitDiff { paths: vec![path] }
66 }
67
68 "git_status" => AgentAction::GitStatus,
69
70 "git_commit" => {
71 let message = Self::get_string_arg(args, "message")?;
72 let files = Self::get_string_array_arg(args, "files")?;
73 AgentAction::GitCommit { message, files }
74 }
75
76 "web_search" => {
77 let query = Self::get_string_arg(args, "query")?;
78 let max_results = Self::get_int_arg(args, "max_results")
79 .or_else(|_| Self::get_int_arg(args, "result_count"))
80 .unwrap_or(5)
81 .clamp(1, 10);
82 AgentAction::WebSearch {
83 queries: vec![(query, max_results)],
84 }
85 }
86
87 "edit_file" => {
88 let path = Self::get_string_arg(args, "path")?;
89 let old_string = Self::get_string_arg(args, "old_string")?;
90 let new_string = Self::get_string_arg(args, "new_string")?;
91 AgentAction::EditFile { path, old_string, new_string }
92 }
93
94 "web_fetch" => {
95 let url = Self::get_string_arg(args, "url")?;
96 AgentAction::WebFetch { url }
97 }
98
99 name => {
100 return Err(anyhow!(
101 "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
102 name
103 ))
104 }
105 };
106
107 Ok(action)
108 }
109
110 fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
113 args.get(key)
114 .and_then(|v| v.as_str())
115 .map(|s| s.to_string())
116 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
117 }
118
119 fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
120 args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
121 }
122
123 fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
124 args.get(key)
125 .and_then(|v| v.as_u64())
126 .map(|n| n as usize)
127 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
128 }
129
130 fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
131 args.get(key)
132 .and_then(|v| v.as_array())
133 .map(|arr| {
134 arr.iter()
135 .filter_map(|item| item.as_str().map(|s| s.to_string()))
136 .collect()
137 })
138 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
139 }
140}
141
142pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
144 tool_calls
145 .iter()
146 .filter_map(|tc| match tc.to_agent_action() {
147 Ok(action) => Some(action),
148 Err(e) => {
149 warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
150 None
151 }
152 })
153 .collect()
154}
155
156pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
159 if actions.is_empty() {
160 return actions;
161 }
162
163 let mut result = Vec::new();
164 let mut current_group: Vec<String> = Vec::new();
165
166 for action in actions {
167 match action {
168 AgentAction::ReadFile { paths } => {
169 current_group.extend(paths);
170 }
171 other => {
172 if !current_group.is_empty() {
174 result.push(AgentAction::ReadFile {
175 paths: std::mem::take(&mut current_group),
176 });
177 }
178 result.push(other);
179 }
180 }
181 }
182
183 if !current_group.is_empty() {
185 result.push(AgentAction::ReadFile {
186 paths: current_group,
187 });
188 }
189
190 result
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use serde_json::json;
197
198 #[test]
199 fn test_parse_read_file_tool_call() {
200 let tool_call = ToolCall {
201 id: Some("call_123".to_string()),
202 function: FunctionCall {
203 name: "read_file".to_string(),
204 arguments: json!({
205 "path": "src/main.rs"
206 }),
207 },
208 };
209
210 let action = tool_call.to_agent_action().unwrap();
211 match action {
212 AgentAction::ReadFile { paths } => {
213 assert_eq!(paths.len(), 1);
214 assert_eq!(paths[0], "src/main.rs");
215 }
216 _ => panic!("Expected ReadFile action"),
217 }
218 }
219
220 #[test]
221 fn test_parse_write_file_tool_call() {
222 let tool_call = ToolCall {
223 id: None,
224 function: FunctionCall {
225 name: "write_file".to_string(),
226 arguments: json!({
227 "path": "test.txt",
228 "content": "Hello, world!"
229 }),
230 },
231 };
232
233 let action = tool_call.to_agent_action().unwrap();
234 match action {
235 AgentAction::WriteFile { path, content } => {
236 assert_eq!(path, "test.txt");
237 assert_eq!(content, "Hello, world!");
238 }
239 _ => panic!("Expected WriteFile action"),
240 }
241 }
242
243 #[test]
244 fn test_parse_execute_command_tool_call() {
245 let tool_call = ToolCall {
246 id: None,
247 function: FunctionCall {
248 name: "execute_command".to_string(),
249 arguments: json!({
250 "command": "cargo test",
251 "working_dir": "/path/to/project"
252 }),
253 },
254 };
255
256 let action = tool_call.to_agent_action().unwrap();
257 match action {
258 AgentAction::ExecuteCommand {
259 command,
260 working_dir,
261 } => {
262 assert_eq!(command, "cargo test");
263 assert_eq!(working_dir, Some("/path/to/project".to_string()));
264 }
265 _ => panic!("Expected ExecuteCommand action"),
266 }
267 }
268
269 #[test]
270 fn test_parse_web_search_tool_call() {
271 let tool_call = ToolCall {
272 id: None,
273 function: FunctionCall {
274 name: "web_search".to_string(),
275 arguments: json!({
276 "query": "Rust async features",
277 "result_count": 5
278 }),
279 },
280 };
281
282 let action = tool_call.to_agent_action().unwrap();
283 match action {
284 AgentAction::WebSearch { queries } => {
285 assert_eq!(queries.len(), 1);
286 assert_eq!(queries[0].0, "Rust async features");
287 assert_eq!(queries[0].1, 5);
288 }
289 _ => panic!("Expected WebSearch action"),
290 }
291 }
292
293 #[test]
294 fn test_unknown_tool_returns_error() {
295 let tool_call = ToolCall {
296 id: None,
297 function: FunctionCall {
298 name: "unknown_tool".to_string(),
299 arguments: json!({}),
300 },
301 };
302
303 assert!(tool_call.to_agent_action().is_err());
304 }
305
306 #[test]
307 fn test_group_parallel_reads() {
308 let actions = vec![
309 AgentAction::ReadFile {
310 paths: vec!["file1.rs".to_string()],
311 },
312 AgentAction::ReadFile {
313 paths: vec!["file2.rs".to_string()],
314 },
315 AgentAction::ReadFile {
316 paths: vec!["file3.rs".to_string()],
317 },
318 ];
319
320 let grouped = group_parallel_reads(actions);
321 assert_eq!(grouped.len(), 1);
322
323 match &grouped[0] {
324 AgentAction::ReadFile { paths } => {
325 assert_eq!(paths.len(), 3);
326 assert_eq!(paths[0], "file1.rs");
327 assert_eq!(paths[1], "file2.rs");
328 assert_eq!(paths[2], "file3.rs");
329 }
330 _ => panic!("Expected ReadFile action"),
331 }
332 }
333
334 #[test]
335 fn test_group_parallel_reads_single_read() {
336 let actions = vec![AgentAction::ReadFile {
337 paths: vec!["file1.rs".to_string()],
338 }];
339
340 let grouped = group_parallel_reads(actions);
341 assert_eq!(grouped.len(), 1);
342
343 match &grouped[0] {
344 AgentAction::ReadFile { paths } => {
345 assert_eq!(paths.len(), 1);
346 assert_eq!(paths[0], "file1.rs");
347 }
348 _ => panic!("Expected ReadFile action"),
349 }
350 }
351}