Skip to main content

mermaid_cli/models/
tools.rs

1/// Ollama Tools API support for native function calling
2///
3/// This module defines Mermaid's available tools in Ollama's JSON Schema format,
4/// replacing the legacy text-based action block system.
5
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use std::sync::LazyLock;
9
10/// A tool available to the model (Ollama format)
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Tool {
13    #[serde(rename = "type")]
14    pub type_: String,
15    pub function: ToolFunction,
16}
17
18/// Function definition for a tool
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ToolFunction {
21    pub name: String,
22    pub description: String,
23    pub parameters: serde_json::Value,
24}
25
26/// Registry of all available Mermaid tools
27pub struct ToolRegistry {
28    tools: Vec<Tool>,
29}
30
31/// Cached Ollama JSON format for the static tool definitions.
32/// Built once on first access, reused for every chat() call.
33static OLLAMA_TOOLS_CACHE: LazyLock<Vec<serde_json::Value>> = LazyLock::new(|| {
34    let registry = ToolRegistry::mermaid_tools();
35    registry.tools.iter().map(|t| json!(t)).collect()
36});
37
38impl ToolRegistry {
39    /// Create a new registry with all Mermaid tools
40    pub fn mermaid_tools() -> Self {
41        Self {
42            tools: vec![
43                Self::read_file_tool(),
44                Self::write_file_tool(),
45                Self::delete_file_tool(),
46                Self::create_directory_tool(),
47                Self::execute_command_tool(),
48                Self::git_diff_tool(),
49                Self::git_status_tool(),
50                Self::git_commit_tool(),
51                Self::edit_file_tool(),
52                Self::web_search_tool(),
53                Self::web_fetch_tool(),
54            ],
55        }
56    }
57
58    /// Get all tools in Ollama JSON format (cached statically)
59    pub fn to_ollama_format(&self) -> Vec<serde_json::Value> {
60        OLLAMA_TOOLS_CACHE.clone()
61    }
62
63    /// Get a reference to the cached Ollama tool definitions without constructing a registry
64    pub fn ollama_tools_cached() -> &'static [serde_json::Value] {
65        &OLLAMA_TOOLS_CACHE
66    }
67
68    /// Get all tools
69    pub fn tools(&self) -> &[Tool] {
70        &self.tools
71    }
72
73    // Tool Definitions
74
75    fn read_file_tool() -> Tool {
76        Tool {
77            type_: "function".to_string(),
78            function: ToolFunction {
79                name: "read_file".to_string(),
80                description: "Read a file from the filesystem. Can read files anywhere on the system the user has access to, including outside the current project directory. Supports text files, PDFs (sent to vision models), and images.".to_string(),
81                parameters: json!({
82                    "type": "object",
83                    "properties": {
84                        "path": {
85                            "type": "string",
86                            "description": "Absolute or relative path to the file to read. Use absolute paths (e.g., /home/user/file.pdf) for files outside the project."
87                        }
88                    },
89                    "required": ["path"]
90                }),
91            },
92        }
93    }
94
95    fn write_file_tool() -> Tool {
96        Tool {
97            type_: "function".to_string(),
98            function: ToolFunction {
99                name: "write_file".to_string(),
100                description: "Write or create a file in the current project directory. Creates parent directories if they don't exist. Creates a timestamped backup if the file already exists.".to_string(),
101                parameters: json!({
102                    "type": "object",
103                    "properties": {
104                        "path": {
105                            "type": "string",
106                            "description": "Path to the file to write, relative to the project root or absolute (must be within project)"
107                        },
108                        "content": {
109                            "type": "string",
110                            "description": "The complete file content to write"
111                        }
112                    },
113                    "required": ["path", "content"]
114                }),
115            },
116        }
117    }
118
119    fn delete_file_tool() -> Tool {
120        Tool {
121            type_: "function".to_string(),
122            function: ToolFunction {
123                name: "delete_file".to_string(),
124                description: "Delete a file from the project directory. Creates a timestamped backup before deletion for recovery.".to_string(),
125                parameters: json!({
126                    "type": "object",
127                    "properties": {
128                        "path": {
129                            "type": "string",
130                            "description": "Path to the file to delete"
131                        }
132                    },
133                    "required": ["path"]
134                }),
135            },
136        }
137    }
138
139    fn create_directory_tool() -> Tool {
140        Tool {
141            type_: "function".to_string(),
142            function: ToolFunction {
143                name: "create_directory".to_string(),
144                description: "Create a new directory in the project. Creates parent directories if needed.".to_string(),
145                parameters: json!({
146                    "type": "object",
147                    "properties": {
148                        "path": {
149                            "type": "string",
150                            "description": "Path to the directory to create"
151                        }
152                    },
153                    "required": ["path"]
154                }),
155            },
156        }
157    }
158
159    fn execute_command_tool() -> Tool {
160        Tool {
161            type_: "function".to_string(),
162            function: ToolFunction {
163                name: "execute_command".to_string(),
164                description: "Execute a shell command. Use for running tests, builds, git operations, or any terminal command.".to_string(),
165                parameters: json!({
166                    "type": "object",
167                    "properties": {
168                        "command": {
169                            "type": "string",
170                            "description": "The shell command to execute (e.g., 'cargo test', 'npm install')"
171                        },
172                        "working_dir": {
173                            "type": "string",
174                            "description": "Optional working directory to run the command in. Defaults to project root."
175                        }
176                    },
177                    "required": ["command"]
178                }),
179            },
180        }
181    }
182
183    fn git_diff_tool() -> Tool {
184        Tool {
185            type_: "function".to_string(),
186            function: ToolFunction {
187                name: "git_diff".to_string(),
188                description: "Show git diff for staged and unstaged changes. Can show diff for specific files or entire repository.".to_string(),
189                parameters: json!({
190                    "type": "object",
191                    "properties": {
192                        "path": {
193                            "type": "string",
194                            "description": "Optional specific file path to show diff for. If omitted, shows diff for entire repository."
195                        }
196                    },
197                    "required": []
198                }),
199            },
200        }
201    }
202
203    fn git_status_tool() -> Tool {
204        Tool {
205            type_: "function".to_string(),
206            function: ToolFunction {
207                name: "git_status".to_string(),
208                description: "Show the current git repository status including staged, unstaged, and untracked files.".to_string(),
209                parameters: json!({
210                    "type": "object",
211                    "properties": {},
212                    "required": []
213                }),
214            },
215        }
216    }
217
218    fn git_commit_tool() -> Tool {
219        Tool {
220            type_: "function".to_string(),
221            function: ToolFunction {
222                name: "git_commit".to_string(),
223                description: "Create a git commit with specified message and files.".to_string(),
224                parameters: json!({
225                    "type": "object",
226                    "properties": {
227                        "message": {
228                            "type": "string",
229                            "description": "Commit message"
230                        },
231                        "files": {
232                            "type": "array",
233                            "items": {
234                                "type": "string"
235                            },
236                            "description": "List of file paths to include in the commit"
237                        }
238                    },
239                    "required": ["message", "files"]
240                }),
241            },
242        }
243    }
244
245    fn edit_file_tool() -> Tool {
246        Tool {
247            type_: "function".to_string(),
248            function: ToolFunction {
249                name: "edit_file".to_string(),
250                description: "Make targeted edits to a file by replacing specific text. \
251                    The old_string must match exactly and uniquely in the file. \
252                    Prefer this over write_file for modifying existing files.".to_string(),
253                parameters: json!({
254                    "type": "object",
255                    "properties": {
256                        "path": {
257                            "type": "string",
258                            "description": "Path to the file to edit"
259                        },
260                        "old_string": {
261                            "type": "string",
262                            "description": "The exact text to find and replace (must be unique in the file)"
263                        },
264                        "new_string": {
265                            "type": "string",
266                            "description": "The new text to replace old_string with"
267                        }
268                    },
269                    "required": ["path", "old_string", "new_string"]
270                }),
271            },
272        }
273    }
274
275    fn web_search_tool() -> Tool {
276        Tool {
277            type_: "function".to_string(),
278            function: ToolFunction {
279                name: "web_search".to_string(),
280                description: "Search the web for information. Returns full page content in markdown format for deep analysis. Use for current information, library documentation, version-specific questions, or any time-sensitive data.".to_string(),
281                parameters: json!({
282                    "type": "object",
283                    "properties": {
284                        "query": {
285                            "type": "string",
286                            "description": "Search query. Be specific and include version numbers when relevant (e.g., 'Rust async tokio 1.40 new features')"
287                        },
288                        "max_results": {
289                            "type": "integer",
290                            "description": "Number of results to fetch (1-10). Use 3 for simple facts, 5-7 for research, 10 for comprehensive analysis.",
291                            "minimum": 1,
292                            "maximum": 10
293                        }
294                    },
295                    "required": ["query", "max_results"]
296                }),
297            },
298        }
299    }
300
301    fn web_fetch_tool() -> Tool {
302        Tool {
303            type_: "function".to_string(),
304            function: ToolFunction {
305                name: "web_fetch".to_string(),
306                description: "Fetch content from a URL and return it as clean markdown. Use for reading documentation pages, articles, GitHub READMEs, or any web page the user references.".to_string(),
307                parameters: json!({
308                    "type": "object",
309                    "properties": {
310                        "url": {
311                            "type": "string",
312                            "description": "The URL to fetch content from (e.g., 'https://docs.rs/tokio/latest')"
313                        }
314                    },
315                    "required": ["url"]
316                }),
317            },
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_tool_registry_creation() {
328        let registry = ToolRegistry::mermaid_tools();
329        assert_eq!(registry.tools().len(), 11, "Should have 11 tools defined");
330    }
331
332    #[test]
333    fn test_tool_serialization() {
334        let registry = ToolRegistry::mermaid_tools();
335        let ollama_tools = registry.to_ollama_format();
336
337        assert_eq!(ollama_tools.len(), 11);
338
339        // Verify first tool has correct structure
340        let first_tool = &ollama_tools[0];
341        assert!(first_tool.get("type").is_some());
342        assert!(first_tool.get("function").is_some());
343    }
344
345    #[test]
346    fn test_read_file_tool_schema() {
347        let tool = ToolRegistry::read_file_tool();
348        assert_eq!(tool.function.name, "read_file");
349        assert!(tool.function.description.contains("Read a file"));
350
351        let params = tool.function.parameters.as_object().unwrap();
352        assert!(params.get("properties").is_some());
353        assert!(params.get("required").is_some());
354    }
355}