1use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8
9use crate::agents::AgentAction;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolCall {
14 #[serde(default)]
15 pub id: Option<String>,
16 pub function: FunctionCall,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct FunctionCall {
22 pub name: String,
23 pub arguments: serde_json::Value,
24}
25
26impl ToolCall {
27 pub fn to_agent_action(&self) -> Result<AgentAction> {
29 let args = &self.function.arguments;
30
31 let action = match self.function.name.as_str() {
32 "read_file" => {
33 let path = Self::get_string_arg(args, "path")?;
34 AgentAction::ReadFile { path }
35 }
36
37 "write_file" => {
38 let path = Self::get_string_arg(args, "path")?;
39 let content = Self::get_string_arg(args, "content")?;
40 AgentAction::WriteFile { path, content }
41 }
42
43 "delete_file" => {
44 let path = Self::get_string_arg(args, "path")?;
45 AgentAction::DeleteFile { path }
46 }
47
48 "create_directory" => {
49 let path = Self::get_string_arg(args, "path")?;
50 AgentAction::CreateDirectory { path }
51 }
52
53 "execute_command" => {
54 let command = Self::get_string_arg(args, "command")?;
55 let working_dir = Self::get_optional_string_arg(args, "working_dir");
56 AgentAction::ExecuteCommand {
57 command,
58 working_dir,
59 }
60 }
61
62 "git_diff" => {
63 let path = Self::get_optional_string_arg(args, "path");
64 AgentAction::GitDiff { path }
65 }
66
67 "git_status" => AgentAction::GitStatus,
68
69 "git_commit" => {
70 let message = Self::get_string_arg(args, "message")?;
71 let files = Self::get_string_array_arg(args, "files")?;
72 AgentAction::GitCommit { message, files }
73 }
74
75 "web_search" => {
76 let query = Self::get_string_arg(args, "query")?;
77 let result_count = Self::get_int_arg(args, "result_count")
78 .unwrap_or(5)
79 .clamp(1, 10);
80 AgentAction::WebSearch {
81 query,
82 result_count,
83 }
84 }
85
86 name => {
87 return Err(anyhow!(
88 "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
89 name
90 ))
91 }
92 };
93
94 Ok(action)
95 }
96
97 fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
100 args.get(key)
101 .and_then(|v| v.as_str())
102 .map(|s| s.to_string())
103 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
104 }
105
106 fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
107 args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
108 }
109
110 fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
111 args.get(key)
112 .and_then(|v| v.as_u64())
113 .map(|n| n as usize)
114 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
115 }
116
117 fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
118 args.get(key)
119 .and_then(|v| v.as_array())
120 .map(|arr| {
121 arr.iter()
122 .filter_map(|item| item.as_str().map(|s| s.to_string()))
123 .collect()
124 })
125 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
126 }
127}
128
129pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
131 tool_calls
132 .iter()
133 .filter_map(|tc| match tc.to_agent_action() {
134 Ok(action) => Some(action),
135 Err(e) => {
136 eprintln!("Failed to parse tool call '{}': {}", tc.function.name, e);
137 None
138 }
139 })
140 .collect()
141}
142
143pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
145 if actions.is_empty() {
146 return actions;
147 }
148
149 let mut result = Vec::new();
150 let mut current_group: Vec<String> = Vec::new();
151
152 for action in actions {
153 match action {
154 AgentAction::ReadFile { path } => {
155 current_group.push(path);
156 }
157 other => {
158 if current_group.len() > 1 {
160 result.push(AgentAction::ParallelRead {
161 paths: current_group.clone(),
162 });
163 } else if current_group.len() == 1 {
164 result.push(AgentAction::ReadFile {
165 path: current_group[0].clone(),
166 });
167 }
168 current_group.clear();
169
170 result.push(other);
171 }
172 }
173 }
174
175 if current_group.len() > 1 {
177 result.push(AgentAction::ParallelRead {
178 paths: current_group,
179 });
180 } else if current_group.len() == 1 {
181 result.push(AgentAction::ReadFile {
182 path: current_group[0].clone(),
183 });
184 }
185
186 result
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use serde_json::json;
193
194 #[test]
195 fn test_parse_read_file_tool_call() {
196 let tool_call = ToolCall {
197 id: Some("call_123".to_string()),
198 function: FunctionCall {
199 name: "read_file".to_string(),
200 arguments: json!({
201 "path": "src/main.rs"
202 }),
203 },
204 };
205
206 let action = tool_call.to_agent_action().unwrap();
207 match action {
208 AgentAction::ReadFile { path } => assert_eq!(path, "src/main.rs"),
209 _ => panic!("Expected ReadFile action"),
210 }
211 }
212
213 #[test]
214 fn test_parse_write_file_tool_call() {
215 let tool_call = ToolCall {
216 id: None,
217 function: FunctionCall {
218 name: "write_file".to_string(),
219 arguments: json!({
220 "path": "test.txt",
221 "content": "Hello, world!"
222 }),
223 },
224 };
225
226 let action = tool_call.to_agent_action().unwrap();
227 match action {
228 AgentAction::WriteFile { path, content } => {
229 assert_eq!(path, "test.txt");
230 assert_eq!(content, "Hello, world!");
231 }
232 _ => panic!("Expected WriteFile action"),
233 }
234 }
235
236 #[test]
237 fn test_parse_execute_command_tool_call() {
238 let tool_call = ToolCall {
239 id: None,
240 function: FunctionCall {
241 name: "execute_command".to_string(),
242 arguments: json!({
243 "command": "cargo test",
244 "working_dir": "/path/to/project"
245 }),
246 },
247 };
248
249 let action = tool_call.to_agent_action().unwrap();
250 match action {
251 AgentAction::ExecuteCommand {
252 command,
253 working_dir,
254 } => {
255 assert_eq!(command, "cargo test");
256 assert_eq!(working_dir, Some("/path/to/project".to_string()));
257 }
258 _ => panic!("Expected ExecuteCommand action"),
259 }
260 }
261
262 #[test]
263 fn test_parse_web_search_tool_call() {
264 let tool_call = ToolCall {
265 id: None,
266 function: FunctionCall {
267 name: "web_search".to_string(),
268 arguments: json!({
269 "query": "Rust async features",
270 "result_count": 5
271 }),
272 },
273 };
274
275 let action = tool_call.to_agent_action().unwrap();
276 match action {
277 AgentAction::WebSearch {
278 query,
279 result_count,
280 } => {
281 assert_eq!(query, "Rust async features");
282 assert_eq!(result_count, 5);
283 }
284 _ => panic!("Expected WebSearch action"),
285 }
286 }
287
288 #[test]
289 fn test_unknown_tool_returns_error() {
290 let tool_call = ToolCall {
291 id: None,
292 function: FunctionCall {
293 name: "unknown_tool".to_string(),
294 arguments: json!({}),
295 },
296 };
297
298 assert!(tool_call.to_agent_action().is_err());
299 }
300
301 #[test]
302 fn test_group_parallel_reads() {
303 let actions = vec![
304 AgentAction::ReadFile {
305 path: "file1.rs".to_string(),
306 },
307 AgentAction::ReadFile {
308 path: "file2.rs".to_string(),
309 },
310 AgentAction::ReadFile {
311 path: "file3.rs".to_string(),
312 },
313 ];
314
315 let grouped = group_parallel_reads(actions);
316 assert_eq!(grouped.len(), 1);
317
318 match &grouped[0] {
319 AgentAction::ParallelRead { paths } => {
320 assert_eq!(paths.len(), 3);
321 }
322 _ => panic!("Expected ParallelRead action"),
323 }
324 }
325
326 #[test]
327 fn test_group_parallel_reads_single_read() {
328 let actions = vec![AgentAction::ReadFile {
329 path: "file1.rs".to_string(),
330 }];
331
332 let grouped = group_parallel_reads(actions);
333 assert_eq!(grouped.len(), 1);
334
335 match &grouped[0] {
336 AgentAction::ReadFile { path } => {
337 assert_eq!(path, "file1.rs");
338 }
339 _ => panic!("Expected ReadFile action"),
340 }
341 }
342}