Skip to main content

mermaid_cli/models/
tool_call.rs

1/// Tool call parsing and conversion to AgentAction
2///
3/// Handles deserialization of Ollama tool_calls responses and converts
4/// them to Mermaid's internal AgentAction enum.
5
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12/// A tool call from the model (Ollama format)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15    #[serde(default)]
16    pub id: Option<String>,
17    pub function: FunctionCall,
18}
19
20/// Function call details
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23    pub name: String,
24    pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28    /// Convert Ollama tool call to Mermaid AgentAction
29    pub fn to_agent_action(&self) -> Result<AgentAction> {
30        let args = &self.function.arguments;
31
32        let action = match self.function.name.as_str() {
33            "read_file" => {
34                let path = Self::get_string_arg(args, "path")?;
35                AgentAction::ReadFile { paths: vec![path] }
36            }
37
38            "write_file" => {
39                let path = Self::get_string_arg(args, "path")?;
40                let content = Self::get_string_arg(args, "content")?;
41                AgentAction::WriteFile { path, content }
42            }
43
44            "delete_file" => {
45                let path = Self::get_string_arg(args, "path")?;
46                AgentAction::DeleteFile { path }
47            }
48
49            "create_directory" => {
50                let path = Self::get_string_arg(args, "path")?;
51                AgentAction::CreateDirectory { path }
52            }
53
54            "execute_command" => {
55                let command = Self::get_string_arg(args, "command")?;
56                let working_dir = Self::get_optional_string_arg(args, "working_dir");
57                let timeout = args.get("timeout").and_then(|v| v.as_u64());
58                AgentAction::ExecuteCommand {
59                    command,
60                    working_dir,
61                    timeout,
62                }
63            }
64
65            "git_diff" => {
66                let path = Self::get_optional_string_arg(args, "path");
67                AgentAction::GitDiff { paths: vec![path] }
68            }
69
70            "git_status" => AgentAction::GitStatus,
71
72            "git_commit" => {
73                let message = Self::get_string_arg(args, "message")?;
74                let files = Self::get_string_array_arg(args, "files")?;
75                AgentAction::GitCommit { message, files }
76            }
77
78            "web_search" => {
79                let query = Self::get_string_arg(args, "query")?;
80                let max_results = Self::get_int_arg(args, "max_results")
81                    .or_else(|_| Self::get_int_arg(args, "result_count"))
82                    .unwrap_or(5)
83                    .clamp(1, 10);
84                AgentAction::WebSearch {
85                    queries: vec![(query, max_results)],
86                }
87            }
88
89            "edit_file" => {
90                let path = Self::get_string_arg(args, "path")?;
91                let old_string = Self::get_string_arg(args, "old_string")?;
92                let new_string = Self::get_string_arg(args, "new_string")?;
93                AgentAction::EditFile { path, old_string, new_string }
94            }
95
96            "web_fetch" => {
97                let url = Self::get_string_arg(args, "url")?;
98                AgentAction::WebFetch { url }
99            }
100
101            name => {
102                return Err(anyhow!(
103                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
104                    name
105                ))
106            }
107        };
108
109        Ok(action)
110    }
111
112    // Helper methods for argument extraction
113
114    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
115        args.get(key)
116            .and_then(|v| v.as_str())
117            .map(|s| s.to_string())
118            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
119    }
120
121    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
122        args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
123    }
124
125    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
126        args.get(key)
127            .and_then(|v| v.as_u64())
128            .map(|n| n as usize)
129            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
130    }
131
132    fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
133        args.get(key)
134            .and_then(|v| v.as_array())
135            .map(|arr| {
136                arr.iter()
137                    .filter_map(|item| item.as_str().map(|s| s.to_string()))
138                    .collect()
139            })
140            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
141    }
142}
143
144/// Parse multiple tool calls into agent actions
145pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
146    tool_calls
147        .iter()
148        .filter_map(|tc| match tc.to_agent_action() {
149            Ok(action) => Some(action),
150            Err(e) => {
151                warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
152                None
153            }
154        })
155        .collect()
156}
157
158/// Group consecutive same-type read operations into a single ReadFile action
159/// The executor will decide whether to parallelize based on the number of paths
160pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
161    if actions.is_empty() {
162        return actions;
163    }
164
165    let mut result = Vec::new();
166    let mut current_group: Vec<String> = Vec::new();
167
168    for action in actions {
169        match action {
170            AgentAction::ReadFile { paths } => {
171                current_group.extend(paths);
172            }
173            other => {
174                // Flush current read group
175                if !current_group.is_empty() {
176                    result.push(AgentAction::ReadFile {
177                        paths: std::mem::take(&mut current_group),
178                    });
179                }
180                result.push(other);
181            }
182        }
183    }
184
185    // Flush remaining read group
186    if !current_group.is_empty() {
187        result.push(AgentAction::ReadFile {
188            paths: current_group,
189        });
190    }
191
192    result
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use serde_json::json;
199
200    #[test]
201    fn test_parse_read_file_tool_call() {
202        let tool_call = ToolCall {
203            id: Some("call_123".to_string()),
204            function: FunctionCall {
205                name: "read_file".to_string(),
206                arguments: json!({
207                    "path": "src/main.rs"
208                }),
209            },
210        };
211
212        let action = tool_call.to_agent_action().unwrap();
213        match action {
214            AgentAction::ReadFile { paths } => {
215                assert_eq!(paths.len(), 1);
216                assert_eq!(paths[0], "src/main.rs");
217            }
218            _ => panic!("Expected ReadFile action"),
219        }
220    }
221
222    #[test]
223    fn test_parse_write_file_tool_call() {
224        let tool_call = ToolCall {
225            id: None,
226            function: FunctionCall {
227                name: "write_file".to_string(),
228                arguments: json!({
229                    "path": "test.txt",
230                    "content": "Hello, world!"
231                }),
232            },
233        };
234
235        let action = tool_call.to_agent_action().unwrap();
236        match action {
237            AgentAction::WriteFile { path, content } => {
238                assert_eq!(path, "test.txt");
239                assert_eq!(content, "Hello, world!");
240            }
241            _ => panic!("Expected WriteFile action"),
242        }
243    }
244
245    #[test]
246    fn test_parse_execute_command_tool_call() {
247        let tool_call = ToolCall {
248            id: None,
249            function: FunctionCall {
250                name: "execute_command".to_string(),
251                arguments: json!({
252                    "command": "cargo test",
253                    "working_dir": "/path/to/project"
254                }),
255            },
256        };
257
258        let action = tool_call.to_agent_action().unwrap();
259        match action {
260            AgentAction::ExecuteCommand {
261                command,
262                working_dir,
263                timeout,
264            } => {
265                assert_eq!(command, "cargo test");
266                assert_eq!(working_dir, Some("/path/to/project".to_string()));
267                assert_eq!(timeout, None);
268            }
269            _ => panic!("Expected ExecuteCommand action"),
270        }
271    }
272
273    #[test]
274    fn test_parse_web_search_tool_call() {
275        let tool_call = ToolCall {
276            id: None,
277            function: FunctionCall {
278                name: "web_search".to_string(),
279                arguments: json!({
280                    "query": "Rust async features",
281                    "result_count": 5
282                }),
283            },
284        };
285
286        let action = tool_call.to_agent_action().unwrap();
287        match action {
288            AgentAction::WebSearch { queries } => {
289                assert_eq!(queries.len(), 1);
290                assert_eq!(queries[0].0, "Rust async features");
291                assert_eq!(queries[0].1, 5);
292            }
293            _ => panic!("Expected WebSearch action"),
294        }
295    }
296
297    #[test]
298    fn test_unknown_tool_returns_error() {
299        let tool_call = ToolCall {
300            id: None,
301            function: FunctionCall {
302                name: "unknown_tool".to_string(),
303                arguments: json!({}),
304            },
305        };
306
307        assert!(tool_call.to_agent_action().is_err());
308    }
309
310    #[test]
311    fn test_group_parallel_reads() {
312        let actions = vec![
313            AgentAction::ReadFile {
314                paths: vec!["file1.rs".to_string()],
315            },
316            AgentAction::ReadFile {
317                paths: vec!["file2.rs".to_string()],
318            },
319            AgentAction::ReadFile {
320                paths: vec!["file3.rs".to_string()],
321            },
322        ];
323
324        let grouped = group_parallel_reads(actions);
325        assert_eq!(grouped.len(), 1);
326
327        match &grouped[0] {
328            AgentAction::ReadFile { paths } => {
329                assert_eq!(paths.len(), 3);
330                assert_eq!(paths[0], "file1.rs");
331                assert_eq!(paths[1], "file2.rs");
332                assert_eq!(paths[2], "file3.rs");
333            }
334            _ => panic!("Expected ReadFile action"),
335        }
336    }
337
338    #[test]
339    fn test_group_parallel_reads_single_read() {
340        let actions = vec![AgentAction::ReadFile {
341            paths: vec!["file1.rs".to_string()],
342        }];
343
344        let grouped = group_parallel_reads(actions);
345        assert_eq!(grouped.len(), 1);
346
347        match &grouped[0] {
348            AgentAction::ReadFile { paths } => {
349                assert_eq!(paths.len(), 1);
350                assert_eq!(paths[0], "file1.rs");
351            }
352            _ => panic!("Expected ReadFile action"),
353        }
354    }
355}