Skip to main content

mermaid_cli/models/
tool_call.rs

1//! Tool call parsing and conversion to AgentAction
2//!
3//! Handles deserialization of Ollama tool_calls responses and converts
4//! them to Mermaid's internal AgentAction enum.
5
6use anyhow::{Result, anyhow};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12/// A tool call from the model (Ollama format)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15    #[serde(default)]
16    pub id: Option<String>,
17    pub function: FunctionCall,
18}
19
20/// Function call details
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23    pub name: String,
24    pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28    /// Convert Ollama tool call to Mermaid AgentAction
29    pub fn to_agent_action(&self) -> Result<AgentAction> {
30        let args = &self.function.arguments;
31
32        let action = match self.function.name.as_str() {
33            "read_file" => {
34                let path = Self::get_string_arg(args, "path")?;
35                AgentAction::ReadFile { paths: vec![path] }
36            },
37
38            "write_file" => {
39                let path = Self::get_string_arg(args, "path")?;
40                let content = Self::get_string_arg(args, "content")?;
41                AgentAction::WriteFile { path, content }
42            },
43
44            "delete_file" => {
45                let path = Self::get_string_arg(args, "path")?;
46                AgentAction::DeleteFile { path }
47            },
48
49            "create_directory" => {
50                let path = Self::get_string_arg(args, "path")?;
51                AgentAction::CreateDirectory { path }
52            },
53
54            "execute_command" => {
55                let command = Self::get_string_arg(args, "command")?;
56                let working_dir = Self::get_optional_string_arg(args, "working_dir");
57                let timeout = args.get("timeout").and_then(|v| v.as_u64());
58                AgentAction::ExecuteCommand {
59                    command,
60                    working_dir,
61                    timeout,
62                }
63            },
64
65            "web_search" => {
66                let query = Self::get_string_arg(args, "query")?;
67                let max_results = Self::get_int_arg(args, "max_results")
68                    .or_else(|_| Self::get_int_arg(args, "result_count"))
69                    .unwrap_or(5)
70                    .clamp(1, 10);
71                AgentAction::WebSearch {
72                    queries: vec![(query, max_results)],
73                }
74            },
75
76            "edit_file" => {
77                let path = Self::get_string_arg(args, "path")?;
78                let old_string = Self::get_string_arg(args, "old_string")?;
79                let new_string = Self::get_string_arg(args, "new_string")?;
80                AgentAction::EditFile {
81                    path,
82                    old_string,
83                    new_string,
84                }
85            },
86
87            "web_fetch" => {
88                let url = Self::get_string_arg(args, "url")?;
89                AgentAction::WebFetch { url }
90            },
91
92            "agent" => {
93                let prompt = Self::get_string_arg(args, "prompt")?;
94                let description = Self::get_string_arg(args, "description")?;
95                AgentAction::SpawnAgent { prompt, description }
96            },
97
98            name => {
99                return Err(anyhow!(
100                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
101                    name
102                ));
103            },
104        };
105
106        Ok(action)
107    }
108
109    // Helper methods for argument extraction
110
111    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
112        args.get(key)
113            .and_then(|v| v.as_str())
114            .map(|s| s.to_string())
115            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
116    }
117
118    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
119        args.get(key)
120            .and_then(|v| v.as_str())
121            .map(|s| s.to_string())
122    }
123
124    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
125        args.get(key)
126            .and_then(|v| v.as_u64())
127            .map(|n| n as usize)
128            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
129    }
130
131}
132
133/// Parse multiple tool calls into agent actions
134pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
135    tool_calls
136        .iter()
137        .filter_map(|tc| match tc.to_agent_action() {
138            Ok(action) => Some(action),
139            Err(e) => {
140                warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
141                None
142            },
143        })
144        .collect()
145}
146
147/// Group consecutive same-type read operations into a single ReadFile action
148/// The executor will decide whether to parallelize based on the number of paths
149pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
150    if actions.is_empty() {
151        return actions;
152    }
153
154    let mut result = Vec::new();
155    let mut current_group: Vec<String> = Vec::new();
156
157    for action in actions {
158        match action {
159            AgentAction::ReadFile { paths } => {
160                current_group.extend(paths);
161            },
162            other => {
163                // Flush current read group
164                if !current_group.is_empty() {
165                    result.push(AgentAction::ReadFile {
166                        paths: std::mem::take(&mut current_group),
167                    });
168                }
169                result.push(other);
170            },
171        }
172    }
173
174    // Flush remaining read group
175    if !current_group.is_empty() {
176        result.push(AgentAction::ReadFile {
177            paths: current_group,
178        });
179    }
180
181    result
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use serde_json::json;
188
189    #[test]
190    fn test_parse_read_file_tool_call() {
191        let tool_call = ToolCall {
192            id: Some("call_123".to_string()),
193            function: FunctionCall {
194                name: "read_file".to_string(),
195                arguments: json!({
196                    "path": "src/main.rs"
197                }),
198            },
199        };
200
201        let action = tool_call.to_agent_action().unwrap();
202        match action {
203            AgentAction::ReadFile { paths } => {
204                assert_eq!(paths.len(), 1);
205                assert_eq!(paths[0], "src/main.rs");
206            },
207            _ => panic!("Expected ReadFile action"),
208        }
209    }
210
211    #[test]
212    fn test_parse_write_file_tool_call() {
213        let tool_call = ToolCall {
214            id: None,
215            function: FunctionCall {
216                name: "write_file".to_string(),
217                arguments: json!({
218                    "path": "test.txt",
219                    "content": "Hello, world!"
220                }),
221            },
222        };
223
224        let action = tool_call.to_agent_action().unwrap();
225        match action {
226            AgentAction::WriteFile { path, content } => {
227                assert_eq!(path, "test.txt");
228                assert_eq!(content, "Hello, world!");
229            },
230            _ => panic!("Expected WriteFile action"),
231        }
232    }
233
234    #[test]
235    fn test_parse_execute_command_tool_call() {
236        let tool_call = ToolCall {
237            id: None,
238            function: FunctionCall {
239                name: "execute_command".to_string(),
240                arguments: json!({
241                    "command": "cargo test",
242                    "working_dir": "/path/to/project"
243                }),
244            },
245        };
246
247        let action = tool_call.to_agent_action().unwrap();
248        match action {
249            AgentAction::ExecuteCommand {
250                command,
251                working_dir,
252                timeout,
253            } => {
254                assert_eq!(command, "cargo test");
255                assert_eq!(working_dir, Some("/path/to/project".to_string()));
256                assert_eq!(timeout, None);
257            },
258            _ => panic!("Expected ExecuteCommand action"),
259        }
260    }
261
262    #[test]
263    fn test_parse_web_search_tool_call() {
264        let tool_call = ToolCall {
265            id: None,
266            function: FunctionCall {
267                name: "web_search".to_string(),
268                arguments: json!({
269                    "query": "Rust async features",
270                    "result_count": 5
271                }),
272            },
273        };
274
275        let action = tool_call.to_agent_action().unwrap();
276        match action {
277            AgentAction::WebSearch { queries } => {
278                assert_eq!(queries.len(), 1);
279                assert_eq!(queries[0].0, "Rust async features");
280                assert_eq!(queries[0].1, 5);
281            },
282            _ => panic!("Expected WebSearch action"),
283        }
284    }
285
286    #[test]
287    fn test_parse_agent_tool_call() {
288        let tool_call = ToolCall {
289            id: Some("call_agent_1".to_string()),
290            function: FunctionCall {
291                name: "agent".to_string(),
292                arguments: json!({
293                    "prompt": "Read all files in src/models/ and summarize them",
294                    "description": "Read src/models/ files"
295                }),
296            },
297        };
298
299        let action = tool_call.to_agent_action().unwrap();
300        match action {
301            AgentAction::SpawnAgent {
302                prompt,
303                description,
304            } => {
305                assert!(prompt.contains("src/models/"));
306                assert_eq!(description, "Read src/models/ files");
307            },
308            _ => panic!("Expected SpawnAgent action"),
309        }
310    }
311
312    #[test]
313    fn test_unknown_tool_returns_error() {
314        let tool_call = ToolCall {
315            id: None,
316            function: FunctionCall {
317                name: "unknown_tool".to_string(),
318                arguments: json!({}),
319            },
320        };
321
322        assert!(tool_call.to_agent_action().is_err());
323    }
324
325    #[test]
326    fn test_group_parallel_reads() {
327        let actions = vec![
328            AgentAction::ReadFile {
329                paths: vec!["file1.rs".to_string()],
330            },
331            AgentAction::ReadFile {
332                paths: vec!["file2.rs".to_string()],
333            },
334            AgentAction::ReadFile {
335                paths: vec!["file3.rs".to_string()],
336            },
337        ];
338
339        let grouped = group_parallel_reads(actions);
340        assert_eq!(grouped.len(), 1);
341
342        match &grouped[0] {
343            AgentAction::ReadFile { paths } => {
344                assert_eq!(paths.len(), 3);
345                assert_eq!(paths[0], "file1.rs");
346                assert_eq!(paths[1], "file2.rs");
347                assert_eq!(paths[2], "file3.rs");
348            },
349            _ => panic!("Expected ReadFile action"),
350        }
351    }
352
353    #[test]
354    fn test_group_parallel_reads_single_read() {
355        let actions = vec![AgentAction::ReadFile {
356            paths: vec!["file1.rs".to_string()],
357        }];
358
359        let grouped = group_parallel_reads(actions);
360        assert_eq!(grouped.len(), 1);
361
362        match &grouped[0] {
363            AgentAction::ReadFile { paths } => {
364                assert_eq!(paths.len(), 1);
365                assert_eq!(paths[0], "file1.rs");
366            },
367            _ => panic!("Expected ReadFile action"),
368        }
369    }
370}