Skip to main content

mermaid_cli/models/
tool_call.rs

1//! Tool call parsing and conversion to AgentAction
2//!
3//! Handles deserialization of Ollama tool_calls responses and converts
4//! them to Mermaid's internal AgentAction enum.
5
6use anyhow::{Result, anyhow};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12/// A tool call from the model (Ollama format)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15    #[serde(default)]
16    pub id: Option<String>,
17    pub function: FunctionCall,
18}
19
20/// Function call details
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23    pub name: String,
24    pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28    /// Convert Ollama tool call to Mermaid AgentAction
29    pub fn to_agent_action(&self) -> Result<AgentAction> {
30        let args = &self.function.arguments;
31
32        let action = match self.function.name.as_str() {
33            "read_file" => {
34                let path = Self::get_string_arg(args, "path")?;
35                AgentAction::ReadFile { paths: vec![path] }
36            },
37
38            "write_file" => {
39                let path = Self::get_string_arg(args, "path")?;
40                let content = Self::get_string_arg(args, "content")?;
41                AgentAction::WriteFile { path, content }
42            },
43
44            "delete_file" => {
45                let path = Self::get_string_arg(args, "path")?;
46                AgentAction::DeleteFile { path }
47            },
48
49            "create_directory" => {
50                let path = Self::get_string_arg(args, "path")?;
51                AgentAction::CreateDirectory { path }
52            },
53
54            "execute_command" => {
55                let command = Self::get_string_arg(args, "command")?;
56                let working_dir = Self::get_optional_string_arg(args, "working_dir");
57                let timeout = args.get("timeout").and_then(|v| v.as_u64());
58                AgentAction::ExecuteCommand {
59                    command,
60                    working_dir,
61                    timeout,
62                }
63            },
64
65            "web_search" => {
66                let query = Self::get_string_arg(args, "query")?;
67                let max_results = Self::get_int_arg(args, "max_results")
68                    .or_else(|_| Self::get_int_arg(args, "result_count"))
69                    .unwrap_or(5)
70                    .clamp(1, 10);
71                AgentAction::WebSearch {
72                    queries: vec![(query, max_results)],
73                }
74            },
75
76            "edit_file" => {
77                let path = Self::get_string_arg(args, "path")?;
78                let old_string = Self::get_string_arg(args, "old_string")?;
79                let new_string = Self::get_string_arg(args, "new_string")?;
80                AgentAction::EditFile {
81                    path,
82                    old_string,
83                    new_string,
84                }
85            },
86
87            "web_fetch" => {
88                let url = Self::get_string_arg(args, "url")?;
89                AgentAction::WebFetch { url }
90            },
91
92            "agent" => {
93                let prompt = Self::get_string_arg(args, "prompt")?;
94                let description = Self::get_string_arg(args, "description")?;
95                AgentAction::SpawnAgent { prompt, description }
96            },
97
98            "screenshot" => {
99                let mode = Self::get_optional_string_arg(args, "mode")
100                    .unwrap_or_else(|| "fullscreen".to_string());
101                let monitor = Self::get_optional_string_arg(args, "monitor");
102                let region = Self::get_optional_string_arg(args, "region");
103                let window = Self::get_optional_string_arg(args, "window");
104                AgentAction::Screenshot { mode, monitor, region, window }
105            },
106
107            "list_windows" => AgentAction::ListWindows,
108
109            "click" => {
110                let x = Self::get_int_arg(args, "x")? as i32;
111                let y = Self::get_int_arg(args, "y")? as i32;
112                let button = Self::get_optional_string_arg(args, "button")
113                    .unwrap_or_else(|| "left".to_string());
114                AgentAction::Click { x, y, button }
115            },
116
117            "type_text" => {
118                let text = Self::get_string_arg(args, "text")?;
119                AgentAction::TypeText { text }
120            },
121
122            "press_key" => {
123                let key = Self::get_string_arg(args, "key")?;
124                AgentAction::PressKey { key }
125            },
126
127            "scroll" => {
128                let direction = Self::get_string_arg(args, "direction")?;
129                let amount = Self::get_int_arg(args, "amount").unwrap_or(3) as i32;
130                AgentAction::Scroll { direction, amount }
131            },
132
133            "mouse_move" => {
134                let x = Self::get_int_arg(args, "x")? as i32;
135                let y = Self::get_int_arg(args, "y")? as i32;
136                AgentAction::MouseMove { x, y }
137            },
138
139            // MCP tools: mcp__{server_name}__{tool_name}
140            name if name.starts_with("mcp__") => {
141                let rest = &name[5..]; // skip "mcp__"
142                if let Some((server_name, tool_name)) = rest.split_once("__") {
143                    AgentAction::McpToolCall {
144                        server_name: server_name.to_string(),
145                        tool_name: tool_name.to_string(),
146                        arguments: args.clone(),
147                    }
148                } else {
149                    return Err(anyhow!(
150                        "Invalid MCP tool name format: '{}'. Expected 'mcp__{{server}}__{{tool}}'.",
151                        name
152                    ));
153                }
154            },
155
156            name => {
157                return Err(anyhow!(
158                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
159                    name
160                ));
161            },
162        };
163
164        Ok(action)
165    }
166
167    // Helper methods for argument extraction
168
169    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
170        args.get(key)
171            .and_then(|v| v.as_str())
172            .map(|s| s.to_string())
173            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
174    }
175
176    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
177        args.get(key)
178            .and_then(|v| v.as_str())
179            .map(|s| s.to_string())
180    }
181
182    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
183        args.get(key)
184            .and_then(|v| v.as_u64())
185            .map(|n| n as usize)
186            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
187    }
188
189}
190
191/// Parse multiple tool calls into agent actions
192pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
193    tool_calls
194        .iter()
195        .filter_map(|tc| match tc.to_agent_action() {
196            Ok(action) => Some(action),
197            Err(e) => {
198                warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
199                None
200            },
201        })
202        .collect()
203}
204
205/// Group consecutive same-type read operations into a single ReadFile action
206/// The executor will decide whether to parallelize based on the number of paths
207pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
208    if actions.is_empty() {
209        return actions;
210    }
211
212    let mut result = Vec::new();
213    let mut current_group: Vec<String> = Vec::new();
214
215    for action in actions {
216        match action {
217            AgentAction::ReadFile { paths } => {
218                current_group.extend(paths);
219            },
220            other => {
221                // Flush current read group
222                if !current_group.is_empty() {
223                    result.push(AgentAction::ReadFile {
224                        paths: std::mem::take(&mut current_group),
225                    });
226                }
227                result.push(other);
228            },
229        }
230    }
231
232    // Flush remaining read group
233    if !current_group.is_empty() {
234        result.push(AgentAction::ReadFile {
235            paths: current_group,
236        });
237    }
238
239    result
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use serde_json::json;
246
247    #[test]
248    fn test_parse_read_file_tool_call() {
249        let tool_call = ToolCall {
250            id: Some("call_123".to_string()),
251            function: FunctionCall {
252                name: "read_file".to_string(),
253                arguments: json!({
254                    "path": "src/main.rs"
255                }),
256            },
257        };
258
259        let action = tool_call.to_agent_action().unwrap();
260        match action {
261            AgentAction::ReadFile { paths } => {
262                assert_eq!(paths.len(), 1);
263                assert_eq!(paths[0], "src/main.rs");
264            },
265            _ => panic!("Expected ReadFile action"),
266        }
267    }
268
269    #[test]
270    fn test_parse_write_file_tool_call() {
271        let tool_call = ToolCall {
272            id: None,
273            function: FunctionCall {
274                name: "write_file".to_string(),
275                arguments: json!({
276                    "path": "test.txt",
277                    "content": "Hello, world!"
278                }),
279            },
280        };
281
282        let action = tool_call.to_agent_action().unwrap();
283        match action {
284            AgentAction::WriteFile { path, content } => {
285                assert_eq!(path, "test.txt");
286                assert_eq!(content, "Hello, world!");
287            },
288            _ => panic!("Expected WriteFile action"),
289        }
290    }
291
292    #[test]
293    fn test_parse_execute_command_tool_call() {
294        let tool_call = ToolCall {
295            id: None,
296            function: FunctionCall {
297                name: "execute_command".to_string(),
298                arguments: json!({
299                    "command": "cargo test",
300                    "working_dir": "/path/to/project"
301                }),
302            },
303        };
304
305        let action = tool_call.to_agent_action().unwrap();
306        match action {
307            AgentAction::ExecuteCommand {
308                command,
309                working_dir,
310                timeout,
311            } => {
312                assert_eq!(command, "cargo test");
313                assert_eq!(working_dir, Some("/path/to/project".to_string()));
314                assert_eq!(timeout, None);
315            },
316            _ => panic!("Expected ExecuteCommand action"),
317        }
318    }
319
320    #[test]
321    fn test_parse_web_search_tool_call() {
322        let tool_call = ToolCall {
323            id: None,
324            function: FunctionCall {
325                name: "web_search".to_string(),
326                arguments: json!({
327                    "query": "Rust async features",
328                    "result_count": 5
329                }),
330            },
331        };
332
333        let action = tool_call.to_agent_action().unwrap();
334        match action {
335            AgentAction::WebSearch { queries } => {
336                assert_eq!(queries.len(), 1);
337                assert_eq!(queries[0].0, "Rust async features");
338                assert_eq!(queries[0].1, 5);
339            },
340            _ => panic!("Expected WebSearch action"),
341        }
342    }
343
344    #[test]
345    fn test_parse_agent_tool_call() {
346        let tool_call = ToolCall {
347            id: Some("call_agent_1".to_string()),
348            function: FunctionCall {
349                name: "agent".to_string(),
350                arguments: json!({
351                    "prompt": "Read all files in src/models/ and summarize them",
352                    "description": "Read src/models/ files"
353                }),
354            },
355        };
356
357        let action = tool_call.to_agent_action().unwrap();
358        match action {
359            AgentAction::SpawnAgent {
360                prompt,
361                description,
362            } => {
363                assert!(prompt.contains("src/models/"));
364                assert_eq!(description, "Read src/models/ files");
365            },
366            _ => panic!("Expected SpawnAgent action"),
367        }
368    }
369
370    #[test]
371    fn test_unknown_tool_returns_error() {
372        let tool_call = ToolCall {
373            id: None,
374            function: FunctionCall {
375                name: "unknown_tool".to_string(),
376                arguments: json!({}),
377            },
378        };
379
380        assert!(tool_call.to_agent_action().is_err());
381    }
382
383    #[test]
384    fn test_group_parallel_reads() {
385        let actions = vec![
386            AgentAction::ReadFile {
387                paths: vec!["file1.rs".to_string()],
388            },
389            AgentAction::ReadFile {
390                paths: vec!["file2.rs".to_string()],
391            },
392            AgentAction::ReadFile {
393                paths: vec!["file3.rs".to_string()],
394            },
395        ];
396
397        let grouped = group_parallel_reads(actions);
398        assert_eq!(grouped.len(), 1);
399
400        match &grouped[0] {
401            AgentAction::ReadFile { paths } => {
402                assert_eq!(paths.len(), 3);
403                assert_eq!(paths[0], "file1.rs");
404                assert_eq!(paths[1], "file2.rs");
405                assert_eq!(paths[2], "file3.rs");
406            },
407            _ => panic!("Expected ReadFile action"),
408        }
409    }
410
411    #[test]
412    fn test_group_parallel_reads_single_read() {
413        let actions = vec![AgentAction::ReadFile {
414            paths: vec!["file1.rs".to_string()],
415        }];
416
417        let grouped = group_parallel_reads(actions);
418        assert_eq!(grouped.len(), 1);
419
420        match &grouped[0] {
421            AgentAction::ReadFile { paths } => {
422                assert_eq!(paths.len(), 1);
423                assert_eq!(paths[0], "file1.rs");
424            },
425            _ => panic!("Expected ReadFile action"),
426        }
427    }
428}