1use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15 #[serde(default)]
16 pub id: Option<String>,
17 pub function: FunctionCall,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23 pub name: String,
24 pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28 pub fn to_agent_action(&self) -> Result<AgentAction> {
30 let args = &self.function.arguments;
31
32 let action = match self.function.name.as_str() {
33 "read_file" => {
34 let path = Self::get_string_arg(args, "path")?;
35 AgentAction::ReadFile { paths: vec![path] }
36 }
37
38 "write_file" => {
39 let path = Self::get_string_arg(args, "path")?;
40 let content = Self::get_string_arg(args, "content")?;
41 AgentAction::WriteFile { path, content }
42 }
43
44 "delete_file" => {
45 let path = Self::get_string_arg(args, "path")?;
46 AgentAction::DeleteFile { path }
47 }
48
49 "create_directory" => {
50 let path = Self::get_string_arg(args, "path")?;
51 AgentAction::CreateDirectory { path }
52 }
53
54 "execute_command" => {
55 let command = Self::get_string_arg(args, "command")?;
56 let working_dir = Self::get_optional_string_arg(args, "working_dir");
57 AgentAction::ExecuteCommand {
58 command,
59 working_dir,
60 }
61 }
62
63 "git_diff" => {
64 let path = Self::get_optional_string_arg(args, "path");
65 AgentAction::GitDiff { paths: vec![path] }
66 }
67
68 "git_status" => AgentAction::GitStatus,
69
70 "git_commit" => {
71 let message = Self::get_string_arg(args, "message")?;
72 let files = Self::get_string_array_arg(args, "files")?;
73 AgentAction::GitCommit { message, files }
74 }
75
76 "web_search" => {
77 let query = Self::get_string_arg(args, "query")?;
78 let max_results = Self::get_int_arg(args, "max_results")
79 .or_else(|_| Self::get_int_arg(args, "result_count"))
80 .unwrap_or(5)
81 .clamp(1, 10);
82 AgentAction::WebSearch {
83 queries: vec![(query, max_results)],
84 }
85 }
86
87 "web_fetch" => {
88 let url = Self::get_string_arg(args, "url")?;
89 AgentAction::WebFetch { url }
90 }
91
92 name => {
93 return Err(anyhow!(
94 "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
95 name
96 ))
97 }
98 };
99
100 Ok(action)
101 }
102
103 fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
106 args.get(key)
107 .and_then(|v| v.as_str())
108 .map(|s| s.to_string())
109 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
110 }
111
112 fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
113 args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
114 }
115
116 fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
117 args.get(key)
118 .and_then(|v| v.as_u64())
119 .map(|n| n as usize)
120 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
121 }
122
123 fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
124 args.get(key)
125 .and_then(|v| v.as_array())
126 .map(|arr| {
127 arr.iter()
128 .filter_map(|item| item.as_str().map(|s| s.to_string()))
129 .collect()
130 })
131 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
132 }
133}
134
135pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
137 tool_calls
138 .iter()
139 .filter_map(|tc| match tc.to_agent_action() {
140 Ok(action) => Some(action),
141 Err(e) => {
142 warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
143 None
144 }
145 })
146 .collect()
147}
148
149pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
152 if actions.is_empty() {
153 return actions;
154 }
155
156 let mut result = Vec::new();
157 let mut current_group: Vec<String> = Vec::new();
158
159 for action in actions {
160 match action {
161 AgentAction::ReadFile { paths } => {
162 current_group.extend(paths);
163 }
164 other => {
165 if !current_group.is_empty() {
167 result.push(AgentAction::ReadFile {
168 paths: std::mem::take(&mut current_group),
169 });
170 }
171 result.push(other);
172 }
173 }
174 }
175
176 if !current_group.is_empty() {
178 result.push(AgentAction::ReadFile {
179 paths: current_group,
180 });
181 }
182
183 result
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use serde_json::json;
190
191 #[test]
192 fn test_parse_read_file_tool_call() {
193 let tool_call = ToolCall {
194 id: Some("call_123".to_string()),
195 function: FunctionCall {
196 name: "read_file".to_string(),
197 arguments: json!({
198 "path": "src/main.rs"
199 }),
200 },
201 };
202
203 let action = tool_call.to_agent_action().unwrap();
204 match action {
205 AgentAction::ReadFile { paths } => {
206 assert_eq!(paths.len(), 1);
207 assert_eq!(paths[0], "src/main.rs");
208 }
209 _ => panic!("Expected ReadFile action"),
210 }
211 }
212
213 #[test]
214 fn test_parse_write_file_tool_call() {
215 let tool_call = ToolCall {
216 id: None,
217 function: FunctionCall {
218 name: "write_file".to_string(),
219 arguments: json!({
220 "path": "test.txt",
221 "content": "Hello, world!"
222 }),
223 },
224 };
225
226 let action = tool_call.to_agent_action().unwrap();
227 match action {
228 AgentAction::WriteFile { path, content } => {
229 assert_eq!(path, "test.txt");
230 assert_eq!(content, "Hello, world!");
231 }
232 _ => panic!("Expected WriteFile action"),
233 }
234 }
235
236 #[test]
237 fn test_parse_execute_command_tool_call() {
238 let tool_call = ToolCall {
239 id: None,
240 function: FunctionCall {
241 name: "execute_command".to_string(),
242 arguments: json!({
243 "command": "cargo test",
244 "working_dir": "/path/to/project"
245 }),
246 },
247 };
248
249 let action = tool_call.to_agent_action().unwrap();
250 match action {
251 AgentAction::ExecuteCommand {
252 command,
253 working_dir,
254 } => {
255 assert_eq!(command, "cargo test");
256 assert_eq!(working_dir, Some("/path/to/project".to_string()));
257 }
258 _ => panic!("Expected ExecuteCommand action"),
259 }
260 }
261
262 #[test]
263 fn test_parse_web_search_tool_call() {
264 let tool_call = ToolCall {
265 id: None,
266 function: FunctionCall {
267 name: "web_search".to_string(),
268 arguments: json!({
269 "query": "Rust async features",
270 "result_count": 5
271 }),
272 },
273 };
274
275 let action = tool_call.to_agent_action().unwrap();
276 match action {
277 AgentAction::WebSearch { queries } => {
278 assert_eq!(queries.len(), 1);
279 assert_eq!(queries[0].0, "Rust async features");
280 assert_eq!(queries[0].1, 5);
281 }
282 _ => panic!("Expected WebSearch action"),
283 }
284 }
285
286 #[test]
287 fn test_unknown_tool_returns_error() {
288 let tool_call = ToolCall {
289 id: None,
290 function: FunctionCall {
291 name: "unknown_tool".to_string(),
292 arguments: json!({}),
293 },
294 };
295
296 assert!(tool_call.to_agent_action().is_err());
297 }
298
299 #[test]
300 fn test_group_parallel_reads() {
301 let actions = vec![
302 AgentAction::ReadFile {
303 paths: vec!["file1.rs".to_string()],
304 },
305 AgentAction::ReadFile {
306 paths: vec!["file2.rs".to_string()],
307 },
308 AgentAction::ReadFile {
309 paths: vec!["file3.rs".to_string()],
310 },
311 ];
312
313 let grouped = group_parallel_reads(actions);
314 assert_eq!(grouped.len(), 1);
315
316 match &grouped[0] {
317 AgentAction::ReadFile { paths } => {
318 assert_eq!(paths.len(), 3);
319 assert_eq!(paths[0], "file1.rs");
320 assert_eq!(paths[1], "file2.rs");
321 assert_eq!(paths[2], "file3.rs");
322 }
323 _ => panic!("Expected ReadFile action"),
324 }
325 }
326
327 #[test]
328 fn test_group_parallel_reads_single_read() {
329 let actions = vec![AgentAction::ReadFile {
330 paths: vec!["file1.rs".to_string()],
331 }];
332
333 let grouped = group_parallel_reads(actions);
334 assert_eq!(grouped.len(), 1);
335
336 match &grouped[0] {
337 AgentAction::ReadFile { paths } => {
338 assert_eq!(paths.len(), 1);
339 assert_eq!(paths[0], "file1.rs");
340 }
341 _ => panic!("Expected ReadFile action"),
342 }
343 }
344}