1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::RwLock;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolDefinition {
14 pub name: String,
15 pub description: String,
16 pub parameters: serde_json::Value,
18 pub enabled: bool,
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub required_tier: Option<String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ToolCall {
28 pub id: String,
29 pub name: String,
30 pub arguments: serde_json::Value,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolResult {
36 pub tool_call_id: String,
37 pub name: String,
38 pub content: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub error: Option<String>,
41}
42
43pub struct ToolRegistry {
45 tools: RwLock<HashMap<String, ToolDefinition>>,
46}
47
48impl ToolRegistry {
49 #[must_use]
51 pub fn new() -> Self {
52 let mut tools = HashMap::new();
53
54 tools.insert(
56 "calculator".to_string(),
57 ToolDefinition {
58 name: "calculator".to_string(),
59 description: "Evaluate a mathematical expression".to_string(),
60 parameters: serde_json::json!({
61 "type": "object",
62 "properties": {
63 "expression": {
64 "type": "string",
65 "description": "Mathematical expression to evaluate (e.g., '2 + 2 * 3')"
66 }
67 },
68 "required": ["expression"]
69 }),
70 enabled: true,
71 required_tier: None,
72 },
73 );
74
75 tools.insert(
77 "code_execution".to_string(),
78 ToolDefinition {
79 name: "code_execution".to_string(),
80 description: "Execute code in a sandboxed environment".to_string(),
81 parameters: serde_json::json!({
82 "type": "object",
83 "properties": {
84 "language": {
85 "type": "string",
86 "enum": ["python", "bash", "rust"],
87 "description": "Programming language"
88 },
89 "code": {
90 "type": "string",
91 "description": "Code to execute"
92 }
93 },
94 "required": ["language", "code"]
95 }),
96 enabled: true,
97 required_tier: None,
98 },
99 );
100
101 tools.insert(
103 "web_search".to_string(),
104 ToolDefinition {
105 name: "web_search".to_string(),
106 description: "Search the web for information".to_string(),
107 parameters: serde_json::json!({
108 "type": "object",
109 "properties": {
110 "query": {
111 "type": "string",
112 "description": "Search query"
113 },
114 "max_results": {
115 "type": "integer",
116 "description": "Maximum results to return",
117 "default": 5
118 }
119 },
120 "required": ["query"]
121 }),
122 enabled: false, required_tier: Some("Standard".to_string()),
124 },
125 );
126
127 Self { tools: RwLock::new(tools) }
128 }
129
130 #[must_use]
132 pub fn list(&self) -> Vec<ToolDefinition> {
133 let store = self.tools.read().unwrap_or_else(|e| e.into_inner());
134 let mut tools: Vec<ToolDefinition> = store.values().cloned().collect();
135 tools.sort_by(|a, b| a.name.cmp(&b.name));
136 tools
137 }
138
139 #[must_use]
141 pub fn list_for_tier(&self, tier: &str) -> Vec<ToolDefinition> {
142 self.list()
143 .into_iter()
144 .filter(|t| t.enabled)
145 .filter(|t| {
146 t.required_tier.as_ref().is_none_or(|req| req == tier || tier == "Standard")
147 })
148 .collect()
149 }
150
151 #[must_use]
153 pub fn get(&self, name: &str) -> Option<ToolDefinition> {
154 self.tools.read().unwrap_or_else(|e| e.into_inner()).get(name).cloned()
155 }
156
157 pub fn register(&self, tool: ToolDefinition) {
159 if let Ok(mut store) = self.tools.write() {
160 store.insert(tool.name.clone(), tool);
161 }
162 }
163
164 pub fn set_enabled(&self, name: &str, enabled: bool) -> bool {
166 if let Ok(mut store) = self.tools.write() {
167 if let Some(tool) = store.get_mut(name) {
168 tool.enabled = enabled;
169 return true;
170 }
171 }
172 false
173 }
174
175 #[must_use]
177 pub fn execute(&self, call: &ToolCall) -> ToolResult {
178 match call.name.as_str() {
179 "calculator" => execute_calculator(call),
180 "code_execution" => execute_code_sandbox(call),
181 "web_search" => ToolResult {
182 tool_call_id: call.id.clone(),
183 name: call.name.clone(),
184 content: String::new(),
185 error: Some("Web search not implemented in sovereign mode".to_string()),
186 },
187 _ => ToolResult {
188 tool_call_id: call.id.clone(),
189 name: call.name.clone(),
190 content: String::new(),
191 error: Some(format!("Unknown tool: {}", call.name)),
192 },
193 }
194 }
195
196 pub fn execute_with_retry(&self, call: &ToolCall, max_retries: usize) -> ToolCallOutcome {
200 let result = self.execute(call);
201
202 if result.error.is_some() && max_retries > 0 {
203 let error_context = format!(
205 "Tool call to '{}' failed: {}. Please fix the arguments and try again.",
206 call.name,
207 result.error.as_deref().unwrap_or("unknown error")
208 );
209 ToolCallOutcome {
210 result,
211 should_retry: true,
212 error_context: Some(error_context),
213 retries_remaining: max_retries - 1,
214 }
215 } else {
216 ToolCallOutcome {
217 result,
218 should_retry: false,
219 error_context: None,
220 retries_remaining: 0,
221 }
222 }
223 }
224}
225
226#[derive(Debug, Clone, Serialize)]
228pub struct ToolCallOutcome {
229 pub result: ToolResult,
230 pub should_retry: bool,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 pub error_context: Option<String>,
233 pub retries_remaining: usize,
234}
235
236impl Default for ToolRegistry {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242fn execute_calculator(call: &ToolCall) -> ToolResult {
244 let expr = call.arguments.get("expression").and_then(|v| v.as_str()).unwrap_or("");
245
246 let result = eval_math(expr);
247
248 ToolResult {
249 tool_call_id: call.id.clone(),
250 name: call.name.clone(),
251 content: match &result {
252 Ok(val) => val.to_string(),
253 Err(_) => String::new(),
254 },
255 error: result.err().map(|e| e.to_string()),
256 }
257}
258
259fn eval_math(expr: &str) -> Result<f64, String> {
261 let tokens: Vec<char> = expr.chars().filter(|c| !c.is_whitespace()).collect();
262 if tokens.is_empty() {
263 return Err("Empty expression".to_string());
264 }
265 let mut pos = 0;
266 let result = parse_expr(&tokens, &mut pos)?;
267 if pos < tokens.len() {
268 return Err(format!("Unexpected character at position {pos}"));
269 }
270 Ok(result)
271}
272
273fn parse_expr(tokens: &[char], pos: &mut usize) -> Result<f64, String> {
274 let mut left = parse_term(tokens, pos)?;
275 while *pos < tokens.len() && (tokens[*pos] == '+' || tokens[*pos] == '-') {
276 let op = tokens[*pos];
277 *pos += 1;
278 let right = parse_term(tokens, pos)?;
279 left = if op == '+' { left + right } else { left - right };
280 }
281 Ok(left)
282}
283
284fn parse_term(tokens: &[char], pos: &mut usize) -> Result<f64, String> {
285 let mut left = parse_factor(tokens, pos)?;
286 while *pos < tokens.len() && (tokens[*pos] == '*' || tokens[*pos] == '/') {
287 let op = tokens[*pos];
288 *pos += 1;
289 let right = parse_factor(tokens, pos)?;
290 if op == '/' && right == 0.0 {
291 return Err("Division by zero".to_string());
292 }
293 left = if op == '*' { left * right } else { left / right };
294 }
295 Ok(left)
296}
297
298fn parse_factor(tokens: &[char], pos: &mut usize) -> Result<f64, String> {
299 if *pos >= tokens.len() {
300 return Err("Unexpected end of expression".to_string());
301 }
302
303 if tokens[*pos] == '-' {
305 *pos += 1;
306 let val = parse_factor(tokens, pos)?;
307 return Ok(-val);
308 }
309
310 if tokens[*pos] == '(' {
312 *pos += 1;
313 let val = parse_expr(tokens, pos)?;
314 if *pos >= tokens.len() || tokens[*pos] != ')' {
315 return Err("Missing closing parenthesis".to_string());
316 }
317 *pos += 1;
318 return Ok(val);
319 }
320
321 let start = *pos;
323 while *pos < tokens.len() && (tokens[*pos].is_ascii_digit() || tokens[*pos] == '.') {
324 *pos += 1;
325 }
326 if start == *pos {
327 return Err(format!("Expected number at position {start}"));
328 }
329 let num_str: String = tokens[start..*pos].iter().collect();
330 num_str.parse::<f64>().map_err(|e| e.to_string())
331}
332
333fn execute_code_sandbox(call: &ToolCall) -> ToolResult {
335 let language = call.arguments.get("language").and_then(|v| v.as_str()).unwrap_or("unknown");
336 let code = call.arguments.get("code").and_then(|v| v.as_str()).unwrap_or("");
337
338 let content = format!(
340 "{{\"stdout\": \"[sandbox dry-run] Would execute {language} code ({} chars)\", \"stderr\": \"\", \"exit_code\": 0}}",
341 code.len()
342 );
343
344 ToolResult { tool_call_id: call.id.clone(), name: call.name.clone(), content, error: None }
345}