rexis_rag/
tools.rs

1//! # RRAG Tools System
2//!
3//! Type-safe tool system leveraging Rust's trait system for zero-cost abstractions.
4//! Designed for async execution with proper error handling and resource management.
5
6use crate::{RragError, RragResult};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Instant;
12
13/// Tool execution result with comprehensive metadata
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ToolResult {
16    /// Whether the tool executed successfully
17    pub success: bool,
18
19    /// Tool output content
20    pub output: String,
21
22    /// Execution metadata
23    pub metadata: HashMap<String, serde_json::Value>,
24
25    /// Execution time in milliseconds
26    pub execution_time_ms: u64,
27
28    /// Resource usage information
29    pub resource_usage: Option<ResourceUsage>,
30}
31
32impl ToolResult {
33    /// Create a successful result
34    pub fn success(output: impl Into<String>) -> Self {
35        Self {
36            success: true,
37            output: output.into(),
38            metadata: HashMap::new(),
39            execution_time_ms: 0,
40            resource_usage: None,
41        }
42    }
43
44    /// Create an error result
45    pub fn error(error: impl Into<String>) -> Self {
46        Self {
47            success: false,
48            output: error.into(),
49            metadata: HashMap::new(),
50            execution_time_ms: 0,
51            resource_usage: None,
52        }
53    }
54
55    /// Add metadata using builder pattern
56    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
57        self.metadata.insert(key.into(), value);
58        self
59    }
60
61    /// Set execution timing
62    pub fn with_timing(mut self, execution_time_ms: u64) -> Self {
63        self.execution_time_ms = execution_time_ms;
64        self
65    }
66
67    /// Set resource usage
68    pub fn with_resource_usage(mut self, usage: ResourceUsage) -> Self {
69        self.resource_usage = Some(usage);
70        self
71    }
72}
73
74/// Resource usage tracking for tools
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ResourceUsage {
77    /// Memory allocated in bytes
78    pub memory_bytes: Option<u64>,
79
80    /// CPU time used in microseconds
81    pub cpu_time_us: Option<u64>,
82
83    /// Network requests made
84    pub network_requests: Option<u32>,
85
86    /// Files accessed
87    pub files_accessed: Option<u32>,
88}
89
90/// Core tool trait optimized for Rust's async ecosystem
91#[async_trait]
92pub trait Tool: Send + Sync {
93    /// Tool identifier (used for registration and calling)
94    fn name(&self) -> &str;
95
96    /// Human-readable description for LLM context
97    fn description(&self) -> &str;
98
99    /// JSON schema for parameter validation (optional)
100    fn schema(&self) -> Option<serde_json::Value> {
101        None
102    }
103
104    /// Execute the tool with string input
105    async fn execute(&self, input: &str) -> RragResult<ToolResult>;
106
107    /// Execute with structured parameters (default delegates to execute)
108    async fn execute_with_params(&self, params: serde_json::Value) -> RragResult<ToolResult> {
109        let input = match params {
110            serde_json::Value::String(s) => s,
111            _ => params.to_string(),
112        };
113        self.execute(&input).await
114    }
115
116    /// Tool capabilities for filtering and discovery
117    fn capabilities(&self) -> Vec<&'static str> {
118        vec![]
119    }
120
121    /// Whether this tool requires authentication
122    fn requires_auth(&self) -> bool {
123        false
124    }
125
126    /// Tool category for organization
127    fn category(&self) -> &'static str {
128        "general"
129    }
130
131    /// Whether this tool can be cached
132    fn is_cacheable(&self) -> bool {
133        false
134    }
135
136    /// Cost estimate for execution (arbitrary units)
137    fn cost_estimate(&self) -> u32 {
138        1
139    }
140}
141
142/// Macro for creating simple tools with less boilerplate
143#[macro_export]
144macro_rules! rrag_tool {
145    (
146        name: $name:expr,
147        description: $desc:expr,
148        execute: $exec:expr
149    ) => {
150        #[derive(Debug)]
151        pub struct GeneratedTool;
152
153        #[async_trait::async_trait]
154        impl Tool for GeneratedTool {
155            fn name(&self) -> &str {
156                $name
157            }
158
159            fn description(&self) -> &str {
160                $desc
161            }
162
163            async fn execute(&self, input: &str) -> RragResult<ToolResult> {
164                let start = std::time::Instant::now();
165                let result = ($exec)(input).await;
166                let execution_time = start.elapsed().as_millis() as u64;
167
168                match result {
169                    Ok(output) => Ok(ToolResult::success(output).with_timing(execution_time)),
170                    Err(e) => Ok(ToolResult::error(e.to_string()).with_timing(execution_time)),
171                }
172            }
173        }
174    };
175
176    (
177        name: $name:expr,
178        description: $desc:expr,
179        category: $category:expr,
180        execute: $exec:expr
181    ) => {
182        #[derive(Debug)]
183        pub struct GeneratedTool;
184
185        #[async_trait::async_trait]
186        impl Tool for GeneratedTool {
187            fn name(&self) -> &str {
188                $name
189            }
190
191            fn description(&self) -> &str {
192                $desc
193            }
194
195            fn category(&self) -> &'static str {
196                $category
197            }
198
199            async fn execute(&self, input: &str) -> RragResult<ToolResult> {
200                let start = std::time::Instant::now();
201                let result = ($exec)(input).await;
202                let execution_time = start.elapsed().as_millis() as u64;
203
204                match result {
205                    Ok(output) => Ok(ToolResult::success(output).with_timing(execution_time)),
206                    Err(e) => Ok(ToolResult::error(e.to_string()).with_timing(execution_time)),
207                }
208            }
209        }
210    };
211}
212
213/// Thread-safe tool registry using Arc for efficient sharing
214#[derive(Clone)]
215pub struct ToolRegistry {
216    tools: HashMap<String, Arc<dyn Tool>>,
217}
218
219impl ToolRegistry {
220    /// Create a new empty registry
221    pub fn new() -> Self {
222        Self {
223            tools: HashMap::new(),
224        }
225    }
226
227    /// Create registry with pre-registered tools
228    pub fn with_tools(tools: Vec<Arc<dyn Tool>>) -> Self {
229        let mut registry = HashMap::new();
230        for tool in tools {
231            registry.insert(tool.name().to_string(), tool);
232        }
233
234        Self { tools: registry }
235    }
236
237    /// Register a new tool
238    pub fn register(&mut self, tool: Arc<dyn Tool>) -> RragResult<()> {
239        let name = tool.name().to_string();
240
241        if self.tools.contains_key(&name) {
242            return Err(RragError::config(
243                "tool_name",
244                "unique name",
245                format!("duplicate: {}", name),
246            ));
247        }
248
249        self.tools.insert(name, tool);
250        Ok(())
251    }
252
253    /// Get a tool by name
254    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
255        self.tools.get(name).cloned()
256    }
257
258    /// List all registered tool names
259    pub fn list_tools(&self) -> Vec<String> {
260        self.tools.keys().cloned().collect()
261    }
262
263    /// List tools by category
264    pub fn list_by_category(&self, category: &str) -> Vec<Arc<dyn Tool>> {
265        self.tools
266            .values()
267            .filter(|tool| tool.category() == category)
268            .cloned()
269            .collect()
270    }
271
272    /// List tools by capability
273    pub fn list_by_capability(&self, capability: &str) -> Vec<Arc<dyn Tool>> {
274        self.tools
275            .values()
276            .filter(|tool| tool.capabilities().contains(&capability))
277            .cloned()
278            .collect()
279    }
280
281    /// Execute a tool by name
282    pub async fn execute(&self, tool_name: &str, input: &str) -> RragResult<ToolResult> {
283        let tool = self
284            .get(tool_name)
285            .ok_or_else(|| RragError::tool_execution(tool_name, "Tool not found"))?;
286
287        tool.execute(input).await
288    }
289
290    /// Get tool schemas for LLM context
291    pub fn get_tool_schemas(&self) -> HashMap<String, serde_json::Value> {
292        self.tools
293            .iter()
294            .filter_map(|(name, tool)| tool.schema().map(|schema| (name.clone(), schema)))
295            .collect()
296    }
297
298    /// Get tool descriptions for LLM context
299    pub fn get_tool_descriptions(&self) -> HashMap<String, String> {
300        self.tools
301            .iter()
302            .map(|(name, tool)| (name.clone(), tool.description().to_string()))
303            .collect()
304    }
305}
306
307impl Default for ToolRegistry {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313/// Built-in calculator tool
314#[derive(Debug)]
315pub struct Calculator;
316
317#[async_trait]
318impl Tool for Calculator {
319    fn name(&self) -> &str {
320        "calculator"
321    }
322
323    fn description(&self) -> &str {
324        "Performs mathematical calculations. Input should be a mathematical expression like '2+2', '10*5', or '15/3'."
325    }
326
327    fn category(&self) -> &'static str {
328        "math"
329    }
330
331    fn capabilities(&self) -> Vec<&'static str> {
332        vec!["math", "calculation", "arithmetic"]
333    }
334
335    fn is_cacheable(&self) -> bool {
336        true // Math results are deterministic
337    }
338
339    async fn execute(&self, input: &str) -> RragResult<ToolResult> {
340        let start = Instant::now();
341
342        match calculate(input) {
343            Ok(result) => {
344                let execution_time = start.elapsed().as_millis() as u64;
345                Ok(ToolResult::success(result.to_string())
346                    .with_timing(execution_time)
347                    .with_metadata("expression", serde_json::Value::String(input.to_string()))
348                    .with_metadata(
349                        "result_type",
350                        serde_json::Value::String("number".to_string()),
351                    ))
352            }
353            Err(e) => {
354                let execution_time = start.elapsed().as_millis() as u64;
355                Ok(ToolResult::error(format!("Calculation error: {}", e))
356                    .with_timing(execution_time))
357            }
358        }
359    }
360
361    fn schema(&self) -> Option<serde_json::Value> {
362        Some(serde_json::json!({
363            "type": "object",
364            "properties": {
365                "expression": {
366                    "type": "string",
367                    "description": "Mathematical expression to evaluate",
368                    "examples": ["2+2", "10*5", "15/3", "sqrt(16)", "2^3"]
369                }
370            },
371            "required": ["expression"]
372        }))
373    }
374}
375
376/// Simple calculator implementation
377fn calculate(expr: &str) -> RragResult<f64> {
378    let expr = expr.trim().replace(" ", "");
379
380    // Handle basic operations in order of precedence
381    if let Some(result) = try_parse_number(&expr) {
382        return Ok(result);
383    }
384
385    // Addition and subtraction (lowest precedence)
386    if let Some(pos) = expr.rfind('+') {
387        let (left, right) = expr.split_at(pos);
388        let right = &right[1..];
389        return Ok(calculate(left)? + calculate(right)?);
390    }
391
392    if let Some(pos) = expr.rfind('-') {
393        if pos > 0 {
394            // Avoid treating negative numbers as subtraction
395            let (left, right) = expr.split_at(pos);
396            let right = &right[1..];
397            return Ok(calculate(left)? - calculate(right)?);
398        }
399    }
400
401    // Multiplication and division
402    if let Some(pos) = expr.rfind('*') {
403        let (left, right) = expr.split_at(pos);
404        let right = &right[1..];
405        return Ok(calculate(left)? * calculate(right)?);
406    }
407
408    if let Some(pos) = expr.rfind('/') {
409        let (left, right) = expr.split_at(pos);
410        let right = &right[1..];
411        let right_val = calculate(right)?;
412        if right_val == 0.0 {
413            return Err(RragError::tool_execution("calculator", "Division by zero"));
414        }
415        return Ok(calculate(left)? / right_val);
416    }
417
418    // Power operation
419    if let Some(pos) = expr.find('^') {
420        let (left, right) = expr.split_at(pos);
421        let right = &right[1..];
422        return Ok(calculate(left)?.powf(calculate(right)?));
423    }
424
425    // Functions
426    if expr.starts_with("sqrt(") && expr.ends_with(')') {
427        let inner = &expr[5..expr.len() - 1];
428        let value = calculate(inner)?;
429        if value < 0.0 {
430            return Err(RragError::tool_execution(
431                "calculator",
432                "Square root of negative number",
433            ));
434        }
435        return Ok(value.sqrt());
436    }
437
438    if expr.starts_with("sin(") && expr.ends_with(')') {
439        let inner = &expr[4..expr.len() - 1];
440        return Ok(calculate(inner)?.sin());
441    }
442
443    if expr.starts_with("cos(") && expr.ends_with(')') {
444        let inner = &expr[4..expr.len() - 1];
445        return Ok(calculate(inner)?.cos());
446    }
447
448    // Parentheses
449    if expr.starts_with('(') && expr.ends_with(')') {
450        let inner = &expr[1..expr.len() - 1];
451        return calculate(inner);
452    }
453
454    Err(RragError::tool_execution(
455        "calculator",
456        format!("Invalid expression: {}", expr),
457    ))
458}
459
460fn try_parse_number(s: &str) -> Option<f64> {
461    s.parse().ok()
462}
463
464/// Echo tool for testing and debugging
465#[derive(Debug)]
466pub struct EchoTool;
467
468#[async_trait]
469impl Tool for EchoTool {
470    fn name(&self) -> &str {
471        "echo"
472    }
473
474    fn description(&self) -> &str {
475        "Echoes back the input text. Useful for testing and debugging."
476    }
477
478    fn category(&self) -> &'static str {
479        "utility"
480    }
481
482    fn capabilities(&self) -> Vec<&'static str> {
483        vec!["test", "debug", "echo"]
484    }
485
486    async fn execute(&self, input: &str) -> RragResult<ToolResult> {
487        let start = Instant::now();
488        let output = format!("Echo: {}", input);
489        let execution_time = start.elapsed().as_millis() as u64;
490
491        Ok(ToolResult::success(output)
492            .with_timing(execution_time)
493            .with_metadata(
494                "input_length",
495                serde_json::Value::Number(input.len().into()),
496            ))
497    }
498}
499
500/// HTTP client tool for web requests (requires "http" feature)
501#[cfg(feature = "http")]
502#[derive(Debug)]
503pub struct HttpTool {
504    client: reqwest::Client,
505}
506
507#[cfg(feature = "http")]
508impl HttpTool {
509    pub fn new() -> Self {
510        Self {
511            client: reqwest::Client::builder()
512                .timeout(std::time::Duration::from_secs(30))
513                .build()
514                .expect("Failed to create HTTP client"),
515        }
516    }
517}
518
519#[cfg(feature = "http")]
520#[async_trait]
521impl Tool for HttpTool {
522    fn name(&self) -> &str {
523        "http"
524    }
525
526    fn description(&self) -> &str {
527        "Makes HTTP GET requests to fetch web content. Input should be a valid URL."
528    }
529
530    fn category(&self) -> &'static str {
531        "web"
532    }
533
534    fn capabilities(&self) -> Vec<&'static str> {
535        vec!["web", "http", "fetch", "scraping"]
536    }
537
538    async fn execute(&self, input: &str) -> RragResult<ToolResult> {
539        let start = Instant::now();
540
541        let url = input.trim();
542        if !url.starts_with("http://") && !url.starts_with("https://") {
543            let execution_time = start.elapsed().as_millis() as u64;
544            return Ok(ToolResult::error("URL must start with http:// or https://")
545                .with_timing(execution_time));
546        }
547
548        match self.client.get(url).send().await {
549            Ok(response) => {
550                let status = response.status();
551                let headers_count = response.headers().len();
552
553                match response.text().await {
554                    Ok(body) => {
555                        let execution_time = start.elapsed().as_millis() as u64;
556                        let truncated_body = if body.len() > 10000 {
557                            format!(
558                                "{}... [truncated from {} chars]",
559                                &body[..10000],
560                                body.len()
561                            )
562                        } else {
563                            body
564                        };
565
566                        Ok(ToolResult::success(truncated_body)
567                            .with_timing(execution_time)
568                            .with_metadata(
569                                "status_code",
570                                serde_json::Value::Number(status.as_u16().into()),
571                            )
572                            .with_metadata(
573                                "headers_count",
574                                serde_json::Value::Number(headers_count.into()),
575                            )
576                            .with_metadata("url", serde_json::Value::String(url.to_string())))
577                    }
578                    Err(e) => {
579                        let execution_time = start.elapsed().as_millis() as u64;
580                        Ok(
581                            ToolResult::error(format!("Failed to read response body: {}", e))
582                                .with_timing(execution_time),
583                        )
584                    }
585                }
586            }
587            Err(e) => {
588                let execution_time = start.elapsed().as_millis() as u64;
589                Ok(ToolResult::error(format!("HTTP request failed: {}", e))
590                    .with_timing(execution_time))
591            }
592        }
593    }
594
595    fn schema(&self) -> Option<serde_json::Value> {
596        Some(serde_json::json!({
597            "type": "object",
598            "properties": {
599                "url": {
600                    "type": "string",
601                    "format": "uri",
602                    "description": "The URL to fetch"
603                }
604            },
605            "required": ["url"]
606        }))
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[tokio::test]
615    async fn test_calculator_tool() {
616        let calc = Calculator;
617
618        let result = calc.execute("2+2").await.unwrap();
619        assert!(result.success);
620        assert_eq!(result.output, "4");
621
622        let result = calc.execute("10*5").await.unwrap();
623        assert!(result.success);
624        assert_eq!(result.output, "50");
625
626        let result = calc.execute("sqrt(16)").await.unwrap();
627        assert!(result.success);
628        assert_eq!(result.output, "4");
629    }
630
631    #[tokio::test]
632    async fn test_echo_tool() {
633        let echo = EchoTool;
634        let result = echo.execute("hello world").await.unwrap();
635
636        assert!(result.success);
637        assert_eq!(result.output, "Echo: hello world");
638        assert!(result.execution_time_ms > 0);
639    }
640
641    #[tokio::test]
642    async fn test_tool_registry() {
643        let mut registry = ToolRegistry::new();
644
645        registry.register(Arc::new(Calculator)).unwrap();
646        registry.register(Arc::new(EchoTool)).unwrap();
647
648        assert_eq!(registry.list_tools().len(), 2);
649        assert!(registry.list_tools().contains(&"calculator".to_string()));
650        assert!(registry.list_tools().contains(&"echo".to_string()));
651
652        let result = registry.execute("calculator", "5*5").await.unwrap();
653        assert!(result.success);
654        assert_eq!(result.output, "25");
655    }
656
657    #[test]
658    fn test_calculator_functions() {
659        assert_eq!(calculate("2+2").unwrap(), 4.0);
660        assert_eq!(calculate("10-3").unwrap(), 7.0);
661        assert_eq!(calculate("4*5").unwrap(), 20.0);
662        assert_eq!(calculate("15/3").unwrap(), 5.0);
663        assert_eq!(calculate("2^3").unwrap(), 8.0);
664        assert_eq!(calculate("sqrt(9)").unwrap(), 3.0);
665        assert_eq!(calculate("(2+3)*4").unwrap(), 20.0);
666    }
667
668    #[test]
669    fn test_calculator_errors() {
670        assert!(calculate("5/0").is_err());
671        assert!(calculate("sqrt(-1)").is_err());
672        assert!(calculate("invalid").is_err());
673    }
674
675    #[test]
676    fn test_tool_categories() {
677        let calc = Calculator;
678        assert_eq!(calc.category(), "math");
679        assert!(calc.capabilities().contains(&"math"));
680        assert!(calc.is_cacheable());
681
682        let echo = EchoTool;
683        assert_eq!(echo.category(), "utility");
684        assert!(echo.capabilities().contains(&"test"));
685    }
686}