Skip to main content

batuta/serve/banco/
tools.rs

1//! Tool calling framework — OpenAI-compatible function calling with self-healing.
2//!
3//! Tools are registered functions that models can invoke during chat.
4//! Built-in tools: code_execution (sandbox), calculator.
5//! Custom tools can be registered via the API.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::RwLock;
10
11/// Tool definition — describes a callable function.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolDefinition {
14    pub name: String,
15    pub description: String,
16    /// JSON Schema for the tool's parameters.
17    pub parameters: serde_json::Value,
18    /// Whether this tool is enabled.
19    pub enabled: bool,
20    /// Privacy tier requirement (None = all tiers).
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub required_tier: Option<String>,
23}
24
25/// A tool call request from the model.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ToolCall {
28    pub id: String,
29    pub name: String,
30    pub arguments: serde_json::Value,
31}
32
33/// A tool call result.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolResult {
36    pub tool_call_id: String,
37    pub name: String,
38    pub content: String,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub error: Option<String>,
41}
42
43/// Tool registry — manages available tools.
44pub struct ToolRegistry {
45    tools: RwLock<HashMap<String, ToolDefinition>>,
46}
47
48impl ToolRegistry {
49    /// Create registry with built-in tools.
50    #[must_use]
51    pub fn new() -> Self {
52        let mut tools = HashMap::new();
53
54        // Built-in: calculator
55        tools.insert(
56            "calculator".to_string(),
57            ToolDefinition {
58                name: "calculator".to_string(),
59                description: "Evaluate a mathematical expression".to_string(),
60                parameters: serde_json::json!({
61                    "type": "object",
62                    "properties": {
63                        "expression": {
64                            "type": "string",
65                            "description": "Mathematical expression to evaluate (e.g., '2 + 2 * 3')"
66                        }
67                    },
68                    "required": ["expression"]
69                }),
70                enabled: true,
71                required_tier: None,
72            },
73        );
74
75        // Built-in: code_execution
76        tools.insert(
77            "code_execution".to_string(),
78            ToolDefinition {
79                name: "code_execution".to_string(),
80                description: "Execute code in a sandboxed environment".to_string(),
81                parameters: serde_json::json!({
82                    "type": "object",
83                    "properties": {
84                        "language": {
85                            "type": "string",
86                            "enum": ["python", "bash", "rust"],
87                            "description": "Programming language"
88                        },
89                        "code": {
90                            "type": "string",
91                            "description": "Code to execute"
92                        }
93                    },
94                    "required": ["language", "code"]
95                }),
96                enabled: true,
97                required_tier: None,
98            },
99        );
100
101        // Built-in: web_search (Standard tier only)
102        tools.insert(
103            "web_search".to_string(),
104            ToolDefinition {
105                name: "web_search".to_string(),
106                description: "Search the web for information".to_string(),
107                parameters: serde_json::json!({
108                    "type": "object",
109                    "properties": {
110                        "query": {
111                            "type": "string",
112                            "description": "Search query"
113                        },
114                        "max_results": {
115                            "type": "integer",
116                            "description": "Maximum results to return",
117                            "default": 5
118                        }
119                    },
120                    "required": ["query"]
121                }),
122                enabled: false, // disabled by default — requires Standard tier
123                required_tier: Some("Standard".to_string()),
124            },
125        );
126
127        Self { tools: RwLock::new(tools) }
128    }
129
130    /// List all tools.
131    #[must_use]
132    pub fn list(&self) -> Vec<ToolDefinition> {
133        let store = self.tools.read().unwrap_or_else(|e| e.into_inner());
134        let mut tools: Vec<ToolDefinition> = store.values().cloned().collect();
135        tools.sort_by(|a, b| a.name.cmp(&b.name));
136        tools
137    }
138
139    /// List enabled tools for a given privacy tier.
140    #[must_use]
141    pub fn list_for_tier(&self, tier: &str) -> Vec<ToolDefinition> {
142        self.list()
143            .into_iter()
144            .filter(|t| t.enabled)
145            .filter(|t| {
146                t.required_tier.as_ref().is_none_or(|req| req == tier || tier == "Standard")
147            })
148            .collect()
149    }
150
151    /// Get a tool by name.
152    #[must_use]
153    pub fn get(&self, name: &str) -> Option<ToolDefinition> {
154        self.tools.read().unwrap_or_else(|e| e.into_inner()).get(name).cloned()
155    }
156
157    /// Register a custom tool.
158    pub fn register(&self, tool: ToolDefinition) {
159        if let Ok(mut store) = self.tools.write() {
160            store.insert(tool.name.clone(), tool);
161        }
162    }
163
164    /// Enable/disable a tool.
165    pub fn set_enabled(&self, name: &str, enabled: bool) -> bool {
166        if let Ok(mut store) = self.tools.write() {
167            if let Some(tool) = store.get_mut(name) {
168                tool.enabled = enabled;
169                return true;
170            }
171        }
172        false
173    }
174
175    /// Execute a tool call. Returns the result.
176    #[must_use]
177    pub fn execute(&self, call: &ToolCall) -> ToolResult {
178        match call.name.as_str() {
179            "calculator" => execute_calculator(call),
180            "code_execution" => execute_code_sandbox(call),
181            "web_search" => ToolResult {
182                tool_call_id: call.id.clone(),
183                name: call.name.clone(),
184                content: String::new(),
185                error: Some("Web search not implemented in sovereign mode".to_string()),
186            },
187            _ => ToolResult {
188                tool_call_id: call.id.clone(),
189                name: call.name.clone(),
190                content: String::new(),
191                error: Some(format!("Unknown tool: {}", call.name)),
192            },
193        }
194    }
195
196    /// Execute with self-healing retry.
197    /// If the call fails, injects the error and allows the caller to re-prompt.
198    /// Returns (result, retry_messages) where retry_messages is the error context for re-prompting.
199    pub fn execute_with_retry(&self, call: &ToolCall, max_retries: usize) -> ToolCallOutcome {
200        let result = self.execute(call);
201
202        if result.error.is_some() && max_retries > 0 {
203            // Return the error context for the caller to inject and re-prompt
204            let error_context = format!(
205                "Tool call to '{}' failed: {}. Please fix the arguments and try again.",
206                call.name,
207                result.error.as_deref().unwrap_or("unknown error")
208            );
209            ToolCallOutcome {
210                result,
211                should_retry: true,
212                error_context: Some(error_context),
213                retries_remaining: max_retries - 1,
214            }
215        } else {
216            ToolCallOutcome {
217                result,
218                should_retry: false,
219                error_context: None,
220                retries_remaining: 0,
221            }
222        }
223    }
224}
225
226/// Outcome of a tool call with retry information.
227#[derive(Debug, Clone, Serialize)]
228pub struct ToolCallOutcome {
229    pub result: ToolResult,
230    pub should_retry: bool,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub error_context: Option<String>,
233    pub retries_remaining: usize,
234}
235
236impl Default for ToolRegistry {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242/// Built-in calculator — evaluates simple math expressions.
243fn execute_calculator(call: &ToolCall) -> ToolResult {
244    let expr = call.arguments.get("expression").and_then(|v| v.as_str()).unwrap_or("");
245
246    let result = eval_math(expr);
247
248    ToolResult {
249        tool_call_id: call.id.clone(),
250        name: call.name.clone(),
251        content: match &result {
252            Ok(val) => val.to_string(),
253            Err(_) => String::new(),
254        },
255        error: result.err().map(|e| e.to_string()),
256    }
257}
258
259/// Simple math expression evaluator (supports +, -, *, /, parentheses).
260fn eval_math(expr: &str) -> Result<f64, String> {
261    let tokens: Vec<char> = expr.chars().filter(|c| !c.is_whitespace()).collect();
262    if tokens.is_empty() {
263        return Err("Empty expression".to_string());
264    }
265    let mut pos = 0;
266    let result = parse_expr(&tokens, &mut pos)?;
267    if pos < tokens.len() {
268        return Err(format!("Unexpected character at position {pos}"));
269    }
270    Ok(result)
271}
272
273fn parse_expr(tokens: &[char], pos: &mut usize) -> Result<f64, String> {
274    let mut left = parse_term(tokens, pos)?;
275    while *pos < tokens.len() && (tokens[*pos] == '+' || tokens[*pos] == '-') {
276        let op = tokens[*pos];
277        *pos += 1;
278        let right = parse_term(tokens, pos)?;
279        left = if op == '+' { left + right } else { left - right };
280    }
281    Ok(left)
282}
283
284fn parse_term(tokens: &[char], pos: &mut usize) -> Result<f64, String> {
285    let mut left = parse_factor(tokens, pos)?;
286    while *pos < tokens.len() && (tokens[*pos] == '*' || tokens[*pos] == '/') {
287        let op = tokens[*pos];
288        *pos += 1;
289        let right = parse_factor(tokens, pos)?;
290        if op == '/' && right == 0.0 {
291            return Err("Division by zero".to_string());
292        }
293        left = if op == '*' { left * right } else { left / right };
294    }
295    Ok(left)
296}
297
298fn parse_factor(tokens: &[char], pos: &mut usize) -> Result<f64, String> {
299    if *pos >= tokens.len() {
300        return Err("Unexpected end of expression".to_string());
301    }
302
303    // Negation
304    if tokens[*pos] == '-' {
305        *pos += 1;
306        let val = parse_factor(tokens, pos)?;
307        return Ok(-val);
308    }
309
310    // Parentheses
311    if tokens[*pos] == '(' {
312        *pos += 1;
313        let val = parse_expr(tokens, pos)?;
314        if *pos >= tokens.len() || tokens[*pos] != ')' {
315            return Err("Missing closing parenthesis".to_string());
316        }
317        *pos += 1;
318        return Ok(val);
319    }
320
321    // Number
322    let start = *pos;
323    while *pos < tokens.len() && (tokens[*pos].is_ascii_digit() || tokens[*pos] == '.') {
324        *pos += 1;
325    }
326    if start == *pos {
327        return Err(format!("Expected number at position {start}"));
328    }
329    let num_str: String = tokens[start..*pos].iter().collect();
330    num_str.parse::<f64>().map_err(|e| e.to_string())
331}
332
333/// Code execution sandbox (dry-run — actual sandbox requires jugar-probar).
334fn execute_code_sandbox(call: &ToolCall) -> ToolResult {
335    let language = call.arguments.get("language").and_then(|v| v.as_str()).unwrap_or("unknown");
336    let code = call.arguments.get("code").and_then(|v| v.as_str()).unwrap_or("");
337
338    // Dry-run: echo what would be executed
339    let content = format!(
340        "{{\"stdout\": \"[sandbox dry-run] Would execute {language} code ({} chars)\", \"stderr\": \"\", \"exit_code\": 0}}",
341        code.len()
342    );
343
344    ToolResult { tool_call_id: call.id.clone(), name: call.name.clone(), content, error: None }
345}