Skip to main content

mermaid_cli/models/
tools.rs

1/// Ollama Tools API support for native function calling
2///
3/// This module defines Mermaid's available tools in Ollama's JSON Schema format,
4/// replacing the legacy text-based action block system.
5
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use std::sync::LazyLock;
9
10/// A tool available to the model (Ollama format)
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Tool {
13    #[serde(rename = "type")]
14    pub type_: String,
15    pub function: ToolFunction,
16}
17
18/// Function definition for a tool
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ToolFunction {
21    pub name: String,
22    pub description: String,
23    pub parameters: serde_json::Value,
24}
25
26/// Registry of all available Mermaid tools
27pub struct ToolRegistry {
28    tools: Vec<Tool>,
29}
30
31/// Cached Ollama JSON format for the static tool definitions.
32/// Built once on first access, reused for every chat() call.
33static OLLAMA_TOOLS_CACHE: LazyLock<Vec<serde_json::Value>> = LazyLock::new(|| {
34    let registry = ToolRegistry::mermaid_tools();
35    registry.tools.iter().map(|t| json!(t)).collect()
36});
37
38impl ToolRegistry {
39    /// Create a new registry with all Mermaid tools
40    pub fn mermaid_tools() -> Self {
41        Self {
42            tools: vec![
43                Self::read_file_tool(),
44                Self::write_file_tool(),
45                Self::delete_file_tool(),
46                Self::create_directory_tool(),
47                Self::execute_command_tool(),
48                Self::git_diff_tool(),
49                Self::git_status_tool(),
50                Self::git_commit_tool(),
51                Self::web_search_tool(),
52                Self::web_fetch_tool(),
53            ],
54        }
55    }
56
57    /// Get all tools in Ollama JSON format (cached statically)
58    pub fn to_ollama_format(&self) -> Vec<serde_json::Value> {
59        OLLAMA_TOOLS_CACHE.clone()
60    }
61
62    /// Get a reference to the cached Ollama tool definitions without constructing a registry
63    pub fn ollama_tools_cached() -> &'static [serde_json::Value] {
64        &OLLAMA_TOOLS_CACHE
65    }
66
67    /// Get all tools
68    pub fn tools(&self) -> &[Tool] {
69        &self.tools
70    }
71
72    // Tool Definitions
73
74    fn read_file_tool() -> Tool {
75        Tool {
76            type_: "function".to_string(),
77            function: ToolFunction {
78                name: "read_file".to_string(),
79                description: "Read a file from the filesystem. Can read files anywhere on the system the user has access to, including outside the current project directory. Supports text files, PDFs (sent to vision models), and images.".to_string(),
80                parameters: json!({
81                    "type": "object",
82                    "properties": {
83                        "path": {
84                            "type": "string",
85                            "description": "Absolute or relative path to the file to read. Use absolute paths (e.g., /home/user/file.pdf) for files outside the project."
86                        }
87                    },
88                    "required": ["path"]
89                }),
90            },
91        }
92    }
93
94    fn write_file_tool() -> Tool {
95        Tool {
96            type_: "function".to_string(),
97            function: ToolFunction {
98                name: "write_file".to_string(),
99                description: "Write or create a file in the current project directory. Creates parent directories if they don't exist. Creates a timestamped backup if the file already exists.".to_string(),
100                parameters: json!({
101                    "type": "object",
102                    "properties": {
103                        "path": {
104                            "type": "string",
105                            "description": "Path to the file to write, relative to the project root or absolute (must be within project)"
106                        },
107                        "content": {
108                            "type": "string",
109                            "description": "The complete file content to write"
110                        }
111                    },
112                    "required": ["path", "content"]
113                }),
114            },
115        }
116    }
117
118    fn delete_file_tool() -> Tool {
119        Tool {
120            type_: "function".to_string(),
121            function: ToolFunction {
122                name: "delete_file".to_string(),
123                description: "Delete a file from the project directory. Creates a timestamped backup before deletion for recovery.".to_string(),
124                parameters: json!({
125                    "type": "object",
126                    "properties": {
127                        "path": {
128                            "type": "string",
129                            "description": "Path to the file to delete"
130                        }
131                    },
132                    "required": ["path"]
133                }),
134            },
135        }
136    }
137
138    fn create_directory_tool() -> Tool {
139        Tool {
140            type_: "function".to_string(),
141            function: ToolFunction {
142                name: "create_directory".to_string(),
143                description: "Create a new directory in the project. Creates parent directories if needed.".to_string(),
144                parameters: json!({
145                    "type": "object",
146                    "properties": {
147                        "path": {
148                            "type": "string",
149                            "description": "Path to the directory to create"
150                        }
151                    },
152                    "required": ["path"]
153                }),
154            },
155        }
156    }
157
158    fn execute_command_tool() -> Tool {
159        Tool {
160            type_: "function".to_string(),
161            function: ToolFunction {
162                name: "execute_command".to_string(),
163                description: "Execute a shell command. Use for running tests, builds, git operations, or any terminal command.".to_string(),
164                parameters: json!({
165                    "type": "object",
166                    "properties": {
167                        "command": {
168                            "type": "string",
169                            "description": "The shell command to execute (e.g., 'cargo test', 'npm install')"
170                        },
171                        "working_dir": {
172                            "type": "string",
173                            "description": "Optional working directory to run the command in. Defaults to project root."
174                        }
175                    },
176                    "required": ["command"]
177                }),
178            },
179        }
180    }
181
182    fn git_diff_tool() -> Tool {
183        Tool {
184            type_: "function".to_string(),
185            function: ToolFunction {
186                name: "git_diff".to_string(),
187                description: "Show git diff for staged and unstaged changes. Can show diff for specific files or entire repository.".to_string(),
188                parameters: json!({
189                    "type": "object",
190                    "properties": {
191                        "path": {
192                            "type": "string",
193                            "description": "Optional specific file path to show diff for. If omitted, shows diff for entire repository."
194                        }
195                    },
196                    "required": []
197                }),
198            },
199        }
200    }
201
202    fn git_status_tool() -> Tool {
203        Tool {
204            type_: "function".to_string(),
205            function: ToolFunction {
206                name: "git_status".to_string(),
207                description: "Show the current git repository status including staged, unstaged, and untracked files.".to_string(),
208                parameters: json!({
209                    "type": "object",
210                    "properties": {},
211                    "required": []
212                }),
213            },
214        }
215    }
216
217    fn git_commit_tool() -> Tool {
218        Tool {
219            type_: "function".to_string(),
220            function: ToolFunction {
221                name: "git_commit".to_string(),
222                description: "Create a git commit with specified message and files.".to_string(),
223                parameters: json!({
224                    "type": "object",
225                    "properties": {
226                        "message": {
227                            "type": "string",
228                            "description": "Commit message"
229                        },
230                        "files": {
231                            "type": "array",
232                            "items": {
233                                "type": "string"
234                            },
235                            "description": "List of file paths to include in the commit"
236                        }
237                    },
238                    "required": ["message", "files"]
239                }),
240            },
241        }
242    }
243
244    fn web_search_tool() -> Tool {
245        Tool {
246            type_: "function".to_string(),
247            function: ToolFunction {
248                name: "web_search".to_string(),
249                description: "Search the web for information. Returns full page content in markdown format for deep analysis. Use for current information, library documentation, version-specific questions, or any time-sensitive data.".to_string(),
250                parameters: json!({
251                    "type": "object",
252                    "properties": {
253                        "query": {
254                            "type": "string",
255                            "description": "Search query. Be specific and include version numbers when relevant (e.g., 'Rust async tokio 1.40 new features')"
256                        },
257                        "max_results": {
258                            "type": "integer",
259                            "description": "Number of results to fetch (1-10). Use 3 for simple facts, 5-7 for research, 10 for comprehensive analysis.",
260                            "minimum": 1,
261                            "maximum": 10
262                        }
263                    },
264                    "required": ["query", "max_results"]
265                }),
266            },
267        }
268    }
269
270    fn web_fetch_tool() -> Tool {
271        Tool {
272            type_: "function".to_string(),
273            function: ToolFunction {
274                name: "web_fetch".to_string(),
275                description: "Fetch content from a URL and return it as clean markdown. Use for reading documentation pages, articles, GitHub READMEs, or any web page the user references.".to_string(),
276                parameters: json!({
277                    "type": "object",
278                    "properties": {
279                        "url": {
280                            "type": "string",
281                            "description": "The URL to fetch content from (e.g., 'https://docs.rs/tokio/latest')"
282                        }
283                    },
284                    "required": ["url"]
285                }),
286            },
287        }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_tool_registry_creation() {
297        let registry = ToolRegistry::mermaid_tools();
298        assert_eq!(registry.tools().len(), 10, "Should have 10 tools defined");
299    }
300
301    #[test]
302    fn test_tool_serialization() {
303        let registry = ToolRegistry::mermaid_tools();
304        let ollama_tools = registry.to_ollama_format();
305
306        assert_eq!(ollama_tools.len(), 10);
307
308        // Verify first tool has correct structure
309        let first_tool = &ollama_tools[0];
310        assert!(first_tool.get("type").is_some());
311        assert!(first_tool.get("function").is_some());
312    }
313
314    #[test]
315    fn test_read_file_tool_schema() {
316        let tool = ToolRegistry::read_file_tool();
317        assert_eq!(tool.function.name, "read_file");
318        assert!(tool.function.description.contains("Read a file"));
319
320        let params = tool.function.parameters.as_object().unwrap();
321        assert!(params.get("properties").is_some());
322        assert!(params.get("required").is_some());
323    }
324}