rexis_llm/
tools.rs

1//! # Tool Calling Support
2//!
3//! Simple, ergonomic tool/function calling for LLMs.
4//! Inspired by LangChain but with Rust-native patterns.
5//!
6//! ## Quick Start
7//!
8//! ```rust,ignore
9//! use rsllm::tools::{Tool, ToolRegistry, ToolCall};
10//!
11//! // Define a tool
12//! struct Calculator;
13//!
14//! impl Tool for Calculator {
15//!     fn name(&self) -> &str { "calculator" }
16//!     fn description(&self) -> &str { "Performs basic arithmetic" }
17//!
18//!     fn parameters_schema(&self) -> serde_json::Value {
19//!         json!({
20//!             "type": "object",
21//!             "properties": {
22//!                 "operation": {"type": "string", "enum": ["add", "subtract", "multiply", "divide"]},
23//!                 "a": {"type": "number"},
24//!                 "b": {"type": "number"}
25//!             },
26//!             "required": ["operation", "a", "b"]
27//!         })
28//!     }
29//!
30//!     fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, Box<dyn Error>> {
31//!         // Implementation
32//!     }
33//! }
34//!
35//! // Use the tool
36//! let mut registry = ToolRegistry::new();
37//! registry.register(Box::new(Calculator));
38//! ```
39
40use serde::{Deserialize, Serialize};
41use serde_json::Value as JsonValue;
42use std::collections::HashMap;
43use std::error::Error;
44use std::fmt;
45
46#[cfg(feature = "json-schema")]
47use schemars::{schema_for, JsonSchema};
48
49/// A tool that can be called by an LLM
50pub trait Tool: Send + Sync {
51    /// The name of the tool (must be unique)
52    fn name(&self) -> &str;
53
54    /// Human-readable description of what the tool does
55    fn description(&self) -> &str;
56
57    /// JSON Schema describing the tool's parameters
58    fn parameters_schema(&self) -> JsonValue;
59
60    /// Execute the tool with the given arguments
61    fn execute(&self, args: JsonValue) -> Result<JsonValue, Box<dyn Error + Send + Sync>>;
62
63    /// Optional: Validate arguments before execution
64    fn validate(&self, _args: &JsonValue) -> Result<(), Box<dyn Error + Send + Sync>> {
65        Ok(())
66    }
67}
68
69/// Schema-based tool with automatic JSON Schema generation
70///
71/// This trait provides automatic schema generation using `schemars`.
72/// Your parameter types just need to derive `JsonSchema`:
73///
74/// ```rust,ignore
75/// #[derive(JsonSchema, Serialize, Deserialize)]
76/// struct MyParams {
77///     name: String,
78///     age: u32,
79/// }
80///
81/// impl SchemaBasedTool for MyTool {
82///     type Params = MyParams;
83///
84///     fn name(&self) -> &str { "my_tool" }
85///     fn description(&self) -> &str { "Does something" }
86///
87///     fn execute_typed(&self, params: Self::Params) -> Result<JsonValue, Box<dyn Error + Send + Sync>> {
88///         // Work with strongly-typed params!
89///         Ok(json!({"result": params.name}))
90///     }
91/// }
92/// ```
93#[cfg(feature = "json-schema")]
94pub trait SchemaBasedTool: Send + Sync {
95    /// The parameter type (must derive JsonSchema, Serialize, Deserialize)
96    type Params: JsonSchema + for<'de> Deserialize<'de>;
97
98    /// Tool name
99    fn name(&self) -> &str;
100
101    /// Tool description
102    fn description(&self) -> &str;
103
104    /// Execute with strongly-typed parameters
105    fn execute_typed(
106        &self,
107        params: Self::Params,
108    ) -> Result<JsonValue, Box<dyn Error + Send + Sync>>;
109
110    /// Optional: Validate typed parameters before execution
111    fn validate_typed(&self, _params: &Self::Params) -> Result<(), Box<dyn Error + Send + Sync>> {
112        Ok(())
113    }
114}
115
116/// Blanket implementation: SchemaBasedTool automatically implements Tool
117#[cfg(feature = "json-schema")]
118impl<T: SchemaBasedTool> Tool for T {
119    fn name(&self) -> &str {
120        SchemaBasedTool::name(self)
121    }
122
123    fn description(&self) -> &str {
124        SchemaBasedTool::description(self)
125    }
126
127    fn parameters_schema(&self) -> JsonValue {
128        // Automatically generate schema from the Params type!
129        // Uses JSON Schema Draft 7 by default (OpenAI compatible)
130        let schema = schema_for!(T::Params);
131        serde_json::to_value(&schema).unwrap_or_else(|_| JsonValue::Null)
132    }
133
134    fn execute(&self, args: JsonValue) -> Result<JsonValue, Box<dyn Error + Send + Sync>> {
135        // Deserialize to strongly-typed params
136        let params: T::Params = serde_json::from_value(args)
137            .map_err(|e| format!("Failed to deserialize parameters: {}", e))?;
138
139        // Validate typed params
140        self.validate_typed(&params)?;
141
142        // Execute with typed params
143        self.execute_typed(params)
144    }
145
146    fn validate(&self, args: &JsonValue) -> Result<(), Box<dyn Error + Send + Sync>> {
147        // Deserialize and validate
148        let params: T::Params = serde_json::from_value(args.clone())
149            .map_err(|e| format!("Invalid parameters: {}", e))?;
150        self.validate_typed(&params)
151    }
152}
153
154/// Tool definition for serialization to LLM API
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ToolDefinition {
157    /// Tool name
158    pub name: String,
159
160    /// Tool description
161    pub description: String,
162
163    /// Parameters schema (JSON Schema)
164    pub parameters: JsonValue,
165
166    /// Tool type (default: "function")
167    #[serde(rename = "type", default = "default_tool_type")]
168    pub tool_type: String,
169}
170
171fn default_tool_type() -> String {
172    "function".to_string()
173}
174
175impl ToolDefinition {
176    /// Create a new tool definition
177    pub fn new(
178        name: impl Into<String>,
179        description: impl Into<String>,
180        parameters: JsonValue,
181    ) -> Self {
182        Self {
183            name: name.into(),
184            description: description.into(),
185            parameters,
186            tool_type: "function".to_string(),
187        }
188    }
189
190    /// Create from a Tool trait object
191    pub fn from_tool(tool: &dyn Tool) -> Self {
192        Self::new(tool.name(), tool.description(), tool.parameters_schema())
193    }
194}
195
196/// A tool call request from the LLM
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ToolCall {
199    /// Unique ID for this tool call
200    pub id: String,
201
202    /// Tool name to call
203    pub name: String,
204
205    /// Arguments as JSON
206    pub arguments: JsonValue,
207}
208
209impl ToolCall {
210    /// Create a new tool call
211    pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: JsonValue) -> Self {
212        Self {
213            id: id.into(),
214            name: name.into(),
215            arguments,
216        }
217    }
218}
219
220/// Result of executing a tool
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ToolResult {
223    /// Tool call ID this result is for
224    pub tool_call_id: String,
225
226    /// Tool name
227    pub tool_name: String,
228
229    /// Result content as JSON
230    pub content: JsonValue,
231
232    /// Whether the tool execution was successful
233    pub success: bool,
234
235    /// Error message if execution failed
236    pub error: Option<String>,
237}
238
239impl ToolResult {
240    /// Create a successful tool result
241    pub fn success(
242        tool_call_id: impl Into<String>,
243        tool_name: impl Into<String>,
244        content: JsonValue,
245    ) -> Self {
246        Self {
247            tool_call_id: tool_call_id.into(),
248            tool_name: tool_name.into(),
249            content,
250            success: true,
251            error: None,
252        }
253    }
254
255    /// Create a failed tool result
256    pub fn error(
257        tool_call_id: impl Into<String>,
258        tool_name: impl Into<String>,
259        error: impl Into<String>,
260    ) -> Self {
261        Self {
262            tool_call_id: tool_call_id.into(),
263            tool_name: tool_name.into(),
264            content: JsonValue::Null,
265            success: false,
266            error: Some(error.into()),
267        }
268    }
269}
270
271/// Registry for managing tools
272pub struct ToolRegistry {
273    tools: HashMap<String, Box<dyn Tool>>,
274}
275
276impl ToolRegistry {
277    /// Create a new empty tool registry
278    pub fn new() -> Self {
279        Self {
280            tools: HashMap::new(),
281        }
282    }
283
284    /// Register a tool
285    pub fn register(&mut self, tool: Box<dyn Tool>) -> Result<(), ToolRegistryError> {
286        let name = tool.name().to_string();
287
288        if self.tools.contains_key(&name) {
289            return Err(ToolRegistryError::DuplicateTool(name));
290        }
291
292        self.tools.insert(name, tool);
293        Ok(())
294    }
295
296    /// Get a tool by name
297    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
298        self.tools.get(name).map(|b| b.as_ref())
299    }
300
301    /// Check if a tool exists
302    pub fn contains(&self, name: &str) -> bool {
303        self.tools.contains_key(name)
304    }
305
306    /// Get all registered tool names
307    pub fn tool_names(&self) -> Vec<&str> {
308        self.tools.keys().map(|s| s.as_str()).collect()
309    }
310
311    /// Get all tool definitions for LLM API
312    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
313        self.tools
314            .values()
315            .map(|tool| ToolDefinition::from_tool(tool.as_ref()))
316            .collect()
317    }
318
319    /// Execute a tool call
320    pub fn execute(&self, tool_call: &ToolCall) -> ToolResult {
321        match self.get(&tool_call.name) {
322            Some(tool) => {
323                // Validate arguments
324                if let Err(e) = tool.validate(&tool_call.arguments) {
325                    return ToolResult::error(
326                        &tool_call.id,
327                        &tool_call.name,
328                        format!("Validation failed: {}", e),
329                    );
330                }
331
332                // Execute the tool
333                match tool.execute(tool_call.arguments.clone()) {
334                    Ok(result) => ToolResult::success(&tool_call.id, &tool_call.name, result),
335                    Err(e) => ToolResult::error(&tool_call.id, &tool_call.name, e.to_string()),
336                }
337            }
338            None => ToolResult::error(
339                &tool_call.id,
340                &tool_call.name,
341                format!("Tool '{}' not found", tool_call.name),
342            ),
343        }
344    }
345
346    /// Execute multiple tool calls
347    pub fn execute_batch(&self, tool_calls: &[ToolCall]) -> Vec<ToolResult> {
348        tool_calls.iter().map(|tc| self.execute(tc)).collect()
349    }
350
351    /// Number of registered tools
352    pub fn len(&self) -> usize {
353        self.tools.len()
354    }
355
356    /// Check if registry is empty
357    pub fn is_empty(&self) -> bool {
358        self.tools.is_empty()
359    }
360}
361
362impl Default for ToolRegistry {
363    fn default() -> Self {
364        Self::new()
365    }
366}
367
368impl fmt::Debug for ToolRegistry {
369    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370        f.debug_struct("ToolRegistry")
371            .field("tools", &self.tool_names())
372            .finish()
373    }
374}
375
376/// Errors that can occur with tool registry
377#[derive(Debug, Clone)]
378pub enum ToolRegistryError {
379    /// Tool with this name already exists
380    DuplicateTool(String),
381}
382
383impl fmt::Display for ToolRegistryError {
384    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385        match self {
386            ToolRegistryError::DuplicateTool(name) => {
387                write!(f, "Tool '{}' is already registered", name)
388            }
389        }
390    }
391}
392
393impl Error for ToolRegistryError {}
394
395// Helper macro for creating simple tools
396#[macro_export]
397macro_rules! simple_tool {
398    (
399        name: $name:expr,
400        description: $desc:expr,
401        parameters: $params:expr,
402        execute: |$args:ident| $body:expr
403    ) => {{
404        struct SimpleTool;
405        impl $crate::tools::Tool for SimpleTool {
406            fn name(&self) -> &str {
407                $name
408            }
409            fn description(&self) -> &str {
410                $desc
411            }
412            fn parameters_schema(&self) -> serde_json::Value {
413                $params
414            }
415            fn execute(
416                &self,
417                $args: serde_json::Value,
418            ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
419                Ok($body)
420            }
421        }
422        Box::new(SimpleTool)
423    }};
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use serde_json::json;
430
431    struct TestTool;
432    impl Tool for TestTool {
433        fn name(&self) -> &str {
434            "test_tool"
435        }
436        fn description(&self) -> &str {
437            "A test tool"
438        }
439        fn parameters_schema(&self) -> JsonValue {
440            json!({"type": "object", "properties": {"input": {"type": "string"}}})
441        }
442        fn execute(&self, args: JsonValue) -> Result<JsonValue, Box<dyn Error + Send + Sync>> {
443            Ok(json!({"result": format!("Processed: {}", args["input"])}))
444        }
445    }
446
447    #[test]
448    fn test_tool_registry() {
449        let mut registry = ToolRegistry::new();
450        assert_eq!(registry.len(), 0);
451
452        registry.register(Box::new(TestTool)).unwrap();
453        assert_eq!(registry.len(), 1);
454        assert!(registry.contains("test_tool"));
455    }
456
457    #[test]
458    fn test_tool_execution() {
459        let mut registry = ToolRegistry::new();
460        registry.register(Box::new(TestTool)).unwrap();
461
462        let call = ToolCall::new("call-1", "test_tool", json!({"input": "hello"}));
463        let result = registry.execute(&call);
464
465        assert!(result.success);
466        assert_eq!(result.tool_name, "test_tool");
467    }
468
469    #[test]
470    fn test_simple_tool_macro() {
471        let tool = simple_tool!(
472            name: "echo",
473            description: "Echoes input",
474            parameters: json!({"type": "object", "properties": {"text": {"type": "string"}}}),
475            execute: |args| {
476                json!({"echo": args["text"]})
477            }
478        );
479
480        assert_eq!(tool.name(), "echo");
481        let result = tool.execute(json!({"text": "hello"})).unwrap();
482        assert_eq!(result["echo"], "hello");
483    }
484}