mermaid_cli/models/
tool_call.rs

1/// Tool call parsing and conversion to AgentAction
2///
3/// Handles deserialization of Ollama tool_calls responses and converts
4/// them to Mermaid's internal AgentAction enum.
5
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8
9use crate::agents::AgentAction;
10
11/// A tool call from the model (Ollama format)
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolCall {
14    #[serde(default)]
15    pub id: Option<String>,
16    pub function: FunctionCall,
17}
18
19/// Function call details
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct FunctionCall {
22    pub name: String,
23    pub arguments: serde_json::Value,
24}
25
26impl ToolCall {
27    /// Convert Ollama tool call to Mermaid AgentAction
28    pub fn to_agent_action(&self) -> Result<AgentAction> {
29        let args = &self.function.arguments;
30
31        let action = match self.function.name.as_str() {
32            "read_file" => {
33                let path = Self::get_string_arg(args, "path")?;
34                AgentAction::ReadFile { path }
35            }
36
37            "write_file" => {
38                let path = Self::get_string_arg(args, "path")?;
39                let content = Self::get_string_arg(args, "content")?;
40                AgentAction::WriteFile { path, content }
41            }
42
43            "delete_file" => {
44                let path = Self::get_string_arg(args, "path")?;
45                AgentAction::DeleteFile { path }
46            }
47
48            "create_directory" => {
49                let path = Self::get_string_arg(args, "path")?;
50                AgentAction::CreateDirectory { path }
51            }
52
53            "execute_command" => {
54                let command = Self::get_string_arg(args, "command")?;
55                let working_dir = Self::get_optional_string_arg(args, "working_dir");
56                AgentAction::ExecuteCommand {
57                    command,
58                    working_dir,
59                }
60            }
61
62            "git_diff" => {
63                let path = Self::get_optional_string_arg(args, "path");
64                AgentAction::GitDiff { path }
65            }
66
67            "git_status" => AgentAction::GitStatus,
68
69            "git_commit" => {
70                let message = Self::get_string_arg(args, "message")?;
71                let files = Self::get_string_array_arg(args, "files")?;
72                AgentAction::GitCommit { message, files }
73            }
74
75            "web_search" => {
76                let query = Self::get_string_arg(args, "query")?;
77                let result_count = Self::get_int_arg(args, "result_count")
78                    .unwrap_or(5)
79                    .clamp(1, 10);
80                AgentAction::WebSearch {
81                    query,
82                    result_count,
83                }
84            }
85
86            name => {
87                return Err(anyhow!(
88                    "Unknown tool: '{}'. Model attempted to call a tool that doesn't exist.",
89                    name
90                ))
91            }
92        };
93
94        Ok(action)
95    }
96
97    // Helper methods for argument extraction
98
99    fn get_string_arg(args: &serde_json::Value, key: &str) -> Result<String> {
100        args.get(key)
101            .and_then(|v| v.as_str())
102            .map(|s| s.to_string())
103            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
104    }
105
106    fn get_optional_string_arg(args: &serde_json::Value, key: &str) -> Option<String> {
107        args.get(key).and_then(|v| v.as_str()).map(|s| s.to_string())
108    }
109
110    fn get_int_arg(args: &serde_json::Value, key: &str) -> Result<usize> {
111        args.get(key)
112            .and_then(|v| v.as_u64())
113            .map(|n| n as usize)
114            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
115    }
116
117    fn get_string_array_arg(args: &serde_json::Value, key: &str) -> Result<Vec<String>> {
118        args.get(key)
119            .and_then(|v| v.as_array())
120            .map(|arr| {
121                arr.iter()
122                    .filter_map(|item| item.as_str().map(|s| s.to_string()))
123                    .collect()
124            })
125            .ok_or_else(|| anyhow!("Missing or invalid required argument: '{}'", key))
126    }
127}
128
129/// Parse multiple tool calls into agent actions
130pub fn parse_tool_calls(tool_calls: &[ToolCall]) -> Vec<AgentAction> {
131    tool_calls
132        .iter()
133        .filter_map(|tc| match tc.to_agent_action() {
134            Ok(action) => Some(action),
135            Err(e) => {
136                eprintln!("Failed to parse tool call '{}': {}", tc.function.name, e);
137                None
138            }
139        })
140        .collect()
141}
142
143/// Group consecutive same-type read operations into parallel reads
144pub fn group_parallel_reads(actions: Vec<AgentAction>) -> Vec<AgentAction> {
145    if actions.is_empty() {
146        return actions;
147    }
148
149    let mut result = Vec::new();
150    let mut current_group: Vec<String> = Vec::new();
151
152    for action in actions {
153        match action {
154            AgentAction::ReadFile { path } => {
155                current_group.push(path);
156            }
157            other => {
158                // Flush current read group if it has multiple items
159                if current_group.len() > 1 {
160                    result.push(AgentAction::ParallelRead {
161                        paths: current_group.clone(),
162                    });
163                } else if current_group.len() == 1 {
164                    result.push(AgentAction::ReadFile {
165                        path: current_group[0].clone(),
166                    });
167                }
168                current_group.clear();
169
170                result.push(other);
171            }
172        }
173    }
174
175    // Flush remaining read group
176    if current_group.len() > 1 {
177        result.push(AgentAction::ParallelRead {
178            paths: current_group,
179        });
180    } else if current_group.len() == 1 {
181        result.push(AgentAction::ReadFile {
182            path: current_group[0].clone(),
183        });
184    }
185
186    result
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use serde_json::json;
193
194    #[test]
195    fn test_parse_read_file_tool_call() {
196        let tool_call = ToolCall {
197            id: Some("call_123".to_string()),
198            function: FunctionCall {
199                name: "read_file".to_string(),
200                arguments: json!({
201                    "path": "src/main.rs"
202                }),
203            },
204        };
205
206        let action = tool_call.to_agent_action().unwrap();
207        match action {
208            AgentAction::ReadFile { path } => assert_eq!(path, "src/main.rs"),
209            _ => panic!("Expected ReadFile action"),
210        }
211    }
212
213    #[test]
214    fn test_parse_write_file_tool_call() {
215        let tool_call = ToolCall {
216            id: None,
217            function: FunctionCall {
218                name: "write_file".to_string(),
219                arguments: json!({
220                    "path": "test.txt",
221                    "content": "Hello, world!"
222                }),
223            },
224        };
225
226        let action = tool_call.to_agent_action().unwrap();
227        match action {
228            AgentAction::WriteFile { path, content } => {
229                assert_eq!(path, "test.txt");
230                assert_eq!(content, "Hello, world!");
231            }
232            _ => panic!("Expected WriteFile action"),
233        }
234    }
235
236    #[test]
237    fn test_parse_execute_command_tool_call() {
238        let tool_call = ToolCall {
239            id: None,
240            function: FunctionCall {
241                name: "execute_command".to_string(),
242                arguments: json!({
243                    "command": "cargo test",
244                    "working_dir": "/path/to/project"
245                }),
246            },
247        };
248
249        let action = tool_call.to_agent_action().unwrap();
250        match action {
251            AgentAction::ExecuteCommand {
252                command,
253                working_dir,
254            } => {
255                assert_eq!(command, "cargo test");
256                assert_eq!(working_dir, Some("/path/to/project".to_string()));
257            }
258            _ => panic!("Expected ExecuteCommand action"),
259        }
260    }
261
262    #[test]
263    fn test_parse_web_search_tool_call() {
264        let tool_call = ToolCall {
265            id: None,
266            function: FunctionCall {
267                name: "web_search".to_string(),
268                arguments: json!({
269                    "query": "Rust async features",
270                    "result_count": 5
271                }),
272            },
273        };
274
275        let action = tool_call.to_agent_action().unwrap();
276        match action {
277            AgentAction::WebSearch {
278                query,
279                result_count,
280            } => {
281                assert_eq!(query, "Rust async features");
282                assert_eq!(result_count, 5);
283            }
284            _ => panic!("Expected WebSearch action"),
285        }
286    }
287
288    #[test]
289    fn test_unknown_tool_returns_error() {
290        let tool_call = ToolCall {
291            id: None,
292            function: FunctionCall {
293                name: "unknown_tool".to_string(),
294                arguments: json!({}),
295            },
296        };
297
298        assert!(tool_call.to_agent_action().is_err());
299    }
300
301    #[test]
302    fn test_group_parallel_reads() {
303        let actions = vec![
304            AgentAction::ReadFile {
305                path: "file1.rs".to_string(),
306            },
307            AgentAction::ReadFile {
308                path: "file2.rs".to_string(),
309            },
310            AgentAction::ReadFile {
311                path: "file3.rs".to_string(),
312            },
313        ];
314
315        let grouped = group_parallel_reads(actions);
316        assert_eq!(grouped.len(), 1);
317
318        match &grouped[0] {
319            AgentAction::ParallelRead { paths } => {
320                assert_eq!(paths.len(), 3);
321            }
322            _ => panic!("Expected ParallelRead action"),
323        }
324    }
325
326    #[test]
327    fn test_group_parallel_reads_single_read() {
328        let actions = vec![AgentAction::ReadFile {
329            path: "file1.rs".to_string(),
330        }];
331
332        let grouped = group_parallel_reads(actions);
333        assert_eq!(grouped.len(), 1);
334
335        match &grouped[0] {
336            AgentAction::ReadFile { path } => {
337                assert_eq!(path, "file1.rs");
338            }
339            _ => panic!("Expected ReadFile action"),
340        }
341    }
342}