Skip to main content

mermaid_cli/models/
tool_call.rs

1/// Tool call parsing and conversion to AgentAction
2///
3/// Handles deserialization of Ollama tool_calls responses and converts
4/// them to Mermaid's internal AgentAction enum.
5
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use tracing::warn;
9
10use crate::agents::AgentAction;
11
12/// A tool call from the model (Ollama format)
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15    #[serde(default)]
16    pub id: Option<String>,
17    pub function: FunctionCall,
18}
19
20/// Function call details
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FunctionCall {
23    pub name: String,
24    pub arguments: serde_json::Value,
25}
26
27impl ToolCall {
28    /// Convert Ollama tool call to Mermaid AgentAction
29    pub fn to_agent_action(&self) -> Result<AgentAction> {
30        let args = &self.function.arguments;
31
32        let action = match self.function.name.as_str() {
33            "read_file" => {
34                let path = Self::get_string_arg(args, "path")?;
35                AgentAction::ReadFile { paths: vec![path] }
36            }
37
38            "write_file" => {
39                let path = Self::get_string_arg(args, "path")?;
40                let content = Self::get_string_arg(args, "content")?;
41                AgentAction::WriteFile { path, content }
42            }
43
44            "delete_file" => {
45                let path = Self::get_string_arg(args, "path")?;
46                AgentAction::DeleteFile { path }
47            }
48
49            "create_directory" => {
50                let path = Self::get_string_arg(args, "path")?;
51                AgentAction::CreateDirectory { path }
52            }
53
54            "execute_command" => {
55                let command = Self::get_string_arg(args, "command")?;
56                let working_dir = Self::get_optional_string_arg(args, "working_dir");
57                AgentAction::ExecuteCommand {
58                    command,
59                    working_dir,
60                }
61            }
62
63            "git_diff" => {
64                let path = Self::get_optional_string_arg(args, "path");
65                AgentAction::GitDiff { paths: vec![path] }
66            }
67
68            "git_status" => AgentAction::GitStatus,
69
70            "git_commit" => {
71                let message = Self::get_string_arg(args, "message")?;
72                let files = Self::get_string_array_arg(args, "files")?;
73                AgentAction::GitCommit { message, files }
74            }
75
76            "web_search" => {
77                let query = Self::get_string_arg(args, "query")?;
78                let max_results = Self::get_int_arg(args, "max_results")
79                    .or_else(|_| Self::get_int_arg(args, "result_count"))
80                    .unwrap_or(5)
81                    .clamp(1, 10);
82                AgentAction::WebSearch {
83                    queries: vec![(query, max_results)],
84                }
85            }
86
87            "web_fetch" => {
88                let url = Self::get_string_arg(args, "url")?;
89                AgentAction::WebFetch { url }
90            }
91
92            name => {
93                return Err(anyhow!(
94                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
95                    name
96                ))
97            }
98        };
99
100        Ok(action)
101    }
102
103    // Helper methods for argument extraction
104
105    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
106        args.get(key)
107            .and_then(|v| v.as_str())
108            .map(|s| s.to_string())
109            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
110    }
111
112    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
113        args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
114    }
115
116    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
117        args.get(key)
118            .and_then(|v| v.as_u64())
119            .map(|n| n as usize)
120            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
121    }
122
123    fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
124        args.get(key)
125            .and_then(|v| v.as_array())
126            .map(|arr| {
127                arr.iter()
128                    .filter_map(|item| item.as_str().map(|s| s.to_string()))
129                    .collect()
130            })
131            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
132    }
133}
134
135/// Parse multiple tool calls into agent actions
136pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
137    tool_calls
138        .iter()
139        .filter_map(|tc| match tc.to_agent_action() {
140            Ok(action) => Some(action),
141            Err(e) => {
142                warn!(tool = %tc.function.name, "Failed to parse tool call: {}", e);
143                None
144            }
145        })
146        .collect()
147}
148
149/// Group consecutive same-type read operations into a single ReadFile action
150/// The executor will decide whether to parallelize based on the number of paths
151pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
152    if actions.is_empty() {
153        return actions;
154    }
155
156    let mut result = Vec::new();
157    let mut current_group: Vec<String> = Vec::new();
158
159    for action in actions {
160        match action {
161            AgentAction::ReadFile { paths } => {
162                current_group.extend(paths);
163            }
164            other => {
165                // Flush current read group
166                if !current_group.is_empty() {
167                    result.push(AgentAction::ReadFile {
168                        paths: std::mem::take(&mut current_group),
169                    });
170                }
171                result.push(other);
172            }
173        }
174    }
175
176    // Flush remaining read group
177    if !current_group.is_empty() {
178        result.push(AgentAction::ReadFile {
179            paths: current_group,
180        });
181    }
182
183    result
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use serde_json::json;
190
191    #[test]
192    fn test_parse_read_file_tool_call() {
193        let tool_call = ToolCall {
194            id: Some("call_123".to_string()),
195            function: FunctionCall {
196                name: "read_file".to_string(),
197                arguments: json!({
198                    "path": "src/main.rs"
199                }),
200            },
201        };
202
203        let action = tool_call.to_agent_action().unwrap();
204        match action {
205            AgentAction::ReadFile { paths } => {
206                assert_eq!(paths.len(), 1);
207                assert_eq!(paths[0], "src/main.rs");
208            }
209            _ => panic!("Expected ReadFile action"),
210        }
211    }
212
213    #[test]
214    fn test_parse_write_file_tool_call() {
215        let tool_call = ToolCall {
216            id: None,
217            function: FunctionCall {
218                name: "write_file".to_string(),
219                arguments: json!({
220                    "path": "test.txt",
221                    "content": "Hello, world!"
222                }),
223            },
224        };
225
226        let action = tool_call.to_agent_action().unwrap();
227        match action {
228            AgentAction::WriteFile { path, content } => {
229                assert_eq!(path, "test.txt");
230                assert_eq!(content, "Hello, world!");
231            }
232            _ => panic!("Expected WriteFile action"),
233        }
234    }
235
236    #[test]
237    fn test_parse_execute_command_tool_call() {
238        let tool_call = ToolCall {
239            id: None,
240            function: FunctionCall {
241                name: "execute_command".to_string(),
242                arguments: json!({
243                    "command": "cargo test",
244                    "working_dir": "/path/to/project"
245                }),
246            },
247        };
248
249        let action = tool_call.to_agent_action().unwrap();
250        match action {
251            AgentAction::ExecuteCommand {
252                command,
253                working_dir,
254            } => {
255                assert_eq!(command, "cargo test");
256                assert_eq!(working_dir, Some("/path/to/project".to_string()));
257            }
258            _ => panic!("Expected ExecuteCommand action"),
259        }
260    }
261
262    #[test]
263    fn test_parse_web_search_tool_call() {
264        let tool_call = ToolCall {
265            id: None,
266            function: FunctionCall {
267                name: "web_search".to_string(),
268                arguments: json!({
269                    "query": "Rust async features",
270                    "result_count": 5
271                }),
272            },
273        };
274
275        let action = tool_call.to_agent_action().unwrap();
276        match action {
277            AgentAction::WebSearch { queries } => {
278                assert_eq!(queries.len(), 1);
279                assert_eq!(queries[0].0, "Rust async features");
280                assert_eq!(queries[0].1, 5);
281            }
282            _ => panic!("Expected WebSearch action"),
283        }
284    }
285
286    #[test]
287    fn test_unknown_tool_returns_error() {
288        let tool_call = ToolCall {
289            id: None,
290            function: FunctionCall {
291                name: "unknown_tool".to_string(),
292                arguments: json!({}),
293            },
294        };
295
296        assert!(tool_call.to_agent_action().is_err());
297    }
298
299    #[test]
300    fn test_group_parallel_reads() {
301        let actions = vec![
302            AgentAction::ReadFile {
303                paths: vec!["file1.rs".to_string()],
304            },
305            AgentAction::ReadFile {
306                paths: vec!["file2.rs".to_string()],
307            },
308            AgentAction::ReadFile {
309                paths: vec!["file3.rs".to_string()],
310            },
311        ];
312
313        let grouped = group_parallel_reads(actions);
314        assert_eq!(grouped.len(), 1);
315
316        match &grouped[0] {
317            AgentAction::ReadFile { paths } => {
318                assert_eq!(paths.len(), 3);
319                assert_eq!(paths[0], "file1.rs");
320                assert_eq!(paths[1], "file2.rs");
321                assert_eq!(paths[2], "file3.rs");
322            }
323            _ => panic!("Expected ReadFile action"),
324        }
325    }
326
327    #[test]
328    fn test_group_parallel_reads_single_read() {
329        let actions = vec![AgentAction::ReadFile {
330            paths: vec!["file1.rs".to_string()],
331        }];
332
333        let grouped = group_parallel_reads(actions);
334        assert_eq!(grouped.len(), 1);
335
336        match &grouped[0] {
337            AgentAction::ReadFile { paths } => {
338                assert_eq!(paths.len(), 1);
339                assert_eq!(paths[0], "file1.rs");
340            }
341            _ => panic!("Expected ReadFile action"),
342        }
343    }
344}