Skip to main content

mermaid_cli/models/
tool_call.rs

1/// Tool call parsing and conversion to AgentAction
2///
3/// Handles deserialization of Ollama tool_calls responses and converts
4/// them to Mermaid's internal AgentAction enum.
5
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12/// A tool call from the model (Ollama format)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15    #[serde(default)]
16    pub id: Option<String>,
17    pub function: FunctionCall,
18}
19
20/// Function call details
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23    pub name: String,
24    pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28    /// Convert Ollama tool call to Mermaid AgentAction
29    pub fn to_agent_action(&self) -> Result<AgentAction> {
30        let args = &self.function.arguments;
31
32        let action = match self.function.name.as_str() {
33            "read_file" => {
34                let path = Self::get_string_arg(args, "path")?;
35                AgentAction::ReadFile { paths: vec![path] }
36            }
37
38            "write_file" => {
39                let path = Self::get_string_arg(args, "path")?;
40                let content = Self::get_string_arg(args, "content")?;
41                AgentAction::WriteFile { path, content }
42            }
43
44            "delete_file" => {
45                let path = Self::get_string_arg(args, "path")?;
46                AgentAction::DeleteFile { path }
47            }
48
49            "create_directory" => {
50                let path = Self::get_string_arg(args, "path")?;
51                AgentAction::CreateDirectory { path }
52            }
53
54            "execute_command" => {
55                let command = Self::get_string_arg(args, "command")?;
56                let working_dir = Self::get_optional_string_arg(args, "working_dir");
57                AgentAction::ExecuteCommand {
58                    command,
59                    working_dir,
60                }
61            }
62
63            "git_diff" => {
64                let path = Self::get_optional_string_arg(args, "path");
65                AgentAction::GitDiff { paths: vec![path] }
66            }
67
68            "git_status" => AgentAction::GitStatus,
69
70            "git_commit" => {
71                let message = Self::get_string_arg(args, "message")?;
72                let files = Self::get_string_array_arg(args, "files")?;
73                AgentAction::GitCommit { message, files }
74            }
75
76            "web_search" => {
77                let query = Self::get_string_arg(args, "query")?;
78                let max_results = Self::get_int_arg(args, "max_results")
79                    .or_else(|_| Self::get_int_arg(args, "result_count"))
80                    .unwrap_or(5)
81                    .clamp(1, 10);
82                AgentAction::WebSearch {
83                    queries: vec![(query, max_results)],
84                }
85            }
86
87            "edit_file" => {
88                let path = Self::get_string_arg(args, "path")?;
89                let old_string = Self::get_string_arg(args, "old_string")?;
90                let new_string = Self::get_string_arg(args, "new_string")?;
91                AgentAction::EditFile { path, old_string, new_string }
92            }
93
94            "web_fetch" => {
95                let url = Self::get_string_arg(args, "url")?;
96                AgentAction::WebFetch { url }
97            }
98
99            name => {
100                return Err(anyhow!(
101                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
102                    name
103                ))
104            }
105        };
106
107        Ok(action)
108    }
109
110    // Helper methods for argument extraction
111
112    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
113        args.get(key)
114            .and_then(|v| v.as_str())
115            .map(|s| s.to_string())
116            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
117    }
118
119    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
120        args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
121    }
122
123    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
124        args.get(key)
125            .and_then(|v| v.as_u64())
126            .map(|n| n as usize)
127            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
128    }
129
130    fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
131        args.get(key)
132            .and_then(|v| v.as_array())
133            .map(|arr| {
134                arr.iter()
135                    .filter_map(|item| item.as_str().map(|s| s.to_string()))
136                    .collect()
137            })
138            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
139    }
140}
141
142/// Parse multiple tool calls into agent actions
143pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
144    tool_calls
145        .iter()
146        .filter_map(|tc| match tc.to_agent_action() {
147            Ok(action) => Some(action),
148            Err(e) => {
149                warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
150                None
151            }
152        })
153        .collect()
154}
155
156/// Group consecutive same-type read operations into a single ReadFile action
157/// The executor will decide whether to parallelize based on the number of paths
158pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
159    if actions.is_empty() {
160        return actions;
161    }
162
163    let mut result = Vec::new();
164    let mut current_group: Vec<String> = Vec::new();
165
166    for action in actions {
167        match action {
168            AgentAction::ReadFile { paths } => {
169                current_group.extend(paths);
170            }
171            other => {
172                // Flush current read group
173                if !current_group.is_empty() {
174                    result.push(AgentAction::ReadFile {
175                        paths: std::mem::take(&mut current_group),
176                    });
177                }
178                result.push(other);
179            }
180        }
181    }
182
183    // Flush remaining read group
184    if !current_group.is_empty() {
185        result.push(AgentAction::ReadFile {
186            paths: current_group,
187        });
188    }
189
190    result
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use serde_json::json;
197
198    #[test]
199    fn test_parse_read_file_tool_call() {
200        let tool_call = ToolCall {
201            id: Some("call_123".to_string()),
202            function: FunctionCall {
203                name: "read_file".to_string(),
204                arguments: json!({
205                    "path": "src/main.rs"
206                }),
207            },
208        };
209
210        let action = tool_call.to_agent_action().unwrap();
211        match action {
212            AgentAction::ReadFile { paths } => {
213                assert_eq!(paths.len(), 1);
214                assert_eq!(paths[0], "src/main.rs");
215            }
216            _ => panic!("Expected ReadFile action"),
217        }
218    }
219
220    #[test]
221    fn test_parse_write_file_tool_call() {
222        let tool_call = ToolCall {
223            id: None,
224            function: FunctionCall {
225                name: "write_file".to_string(),
226                arguments: json!({
227                    "path": "test.txt",
228                    "content": "Hello, world!"
229                }),
230            },
231        };
232
233        let action = tool_call.to_agent_action().unwrap();
234        match action {
235            AgentAction::WriteFile { path, content } => {
236                assert_eq!(path, "test.txt");
237                assert_eq!(content, "Hello, world!");
238            }
239            _ => panic!("Expected WriteFile action"),
240        }
241    }
242
243    #[test]
244    fn test_parse_execute_command_tool_call() {
245        let tool_call = ToolCall {
246            id: None,
247            function: FunctionCall {
248                name: "execute_command".to_string(),
249                arguments: json!({
250                    "command": "cargo test",
251                    "working_dir": "/path/to/project"
252                }),
253            },
254        };
255
256        let action = tool_call.to_agent_action().unwrap();
257        match action {
258            AgentAction::ExecuteCommand {
259                command,
260                working_dir,
261            } => {
262                assert_eq!(command, "cargo test");
263                assert_eq!(working_dir, Some("/path/to/project".to_string()));
264            }
265            _ => panic!("Expected ExecuteCommand action"),
266        }
267    }
268
269    #[test]
270    fn test_parse_web_search_tool_call() {
271        let tool_call = ToolCall {
272            id: None,
273            function: FunctionCall {
274                name: "web_search".to_string(),
275                arguments: json!({
276                    "query": "Rust async features",
277                    "result_count": 5
278                }),
279            },
280        };
281
282        let action = tool_call.to_agent_action().unwrap();
283        match action {
284            AgentAction::WebSearch { queries } => {
285                assert_eq!(queries.len(), 1);
286                assert_eq!(queries[0].0, "Rust async features");
287                assert_eq!(queries[0].1, 5);
288            }
289            _ => panic!("Expected WebSearch action"),
290        }
291    }
292
293    #[test]
294    fn test_unknown_tool_returns_error() {
295        let tool_call = ToolCall {
296            id: None,
297            function: FunctionCall {
298                name: "unknown_tool".to_string(),
299                arguments: json!({}),
300            },
301        };
302
303        assert!(tool_call.to_agent_action().is_err());
304    }
305
306    #[test]
307    fn test_group_parallel_reads() {
308        let actions = vec![
309            AgentAction::ReadFile {
310                paths: vec!["file1.rs".to_string()],
311            },
312            AgentAction::ReadFile {
313                paths: vec!["file2.rs".to_string()],
314            },
315            AgentAction::ReadFile {
316                paths: vec!["file3.rs".to_string()],
317            },
318        ];
319
320        let grouped = group_parallel_reads(actions);
321        assert_eq!(grouped.len(), 1);
322
323        match &grouped[0] {
324            AgentAction::ReadFile { paths } => {
325                assert_eq!(paths.len(), 3);
326                assert_eq!(paths[0], "file1.rs");
327                assert_eq!(paths[1], "file2.rs");
328                assert_eq!(paths[2], "file3.rs");
329            }
330            _ => panic!("Expected ReadFile action"),
331        }
332    }
333
334    #[test]
335    fn test_group_parallel_reads_single_read() {
336        let actions = vec![AgentAction::ReadFile {
337            paths: vec!["file1.rs".to_string()],
338        }];
339
340        let grouped = group_parallel_reads(actions);
341        assert_eq!(grouped.len(), 1);
342
343        match &grouped[0] {
344            AgentAction::ReadFile { paths } => {
345                assert_eq!(paths.len(), 1);
346                assert_eq!(paths[0], "file1.rs");
347            }
348            _ => panic!("Expected ReadFile action"),
349        }
350    }
351}