Skip to main content

mermaid_cli/models/
tool_call.rs

1/// Tool call parsing and conversion to AgentAction
2///
3/// Handles deserialization of Ollama tool_calls responses and converts
4/// them to Mermaid's internal AgentAction enum.
5
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12/// A tool call from the model (Ollama format)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15    #[serde(default)]
16    pub id: Option<String>,
17    pub function: FunctionCall,
18}
19
20/// Function call details
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23    pub name: String,
24    pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28    /// Convert Ollama tool call to Mermaid AgentAction
29    pub fn to_agent_action(&self) -> Result<AgentAction> {
30        let args = &self.function.arguments;
31
32        let action = match self.function.name.as_str() {
33            "read_file" => {
34                let path = Self::get_string_arg(args, "path")?;
35                AgentAction::ReadFile { paths: vec![path] }
36            }
37
38            "write_file" => {
39                let path = Self::get_string_arg(args, "path")?;
40                let content = Self::get_string_arg(args, "content")?;
41                AgentAction::WriteFile { path, content }
42            }
43
44            "delete_file" => {
45                let path = Self::get_string_arg(args, "path")?;
46                AgentAction::DeleteFile { path }
47            }
48
49            "create_directory" => {
50                let path = Self::get_string_arg(args, "path")?;
51                AgentAction::CreateDirectory { path }
52            }
53
54            "execute_command" => {
55                let command = Self::get_string_arg(args, "command")?;
56                let working_dir = Self::get_optional_string_arg(args, "working_dir");
57                AgentAction::ExecuteCommand {
58                    command,
59                    working_dir,
60                }
61            }
62
63            "git_diff" => {
64                let path = Self::get_optional_string_arg(args, "path");
65                AgentAction::GitDiff { paths: vec![path] }
66            }
67
68            "git_status" => AgentAction::GitStatus,
69
70            "git_commit" => {
71                let message = Self::get_string_arg(args, "message")?;
72                let files = Self::get_string_array_arg(args, "files")?;
73                AgentAction::GitCommit { message, files }
74            }
75
76            "web_search" => {
77                let query = Self::get_string_arg(args, "query")?;
78                let result_count = Self::get_int_arg(args, "result_count")
79                    .unwrap_or(5)
80                    .clamp(1, 10);
81                AgentAction::WebSearch {
82                    queries: vec![(query, result_count)],
83                }
84            }
85
86            name => {
87                return Err(anyhow!(
88                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
89                    name
90                ))
91            }
92        };
93
94        Ok(action)
95    }
96
97    // Helper methods for argument extraction
98
99    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
100        args.get(key)
101            .and_then(|v| v.as_str())
102            .map(|s| s.to_string())
103            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
104    }
105
106    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
107        args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
108    }
109
110    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
111        args.get(key)
112            .and_then(|v| v.as_u64())
113            .map(|n| n as usize)
114            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
115    }
116
117    fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
118        args.get(key)
119            .and_then(|v| v.as_array())
120            .map(|arr| {
121                arr.iter()
122                    .filter_map(|item| item.as_str().map(|s| s.to_string()))
123                    .collect()
124            })
125            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
126    }
127}
128
129/// Parse multiple tool calls into agent actions
130pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
131    tool_calls
132        .iter()
133        .filter_map(|tc| match tc.to_agent_action() {
134            Ok(action) => Some(action),
135            Err(e) => {
136                warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
137                None
138            }
139        })
140        .collect()
141}
142
143/// Group consecutive same-type read operations into a single ReadFile action
144/// The executor will decide whether to parallelize based on the number of paths
145pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
146    if actions.is_empty() {
147        return actions;
148    }
149
150    let mut result = Vec::new();
151    let mut current_group: Vec<String> = Vec::new();
152
153    for action in actions {
154        match action {
155            AgentAction::ReadFile { paths } => {
156                current_group.extend(paths);
157            }
158            other => {
159                // Flush current read group
160                if !current_group.is_empty() {
161                    result.push(AgentAction::ReadFile {
162                        paths: std::mem::take(&mut current_group),
163                    });
164                }
165                result.push(other);
166            }
167        }
168    }
169
170    // Flush remaining read group
171    if !current_group.is_empty() {
172        result.push(AgentAction::ReadFile {
173            paths: current_group,
174        });
175    }
176
177    result
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use serde_json::json;
184
185    #[test]
186    fn test_parse_read_file_tool_call() {
187        let tool_call = ToolCall {
188            id: Some("call_123".to_string()),
189            function: FunctionCall {
190                name: "read_file".to_string(),
191                arguments: json!({
192                    "path": "src/main.rs"
193                }),
194            },
195        };
196
197        let action = tool_call.to_agent_action().unwrap();
198        match action {
199            AgentAction::ReadFile { paths } => {
200                assert_eq!(paths.len(), 1);
201                assert_eq!(paths[0], "src/main.rs");
202            }
203            _ => panic!("Expected ReadFile action"),
204        }
205    }
206
207    #[test]
208    fn test_parse_write_file_tool_call() {
209        let tool_call = ToolCall {
210            id: None,
211            function: FunctionCall {
212                name: "write_file".to_string(),
213                arguments: json!({
214                    "path": "test.txt",
215                    "content": "Hello, world!"
216                }),
217            },
218        };
219
220        let action = tool_call.to_agent_action().unwrap();
221        match action {
222            AgentAction::WriteFile { path, content } => {
223                assert_eq!(path, "test.txt");
224                assert_eq!(content, "Hello, world!");
225            }
226            _ => panic!("Expected WriteFile action"),
227        }
228    }
229
230    #[test]
231    fn test_parse_execute_command_tool_call() {
232        let tool_call = ToolCall {
233            id: None,
234            function: FunctionCall {
235                name: "execute_command".to_string(),
236                arguments: json!({
237                    "command": "cargo test",
238                    "working_dir": "/path/to/project"
239                }),
240            },
241        };
242
243        let action = tool_call.to_agent_action().unwrap();
244        match action {
245            AgentAction::ExecuteCommand {
246                command,
247                working_dir,
248            } => {
249                assert_eq!(command, "cargo test");
250                assert_eq!(working_dir, Some("/path/to/project".to_string()));
251            }
252            _ => panic!("Expected ExecuteCommand action"),
253        }
254    }
255
256    #[test]
257    fn test_parse_web_search_tool_call() {
258        let tool_call = ToolCall {
259            id: None,
260            function: FunctionCall {
261                name: "web_search".to_string(),
262                arguments: json!({
263                    "query": "Rust async features",
264                    "result_count": 5
265                }),
266            },
267        };
268
269        let action = tool_call.to_agent_action().unwrap();
270        match action {
271            AgentAction::WebSearch { queries } => {
272                assert_eq!(queries.len(), 1);
273                assert_eq!(queries[0].0, "Rust async features");
274                assert_eq!(queries[0].1, 5);
275            }
276            _ => panic!("Expected WebSearch action"),
277        }
278    }
279
280    #[test]
281    fn test_unknown_tool_returns_error() {
282        let tool_call = ToolCall {
283            id: None,
284            function: FunctionCall {
285                name: "unknown_tool".to_string(),
286                arguments: json!({}),
287            },
288        };
289
290        assert!(tool_call.to_agent_action().is_err());
291    }
292
293    #[test]
294    fn test_group_parallel_reads() {
295        let actions = vec![
296            AgentAction::ReadFile {
297                paths: vec!["file1.rs".to_string()],
298            },
299            AgentAction::ReadFile {
300                paths: vec!["file2.rs".to_string()],
301            },
302            AgentAction::ReadFile {
303                paths: vec!["file3.rs".to_string()],
304            },
305        ];
306
307        let grouped = group_parallel_reads(actions);
308        assert_eq!(grouped.len(), 1);
309
310        match &grouped[0] {
311            AgentAction::ReadFile { paths } => {
312                assert_eq!(paths.len(), 3);
313                assert_eq!(paths[0], "file1.rs");
314                assert_eq!(paths[1], "file2.rs");
315                assert_eq!(paths[2], "file3.rs");
316            }
317            _ => panic!("Expected ReadFile action"),
318        }
319    }
320
321    #[test]
322    fn test_group_parallel_reads_single_read() {
323        let actions = vec![AgentAction::ReadFile {
324            paths: vec!["file1.rs".to_string()],
325        }];
326
327        let grouped = group_parallel_reads(actions);
328        assert_eq!(grouped.len(), 1);
329
330        match &grouped[0] {
331            AgentAction::ReadFile { paths } => {
332                assert_eq!(paths.len(), 1);
333                assert_eq!(paths[0], "file1.rs");
334            }
335            _ => panic!("Expected ReadFile action"),
336        }
337    }
338}