Skip to main content

llama_cpp_v3_agent_sdk/
tool.rs

1use crate::error::AgentError;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5// ---------------------------------------------------------------------------
6// Tool trait
7// ---------------------------------------------------------------------------
8
9/// A tool that the agent can invoke.
10///
11/// Implement this trait to create custom tools. Built-in tools (bash, read,
12/// write, edit, glob) already ship with the crate.
13pub trait Tool: Send + Sync {
14    /// Unique name used by the model to invoke this tool (e.g. `"bash"`).
15    fn name(&self) -> &str;
16
17    /// One-line human-readable description shown to the model.
18    fn description(&self) -> &str;
19
20    /// JSON Schema description of the parameters object.
21    /// Must be a valid JSON object schema, e.g.:
22    /// ```json
23    /// {
24    ///     "type": "object",
25    ///     "properties": {
26    ///         "command": { "type": "string", "description": "Shell command" }
27    ///     },
28    ///     "required": ["command"]
29    /// }
30    /// ```
31    fn parameters_schema(&self) -> serde_json::Value;
32
33    /// Execute the tool with the given JSON arguments.
34    /// Returns the tool output as a string (will be injected back into the
35    /// conversation as the tool result).
36    fn execute(&self, args: &serde_json::Value) -> Result<ToolResult, AgentError>;
37
38    /// Whether this tool requires user permission before execution.
39    /// Defaults to `true` for safety.
40    fn requires_permission(&self) -> bool {
41        true
42    }
43
44    /// Whether this tool is considered dangerous (shown as a warning).
45    fn is_dangerous(&self, _args: &serde_json::Value) -> bool {
46        false
47    }
48}
49
50// ---------------------------------------------------------------------------
51// Tool call / result types
52// ---------------------------------------------------------------------------
53
54/// A parsed tool call extracted from model output.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ToolCall {
57    /// Name of the tool to invoke.
58    pub name: String,
59    /// JSON arguments for the tool.
60    pub arguments: serde_json::Value,
61}
62
63/// Result returned after executing a tool.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ToolResult {
66    /// Whether the tool execution succeeded.
67    pub success: bool,
68    /// Output text from the tool.
69    pub output: String,
70}
71
72impl ToolResult {
73    pub fn ok(output: impl Into<String>) -> Self {
74        Self {
75            success: true,
76            output: output.into(),
77        }
78    }
79
80    pub fn err(output: impl Into<String>) -> Self {
81        Self {
82            success: false,
83            output: output.into(),
84        }
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Tool registry
90// ---------------------------------------------------------------------------
91
92/// Registry of available tools, keyed by name.
93pub struct ToolRegistry {
94    tools: HashMap<String, Box<dyn Tool>>,
95}
96
97impl ToolRegistry {
98    pub fn new() -> Self {
99        Self {
100            tools: HashMap::new(),
101        }
102    }
103
104    /// Register a tool. Overwrites any existing tool with the same name.
105    pub fn register(&mut self, tool: Box<dyn Tool>) {
106        let name = tool.name().to_string();
107        self.tools.insert(name, tool);
108    }
109
110    /// Look up a tool by name.
111    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
112        self.tools.get(name).map(|t| t.as_ref())
113    }
114
115    /// Execute a tool call, returning the result.
116    pub fn execute(&self, call: &ToolCall) -> Result<ToolResult, AgentError> {
117        let tool = self
118            .get(&call.name)
119            .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?;
120        tool.execute(&call.arguments)
121    }
122
123    /// Returns an iterator over all registered tools.
124    pub fn iter(&self) -> impl Iterator<Item = &dyn Tool> {
125        self.tools.values().map(|t| t.as_ref())
126    }
127
128    /// Generate a system prompt fragment describing all available tools.
129    ///
130    /// This produces JSON tool descriptions that can be appended to the system
131    /// prompt so the model knows what tools are available and how to call them.
132    pub fn tools_prompt(&self) -> String {
133        if self.is_empty() {
134            return String::new();
135        }
136        let mut lines = Vec::new();
137        lines.push("# Tools\n".to_string());
138        lines.push("You have access to the following tools. To use a tool, output a tool call in this exact format:\n".to_string());
139        lines.push("<tool_call>".to_string());
140        lines.push(r#"{"name": "<tool_name>", "arguments": {<json_args>}}"#.to_string());
141        lines.push("</tool_call>\n".to_string());
142        lines.push("Available tools:\n".to_string());
143
144        for tool in self.iter() {
145            let schema = serde_json::json!({
146                "name": tool.name(),
147                "description": tool.description(),
148                "parameters": tool.parameters_schema(),
149            });
150            lines.push(format!(
151                "- {}\n```json\n{}\n```\n",
152                tool.name(),
153                serde_json::to_string_pretty(&schema).unwrap_or_default()
154            ));
155        }
156
157        lines.push("When you want to use a tool, output ONLY the <tool_call> block. You may use multiple tool calls in a single response. After each tool call, wait for the tool result before continuing.".to_string());
158
159        lines.join("\n")
160    }
161
162    /// Total count of registered tools.
163    pub fn len(&self) -> usize {
164        self.tools.len()
165    }
166
167    pub fn is_empty(&self) -> bool {
168        self.tools.is_empty()
169    }
170}
171
172impl Default for ToolRegistry {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178// ---------------------------------------------------------------------------
179// Tool call parser
180// ---------------------------------------------------------------------------
181
182/// Parse `<tool_call>...</tool_call>` blocks from model output.
183///
184/// Returns a list of parsed tool calls and the remaining text fragments.
185pub fn parse_tool_calls(text: &str) -> (Vec<ToolCall>, Vec<String>) {
186    let mut calls = Vec::new();
187    let mut text_parts = Vec::new();
188    let mut remaining = text;
189
190    loop {
191        if let Some(start) = remaining.find("<tool_call>") {
192            let before = &remaining[..start];
193            if !before.trim().is_empty() {
194                text_parts.push(before.trim().to_string());
195            }
196
197            let after_tag = &remaining[start + "<tool_call>".len()..];
198            if let Some(end) = after_tag.find("</tool_call>") {
199                let json_str = after_tag[..end].trim();
200                match serde_json::from_str::<ToolCall>(json_str) {
201                    Ok(call) => calls.push(call),
202                    Err(e) => {
203                        // Try to be lenient — maybe the model wrapped it differently
204                        text_parts.push(format!("[Failed to parse tool call: {}]", e));
205                    }
206                }
207                remaining = &after_tag[end + "</tool_call>".len()..];
208            } else {
209                // Unclosed tag — treat the rest as text
210                text_parts.push(remaining.to_string());
211                break;
212            }
213        } else {
214            if !remaining.trim().is_empty() {
215                text_parts.push(remaining.trim().to_string());
216            }
217            break;
218        }
219    }
220
221    (calls, text_parts)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_parse_single_tool_call() {
230        let text = r#"Let me check that for you.
231<tool_call>
232{"name": "bash", "arguments": {"command": "ls -la"}}
233</tool_call>
234"#;
235        let (calls, text_parts) = parse_tool_calls(text);
236        assert_eq!(calls.len(), 1);
237        assert_eq!(calls[0].name, "bash");
238        assert_eq!(text_parts.len(), 1);
239        assert!(text_parts[0].contains("Let me check"));
240    }
241
242    #[test]
243    fn test_parse_multiple_tool_calls() {
244        let text = r#"I'll read both files.
245<tool_call>
246{"name": "read", "arguments": {"path": "a.txt"}}
247</tool_call>
248And now the second one:
249<tool_call>
250{"name": "read", "arguments": {"path": "b.txt"}}
251</tool_call>
252Done."#;
253        let (calls, text_parts) = parse_tool_calls(text);
254        assert_eq!(calls.len(), 2);
255        assert_eq!(calls[0].name, "read");
256        assert_eq!(calls[1].name, "read");
257        assert_eq!(text_parts.len(), 3);
258    }
259
260    #[test]
261    fn test_parse_no_tool_calls() {
262        let text = "Just a normal response with no tools.";
263        let (calls, text_parts) = parse_tool_calls(text);
264        assert_eq!(calls.len(), 0);
265        assert_eq!(text_parts.len(), 1);
266    }
267}