1use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15 #[serde(default)]
16 pub id: Option<String>,
17 pub function: FunctionCall,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23 pub name: String,
24 pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28 pub fn to_agent_action(&self) -> Result<AgentAction> {
30 let args = &self.function.arguments;
31
32 let action = match self.function.name.as_str() {
33 "read_file" => {
34 let path = Self::get_string_arg(args, "path")?;
35 AgentAction::ReadFile { paths: vec![path] }
36 }
37
38 "write_file" => {
39 let path = Self::get_string_arg(args, "path")?;
40 let content = Self::get_string_arg(args, "content")?;
41 AgentAction::WriteFile { path, content }
42 }
43
44 "delete_file" => {
45 let path = Self::get_string_arg(args, "path")?;
46 AgentAction::DeleteFile { path }
47 }
48
49 "create_directory" => {
50 let path = Self::get_string_arg(args, "path")?;
51 AgentAction::CreateDirectory { path }
52 }
53
54 "execute_command" => {
55 let command = Self::get_string_arg(args, "command")?;
56 let working_dir = Self::get_optional_string_arg(args, "working_dir");
57 AgentAction::ExecuteCommand {
58 command,
59 working_dir,
60 }
61 }
62
63 "git_diff" => {
64 let path = Self::get_optional_string_arg(args, "path");
65 AgentAction::GitDiff { paths: vec![path] }
66 }
67
68 "git_status" => AgentAction::GitStatus,
69
70 "git_commit" => {
71 let message = Self::get_string_arg(args, "message")?;
72 let files = Self::get_string_array_arg(args, "files")?;
73 AgentAction::GitCommit { message, files }
74 }
75
76 "web_search" => {
77 let query = Self::get_string_arg(args, "query")?;
78 let result_count = Self::get_int_arg(args, "result_count")
79 .unwrap_or(5)
80 .clamp(1, 10);
81 AgentAction::WebSearch {
82 queries: vec![(query, result_count)],
83 }
84 }
85
86 name => {
87 return Err(anyhow!(
88 "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
89 name
90 ))
91 }
92 };
93
94 Ok(action)
95 }
96
97 fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
100 args.get(key)
101 .and_then(|v| v.as_str())
102 .map(|s| s.to_string())
103 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
104 }
105
106 fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
107 args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
108 }
109
110 fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
111 args.get(key)
112 .and_then(|v| v.as_u64())
113 .map(|n| n as usize)
114 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
115 }
116
117 fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
118 args.get(key)
119 .and_then(|v| v.as_array())
120 .map(|arr| {
121 arr.iter()
122 .filter_map(|item| item.as_str().map(|s| s.to_string()))
123 .collect()
124 })
125 .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
126 }
127}
128
129pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
131 tool_calls
132 .iter()
133 .filter_map(|tc| match tc.to_agent_action() {
134 Ok(action) => Some(action),
135 Err(e) => {
136 warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
137 None
138 }
139 })
140 .collect()
141}
142
143pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
146 if actions.is_empty() {
147 return actions;
148 }
149
150 let mut result = Vec::new();
151 let mut current_group: Vec<String> = Vec::new();
152
153 for action in actions {
154 match action {
155 AgentAction::ReadFile { paths } => {
156 current_group.extend(paths);
157 }
158 other => {
159 if !current_group.is_empty() {
161 result.push(AgentAction::ReadFile {
162 paths: std::mem::take(&mut current_group),
163 });
164 }
165 result.push(other);
166 }
167 }
168 }
169
170 if !current_group.is_empty() {
172 result.push(AgentAction::ReadFile {
173 paths: current_group,
174 });
175 }
176
177 result
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use serde_json::json;
184
185 #[test]
186 fn test_parse_read_file_tool_call() {
187 let tool_call = ToolCall {
188 id: Some("call_123".to_string()),
189 function: FunctionCall {
190 name: "read_file".to_string(),
191 arguments: json!({
192 "path": "src/main.rs"
193 }),
194 },
195 };
196
197 let action = tool_call.to_agent_action().unwrap();
198 match action {
199 AgentAction::ReadFile { paths } => {
200 assert_eq!(paths.len(), 1);
201 assert_eq!(paths[0], "src/main.rs");
202 }
203 _ => panic!("Expected ReadFile action"),
204 }
205 }
206
207 #[test]
208 fn test_parse_write_file_tool_call() {
209 let tool_call = ToolCall {
210 id: None,
211 function: FunctionCall {
212 name: "write_file".to_string(),
213 arguments: json!({
214 "path": "test.txt",
215 "content": "Hello, world!"
216 }),
217 },
218 };
219
220 let action = tool_call.to_agent_action().unwrap();
221 match action {
222 AgentAction::WriteFile { path, content } => {
223 assert_eq!(path, "test.txt");
224 assert_eq!(content, "Hello, world!");
225 }
226 _ => panic!("Expected WriteFile action"),
227 }
228 }
229
230 #[test]
231 fn test_parse_execute_command_tool_call() {
232 let tool_call = ToolCall {
233 id: None,
234 function: FunctionCall {
235 name: "execute_command".to_string(),
236 arguments: json!({
237 "command": "cargo test",
238 "working_dir": "/path/to/project"
239 }),
240 },
241 };
242
243 let action = tool_call.to_agent_action().unwrap();
244 match action {
245 AgentAction::ExecuteCommand {
246 command,
247 working_dir,
248 } => {
249 assert_eq!(command, "cargo test");
250 assert_eq!(working_dir, Some("/path/to/project".to_string()));
251 }
252 _ => panic!("Expected ExecuteCommand action"),
253 }
254 }
255
256 #[test]
257 fn test_parse_web_search_tool_call() {
258 let tool_call = ToolCall {
259 id: None,
260 function: FunctionCall {
261 name: "web_search".to_string(),
262 arguments: json!({
263 "query": "Rust async features",
264 "result_count": 5
265 }),
266 },
267 };
268
269 let action = tool_call.to_agent_action().unwrap();
270 match action {
271 AgentAction::WebSearch { queries } => {
272 assert_eq!(queries.len(), 1);
273 assert_eq!(queries[0].0, "Rust async features");
274 assert_eq!(queries[0].1, 5);
275 }
276 _ => panic!("Expected WebSearch action"),
277 }
278 }
279
280 #[test]
281 fn test_unknown_tool_returns_error() {
282 let tool_call = ToolCall {
283 id: None,
284 function: FunctionCall {
285 name: "unknown_tool".to_string(),
286 arguments: json!({}),
287 },
288 };
289
290 assert!(tool_call.to_agent_action().is_err());
291 }
292
293 #[test]
294 fn test_group_parallel_reads() {
295 let actions = vec![
296 AgentAction::ReadFile {
297 paths: vec!["file1.rs".to_string()],
298 },
299 AgentAction::ReadFile {
300 paths: vec!["file2.rs".to_string()],
301 },
302 AgentAction::ReadFile {
303 paths: vec!["file3.rs".to_string()],
304 },
305 ];
306
307 let grouped = group_parallel_reads(actions);
308 assert_eq!(grouped.len(), 1);
309
310 match &grouped[0] {
311 AgentAction::ReadFile { paths } => {
312 assert_eq!(paths.len(), 3);
313 assert_eq!(paths[0], "file1.rs");
314 assert_eq!(paths[1], "file2.rs");
315 assert_eq!(paths[2], "file3.rs");
316 }
317 _ => panic!("Expected ReadFile action"),
318 }
319 }
320
321 #[test]
322 fn test_group_parallel_reads_single_read() {
323 let actions = vec![AgentAction::ReadFile {
324 paths: vec!["file1.rs".to_string()],
325 }];
326
327 let grouped = group_parallel_reads(actions);
328 assert_eq!(grouped.len(), 1);
329
330 match &grouped[0] {
331 AgentAction::ReadFile { paths } => {
332 assert_eq!(paths.len(), 1);
333 assert_eq!(paths[0], "file1.rs");
334 }
335 _ => panic!("Expected ReadFile action"),
336 }
337 }
338}