Skip to main content

mermaid_cli/models/
tool_call.rs

1//! Tool call parsing and conversion to AgentAction
2//!
3//! Handles deserialization of Ollama tool_calls responses and converts
4//! them to Mermaid's internal AgentAction enum.
5
6use anyhow::{Result, anyhow};
7use serde::{Deserialize, Serialize};
8
9use crate::agents::AgentAction;
10
11/// A tool call from the model (Ollama format)
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolCall {
14    #[serde(default)]
15    pub id: Option<String>,
16    pub function: FunctionCall,
17}
18
19/// Function call details
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct FunctionCall {
22    pub name: String,
23    pub arguments: serde_json::Value,
24}
25
26impl ToolCall {
27    /// Convert Ollama tool call to Mermaid AgentAction
28    pub fn to_agent_action(&self) -> Result<AgentAction> {
29        let args = &self.function.arguments;
30
31        let action = match self.function.name.as_str() {
32            "read_file" => {
33                let path = Self::get_string_arg(args, "path")?;
34                AgentAction::ReadFile { paths: vec![path] }
35            },
36
37            "write_file" => {
38                let path = Self::get_string_arg(args, "path")?;
39                let content = Self::get_string_arg(args, "content")?;
40                AgentAction::WriteFile { path, content }
41            },
42
43            "delete_file" => {
44                let path = Self::get_string_arg(args, "path")?;
45                AgentAction::DeleteFile { path }
46            },
47
48            "create_directory" => {
49                let path = Self::get_string_arg(args, "path")?;
50                AgentAction::CreateDirectory { path }
51            },
52
53            "execute_command" => {
54                let command = Self::get_string_arg(args, "command")?;
55                let working_dir = Self::get_optional_string_arg(args, "working_dir");
56                let timeout = args.get("timeout").and_then(|v| v.as_u64());
57                AgentAction::ExecuteCommand {
58                    command,
59                    working_dir,
60                    timeout,
61                }
62            },
63
64            "web_search" => {
65                let query = Self::get_string_arg(args, "query")?;
66                let max_results = Self::get_int_arg(args, "max_results")
67                    .or_else(|_| Self::get_int_arg(args, "result_count"))
68                    .unwrap_or(5)
69                    .clamp(1, 10);
70                AgentAction::WebSearch {
71                    queries: vec![(query, max_results)],
72                }
73            },
74
75            "edit_file" => {
76                let path = Self::get_string_arg(args, "path")?;
77                let old_string = Self::get_string_arg(args, "old_string")?;
78                let new_string = Self::get_string_arg(args, "new_string")?;
79                AgentAction::EditFile {
80                    path,
81                    old_string,
82                    new_string,
83                }
84            },
85
86            "web_fetch" => {
87                let url = Self::get_string_arg(args, "url")?;
88                AgentAction::WebFetch { url }
89            },
90
91            "agent" => {
92                let prompt = Self::get_string_arg(args, "prompt")?;
93                let description = Self::get_string_arg(args, "description")?;
94                AgentAction::SpawnAgent {
95                    prompt,
96                    description,
97                }
98            },
99
100            "screenshot" => {
101                let mode = Self::get_optional_string_arg(args, "mode")
102                    .unwrap_or_else(|| "fullscreen".to_string());
103                let monitor = Self::get_optional_string_arg(args, "monitor");
104                let region = Self::get_optional_string_arg(args, "region");
105                let window = Self::get_optional_string_arg(args, "window");
106                AgentAction::Screenshot {
107                    mode,
108                    monitor,
109                    region,
110                    window,
111                }
112            },
113
114            "list_windows" => AgentAction::ListWindows,
115
116            "click" => {
117                let x = Self::get_int_arg(args, "x")? as i32;
118                let y = Self::get_int_arg(args, "y")? as i32;
119                let button = Self::get_optional_string_arg(args, "button")
120                    .unwrap_or_else(|| "left".to_string());
121                let screenshot_id = Self::get_int_arg(args, "screenshot_id")
122                    .ok()
123                    .map(|v| v as u64);
124                AgentAction::Click {
125                    x,
126                    y,
127                    button,
128                    screenshot_id,
129                }
130            },
131
132            "type_text" => {
133                let text = Self::get_string_arg(args, "text")?;
134                AgentAction::TypeText { text }
135            },
136
137            "press_key" => {
138                let key = Self::get_string_arg(args, "key")?;
139                AgentAction::PressKey { key }
140            },
141
142            "scroll" => {
143                let direction = Self::get_string_arg(args, "direction")?;
144                let amount = Self::get_int_arg(args, "amount").unwrap_or(3) as i32;
145                AgentAction::Scroll { direction, amount }
146            },
147
148            "mouse_move" => {
149                let x = Self::get_int_arg(args, "x")? as i32;
150                let y = Self::get_int_arg(args, "y")? as i32;
151                let screenshot_id = Self::get_int_arg(args, "screenshot_id")
152                    .ok()
153                    .map(|v| v as u64);
154                AgentAction::MouseMove {
155                    x,
156                    y,
157                    screenshot_id,
158                }
159            },
160
161            // MCP tools: mcp__{server_name}__{tool_name}
162            name if name.starts_with("mcp__") => {
163                let rest = &name[5..]; // skip "mcp__"
164                if let Some((server_name, tool_name)) = rest.split_once("__") {
165                    AgentAction::McpToolCall {
166                        server_name: server_name.to_string(),
167                        tool_name: tool_name.to_string(),
168                        arguments: args.clone(),
169                    }
170                } else {
171                    return Err(anyhow!(
172                        "Invalid MCP tool name format: '{}'. Expected 'mcp__{{server}}__{{tool}}'.",
173                        name
174                    ));
175                }
176            },
177
178            name => {
179                return Err(anyhow!(
180                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
181                    name
182                ));
183            },
184        };
185
186        Ok(action)
187    }
188
189    // Helper methods for argument extraction
190
191    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
192        args.get(key)
193            .and_then(|v| v.as_str())
194            .map(|s| s.to_string())
195            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
196    }
197
198    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
199        args.get(key)
200            .and_then(|v| v.as_str())
201            .map(|s| s.to_string())
202    }
203
204    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
205        args.get(key)
206            .and_then(|v| v.as_u64())
207            .map(|n| n as usize)
208            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use serde_json::json;
216
217    #[test]
218    fn test_parse_read_file_tool_call() {
219        let tool_call = ToolCall {
220            id: Some("call_123".to_string()),
221            function: FunctionCall {
222                name: "read_file".to_string(),
223                arguments: json!({
224                    "path": "src/main.rs"
225                }),
226            },
227        };
228
229        let action = tool_call.to_agent_action().unwrap();
230        match action {
231            AgentAction::ReadFile { paths } => {
232                assert_eq!(paths.len(), 1);
233                assert_eq!(paths[0], "src/main.rs");
234            },
235            _ => panic!("Expected ReadFile action"),
236        }
237    }
238
239    #[test]
240    fn test_parse_write_file_tool_call() {
241        let tool_call = ToolCall {
242            id: None,
243            function: FunctionCall {
244                name: "write_file".to_string(),
245                arguments: json!({
246                    "path": "test.txt",
247                    "content": "Hello, world!"
248                }),
249            },
250        };
251
252        let action = tool_call.to_agent_action().unwrap();
253        match action {
254            AgentAction::WriteFile { path, content } => {
255                assert_eq!(path, "test.txt");
256                assert_eq!(content, "Hello, world!");
257            },
258            _ => panic!("Expected WriteFile action"),
259        }
260    }
261
262    #[test]
263    fn test_parse_execute_command_tool_call() {
264        let tool_call = ToolCall {
265            id: None,
266            function: FunctionCall {
267                name: "execute_command".to_string(),
268                arguments: json!({
269                    "command": "cargo test",
270                    "working_dir": "/path/to/project"
271                }),
272            },
273        };
274
275        let action = tool_call.to_agent_action().unwrap();
276        match action {
277            AgentAction::ExecuteCommand {
278                command,
279                working_dir,
280                timeout,
281            } => {
282                assert_eq!(command, "cargo test");
283                assert_eq!(working_dir, Some("/path/to/project".to_string()));
284                assert_eq!(timeout, None);
285            },
286            _ => panic!("Expected ExecuteCommand action"),
287        }
288    }
289
290    #[test]
291    fn test_parse_web_search_tool_call() {
292        let tool_call = ToolCall {
293            id: None,
294            function: FunctionCall {
295                name: "web_search".to_string(),
296                arguments: json!({
297                    "query": "Rust async features",
298                    "result_count": 5
299                }),
300            },
301        };
302
303        let action = tool_call.to_agent_action().unwrap();
304        match action {
305            AgentAction::WebSearch { queries } => {
306                assert_eq!(queries.len(), 1);
307                assert_eq!(queries[0].0, "Rust async features");
308                assert_eq!(queries[0].1, 5);
309            },
310            _ => panic!("Expected WebSearch action"),
311        }
312    }
313
314    #[test]
315    fn test_parse_agent_tool_call() {
316        let tool_call = ToolCall {
317            id: Some("call_agent_1".to_string()),
318            function: FunctionCall {
319                name: "agent".to_string(),
320                arguments: json!({
321                    "prompt": "Read all files in src/models/ and summarize them",
322                    "description": "Read src/models/ files"
323                }),
324            },
325        };
326
327        let action = tool_call.to_agent_action().unwrap();
328        match action {
329            AgentAction::SpawnAgent {
330                prompt,
331                description,
332            } => {
333                assert!(prompt.contains("src/models/"));
334                assert_eq!(description, "Read src/models/ files");
335            },
336            _ => panic!("Expected SpawnAgent action"),
337        }
338    }
339
340    #[test]
341    fn test_unknown_tool_returns_error() {
342        let tool_call = ToolCall {
343            id: None,
344            function: FunctionCall {
345                name: "unknown_tool".to_string(),
346                arguments: json!({}),
347            },
348        };
349
350        assert!(tool_call.to_agent_action().is_err());
351    }
352}