1use crate::{RragError, RragResult};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Instant;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ToolResult {
16 pub success: bool,
18
19 pub output: String,
21
22 pub metadata: HashMap<String, serde_json::Value>,
24
25 pub execution_time_ms: u64,
27
28 pub resource_usage: Option<ResourceUsage>,
30}
31
32impl ToolResult {
33 pub fn success(output: impl Into<String>) -> Self {
35 Self {
36 success: true,
37 output: output.into(),
38 metadata: HashMap::new(),
39 execution_time_ms: 0,
40 resource_usage: None,
41 }
42 }
43
44 pub fn error(error: impl Into<String>) -> Self {
46 Self {
47 success: false,
48 output: error.into(),
49 metadata: HashMap::new(),
50 execution_time_ms: 0,
51 resource_usage: None,
52 }
53 }
54
55 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
57 self.metadata.insert(key.into(), value);
58 self
59 }
60
61 pub fn with_timing(mut self, execution_time_ms: u64) -> Self {
63 self.execution_time_ms = execution_time_ms;
64 self
65 }
66
67 pub fn with_resource_usage(mut self, usage: ResourceUsage) -> Self {
69 self.resource_usage = Some(usage);
70 self
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ResourceUsage {
77 pub memory_bytes: Option<u64>,
79
80 pub cpu_time_us: Option<u64>,
82
83 pub network_requests: Option<u32>,
85
86 pub files_accessed: Option<u32>,
88}
89
90#[async_trait]
92pub trait Tool: Send + Sync {
93 fn name(&self) -> &str;
95
96 fn description(&self) -> &str;
98
99 fn schema(&self) -> Option<serde_json::Value> {
101 None
102 }
103
104 async fn execute(&self, input: &str) -> RragResult<ToolResult>;
106
107 async fn execute_with_params(&self, params: serde_json::Value) -> RragResult<ToolResult> {
109 let input = match params {
110 serde_json::Value::String(s) => s,
111 _ => params.to_string(),
112 };
113 self.execute(&input).await
114 }
115
116 fn capabilities(&self) -> Vec<&'static str> {
118 vec![]
119 }
120
121 fn requires_auth(&self) -> bool {
123 false
124 }
125
126 fn category(&self) -> &'static str {
128 "general"
129 }
130
131 fn is_cacheable(&self) -> bool {
133 false
134 }
135
136 fn cost_estimate(&self) -> u32 {
138 1
139 }
140}
141
142#[macro_export]
144macro_rules! rrag_tool {
145 (
146 name: $name:expr,
147 description: $desc:expr,
148 execute: $exec:expr
149 ) => {
150 #[derive(Debug)]
151 pub struct GeneratedTool;
152
153 #[async_trait::async_trait]
154 impl Tool for GeneratedTool {
155 fn name(&self) -> &str {
156 $name
157 }
158
159 fn description(&self) -> &str {
160 $desc
161 }
162
163 async fn execute(&self, input: &str) -> RragResult<ToolResult> {
164 let start = std::time::Instant::now();
165 let result = ($exec)(input).await;
166 let execution_time = start.elapsed().as_millis() as u64;
167
168 match result {
169 Ok(output) => Ok(ToolResult::success(output).with_timing(execution_time)),
170 Err(e) => Ok(ToolResult::error(e.to_string()).with_timing(execution_time)),
171 }
172 }
173 }
174 };
175
176 (
177 name: $name:expr,
178 description: $desc:expr,
179 category: $category:expr,
180 execute: $exec:expr
181 ) => {
182 #[derive(Debug)]
183 pub struct GeneratedTool;
184
185 #[async_trait::async_trait]
186 impl Tool for GeneratedTool {
187 fn name(&self) -> &str {
188 $name
189 }
190
191 fn description(&self) -> &str {
192 $desc
193 }
194
195 fn category(&self) -> &'static str {
196 $category
197 }
198
199 async fn execute(&self, input: &str) -> RragResult<ToolResult> {
200 let start = std::time::Instant::now();
201 let result = ($exec)(input).await;
202 let execution_time = start.elapsed().as_millis() as u64;
203
204 match result {
205 Ok(output) => Ok(ToolResult::success(output).with_timing(execution_time)),
206 Err(e) => Ok(ToolResult::error(e.to_string()).with_timing(execution_time)),
207 }
208 }
209 }
210 };
211}
212
213#[derive(Clone)]
215pub struct ToolRegistry {
216 tools: HashMap<String, Arc<dyn Tool>>,
217}
218
219impl ToolRegistry {
220 pub fn new() -> Self {
222 Self {
223 tools: HashMap::new(),
224 }
225 }
226
227 pub fn with_tools(tools: Vec<Arc<dyn Tool>>) -> Self {
229 let mut registry = HashMap::new();
230 for tool in tools {
231 registry.insert(tool.name().to_string(), tool);
232 }
233
234 Self { tools: registry }
235 }
236
237 pub fn register(&mut self, tool: Arc<dyn Tool>) -> RragResult<()> {
239 let name = tool.name().to_string();
240
241 if self.tools.contains_key(&name) {
242 return Err(RragError::config(
243 "tool_name",
244 "unique name",
245 format!("duplicate: {}", name),
246 ));
247 }
248
249 self.tools.insert(name, tool);
250 Ok(())
251 }
252
253 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
255 self.tools.get(name).cloned()
256 }
257
258 pub fn list_tools(&self) -> Vec<String> {
260 self.tools.keys().cloned().collect()
261 }
262
263 pub fn list_by_category(&self, category: &str) -> Vec<Arc<dyn Tool>> {
265 self.tools
266 .values()
267 .filter(|tool| tool.category() == category)
268 .cloned()
269 .collect()
270 }
271
272 pub fn list_by_capability(&self, capability: &str) -> Vec<Arc<dyn Tool>> {
274 self.tools
275 .values()
276 .filter(|tool| tool.capabilities().contains(&capability))
277 .cloned()
278 .collect()
279 }
280
281 pub async fn execute(&self, tool_name: &str, input: &str) -> RragResult<ToolResult> {
283 let tool = self
284 .get(tool_name)
285 .ok_or_else(|| RragError::tool_execution(tool_name, "Tool not found"))?;
286
287 tool.execute(input).await
288 }
289
290 pub fn get_tool_schemas(&self) -> HashMap<String, serde_json::Value> {
292 self.tools
293 .iter()
294 .filter_map(|(name, tool)| tool.schema().map(|schema| (name.clone(), schema)))
295 .collect()
296 }
297
298 pub fn get_tool_descriptions(&self) -> HashMap<String, String> {
300 self.tools
301 .iter()
302 .map(|(name, tool)| (name.clone(), tool.description().to_string()))
303 .collect()
304 }
305}
306
307impl Default for ToolRegistry {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[derive(Debug)]
315pub struct Calculator;
316
317#[async_trait]
318impl Tool for Calculator {
319 fn name(&self) -> &str {
320 "calculator"
321 }
322
323 fn description(&self) -> &str {
324 "Performs mathematical calculations. Input should be a mathematical expression like '2+2', '10*5', or '15/3'."
325 }
326
327 fn category(&self) -> &'static str {
328 "math"
329 }
330
331 fn capabilities(&self) -> Vec<&'static str> {
332 vec!["math", "calculation", "arithmetic"]
333 }
334
335 fn is_cacheable(&self) -> bool {
336 true }
338
339 async fn execute(&self, input: &str) -> RragResult<ToolResult> {
340 let start = Instant::now();
341
342 match calculate(input) {
343 Ok(result) => {
344 let execution_time = start.elapsed().as_millis() as u64;
345 Ok(ToolResult::success(result.to_string())
346 .with_timing(execution_time)
347 .with_metadata("expression", serde_json::Value::String(input.to_string()))
348 .with_metadata(
349 "result_type",
350 serde_json::Value::String("number".to_string()),
351 ))
352 }
353 Err(e) => {
354 let execution_time = start.elapsed().as_millis() as u64;
355 Ok(ToolResult::error(format!("Calculation error: {}", e))
356 .with_timing(execution_time))
357 }
358 }
359 }
360
361 fn schema(&self) -> Option<serde_json::Value> {
362 Some(serde_json::json!({
363 "type": "object",
364 "properties": {
365 "expression": {
366 "type": "string",
367 "description": "Mathematical expression to evaluate",
368 "examples": ["2+2", "10*5", "15/3", "sqrt(16)", "2^3"]
369 }
370 },
371 "required": ["expression"]
372 }))
373 }
374}
375
376fn calculate(expr: &str) -> RragResult<f64> {
378 let expr = expr.trim().replace(" ", "");
379
380 if let Some(result) = try_parse_number(&expr) {
382 return Ok(result);
383 }
384
385 if let Some(pos) = expr.rfind('+') {
387 let (left, right) = expr.split_at(pos);
388 let right = &right[1..];
389 return Ok(calculate(left)? + calculate(right)?);
390 }
391
392 if let Some(pos) = expr.rfind('-') {
393 if pos > 0 {
394 let (left, right) = expr.split_at(pos);
396 let right = &right[1..];
397 return Ok(calculate(left)? - calculate(right)?);
398 }
399 }
400
401 if let Some(pos) = expr.rfind('*') {
403 let (left, right) = expr.split_at(pos);
404 let right = &right[1..];
405 return Ok(calculate(left)? * calculate(right)?);
406 }
407
408 if let Some(pos) = expr.rfind('/') {
409 let (left, right) = expr.split_at(pos);
410 let right = &right[1..];
411 let right_val = calculate(right)?;
412 if right_val == 0.0 {
413 return Err(RragError::tool_execution("calculator", "Division by zero"));
414 }
415 return Ok(calculate(left)? / right_val);
416 }
417
418 if let Some(pos) = expr.find('^') {
420 let (left, right) = expr.split_at(pos);
421 let right = &right[1..];
422 return Ok(calculate(left)?.powf(calculate(right)?));
423 }
424
425 if expr.starts_with("sqrt(") && expr.ends_with(')') {
427 let inner = &expr[5..expr.len() - 1];
428 let value = calculate(inner)?;
429 if value < 0.0 {
430 return Err(RragError::tool_execution(
431 "calculator",
432 "Square root of negative number",
433 ));
434 }
435 return Ok(value.sqrt());
436 }
437
438 if expr.starts_with("sin(") && expr.ends_with(')') {
439 let inner = &expr[4..expr.len() - 1];
440 return Ok(calculate(inner)?.sin());
441 }
442
443 if expr.starts_with("cos(") && expr.ends_with(')') {
444 let inner = &expr[4..expr.len() - 1];
445 return Ok(calculate(inner)?.cos());
446 }
447
448 if expr.starts_with('(') && expr.ends_with(')') {
450 let inner = &expr[1..expr.len() - 1];
451 return calculate(inner);
452 }
453
454 Err(RragError::tool_execution(
455 "calculator",
456 format!("Invalid expression: {}", expr),
457 ))
458}
459
460fn try_parse_number(s: &str) -> Option<f64> {
461 s.parse().ok()
462}
463
464#[derive(Debug)]
466pub struct EchoTool;
467
468#[async_trait]
469impl Tool for EchoTool {
470 fn name(&self) -> &str {
471 "echo"
472 }
473
474 fn description(&self) -> &str {
475 "Echoes back the input text. Useful for testing and debugging."
476 }
477
478 fn category(&self) -> &'static str {
479 "utility"
480 }
481
482 fn capabilities(&self) -> Vec<&'static str> {
483 vec!["test", "debug", "echo"]
484 }
485
486 async fn execute(&self, input: &str) -> RragResult<ToolResult> {
487 let start = Instant::now();
488 let output = format!("Echo: {}", input);
489 let execution_time = start.elapsed().as_millis() as u64;
490
491 Ok(ToolResult::success(output)
492 .with_timing(execution_time)
493 .with_metadata(
494 "input_length",
495 serde_json::Value::Number(input.len().into()),
496 ))
497 }
498}
499
500#[cfg(feature = "http")]
502#[derive(Debug)]
503pub struct HttpTool {
504 client: reqwest::Client,
505}
506
507#[cfg(feature = "http")]
508impl HttpTool {
509 pub fn new() -> Self {
510 Self {
511 client: reqwest::Client::builder()
512 .timeout(std::time::Duration::from_secs(30))
513 .build()
514 .expect("Failed to create HTTP client"),
515 }
516 }
517}
518
519#[cfg(feature = "http")]
520#[async_trait]
521impl Tool for HttpTool {
522 fn name(&self) -> &str {
523 "http"
524 }
525
526 fn description(&self) -> &str {
527 "Makes HTTP GET requests to fetch web content. Input should be a valid URL."
528 }
529
530 fn category(&self) -> &'static str {
531 "web"
532 }
533
534 fn capabilities(&self) -> Vec<&'static str> {
535 vec!["web", "http", "fetch", "scraping"]
536 }
537
538 async fn execute(&self, input: &str) -> RragResult<ToolResult> {
539 let start = Instant::now();
540
541 let url = input.trim();
542 if !url.starts_with("http://") && !url.starts_with("https://") {
543 let execution_time = start.elapsed().as_millis() as u64;
544 return Ok(ToolResult::error("URL must start with http:// or https://")
545 .with_timing(execution_time));
546 }
547
548 match self.client.get(url).send().await {
549 Ok(response) => {
550 let status = response.status();
551 let headers_count = response.headers().len();
552
553 match response.text().await {
554 Ok(body) => {
555 let execution_time = start.elapsed().as_millis() as u64;
556 let truncated_body = if body.len() > 10000 {
557 format!(
558 "{}... [truncated from {} chars]",
559 &body[..10000],
560 body.len()
561 )
562 } else {
563 body
564 };
565
566 Ok(ToolResult::success(truncated_body)
567 .with_timing(execution_time)
568 .with_metadata(
569 "status_code",
570 serde_json::Value::Number(status.as_u16().into()),
571 )
572 .with_metadata(
573 "headers_count",
574 serde_json::Value::Number(headers_count.into()),
575 )
576 .with_metadata("url", serde_json::Value::String(url.to_string())))
577 }
578 Err(e) => {
579 let execution_time = start.elapsed().as_millis() as u64;
580 Ok(
581 ToolResult::error(format!("Failed to read response body: {}", e))
582 .with_timing(execution_time),
583 )
584 }
585 }
586 }
587 Err(e) => {
588 let execution_time = start.elapsed().as_millis() as u64;
589 Ok(ToolResult::error(format!("HTTP request failed: {}", e))
590 .with_timing(execution_time))
591 }
592 }
593 }
594
595 fn schema(&self) -> Option<serde_json::Value> {
596 Some(serde_json::json!({
597 "type": "object",
598 "properties": {
599 "url": {
600 "type": "string",
601 "format": "uri",
602 "description": "The URL to fetch"
603 }
604 },
605 "required": ["url"]
606 }))
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[tokio::test]
615 async fn test_calculator_tool() {
616 let calc = Calculator;
617
618 let result = calc.execute("2+2").await.unwrap();
619 assert!(result.success);
620 assert_eq!(result.output, "4");
621
622 let result = calc.execute("10*5").await.unwrap();
623 assert!(result.success);
624 assert_eq!(result.output, "50");
625
626 let result = calc.execute("sqrt(16)").await.unwrap();
627 assert!(result.success);
628 assert_eq!(result.output, "4");
629 }
630
631 #[tokio::test]
632 async fn test_echo_tool() {
633 let echo = EchoTool;
634 let result = echo.execute("hello world").await.unwrap();
635
636 assert!(result.success);
637 assert_eq!(result.output, "Echo: hello world");
638 assert!(result.execution_time_ms > 0);
639 }
640
641 #[tokio::test]
642 async fn test_tool_registry() {
643 let mut registry = ToolRegistry::new();
644
645 registry.register(Arc::new(Calculator)).unwrap();
646 registry.register(Arc::new(EchoTool)).unwrap();
647
648 assert_eq!(registry.list_tools().len(), 2);
649 assert!(registry.list_tools().contains(&"calculator".to_string()));
650 assert!(registry.list_tools().contains(&"echo".to_string()));
651
652 let result = registry.execute("calculator", "5*5").await.unwrap();
653 assert!(result.success);
654 assert_eq!(result.output, "25");
655 }
656
657 #[test]
658 fn test_calculator_functions() {
659 assert_eq!(calculate("2+2").unwrap(), 4.0);
660 assert_eq!(calculate("10-3").unwrap(), 7.0);
661 assert_eq!(calculate("4*5").unwrap(), 20.0);
662 assert_eq!(calculate("15/3").unwrap(), 5.0);
663 assert_eq!(calculate("2^3").unwrap(), 8.0);
664 assert_eq!(calculate("sqrt(9)").unwrap(), 3.0);
665 assert_eq!(calculate("(2+3)*4").unwrap(), 20.0);
666 }
667
668 #[test]
669 fn test_calculator_errors() {
670 assert!(calculate("5/0").is_err());
671 assert!(calculate("sqrt(-1)").is_err());
672 assert!(calculate("invalid").is_err());
673 }
674
675 #[test]
676 fn test_tool_categories() {
677 let calc = Calculator;
678 assert_eq!(calc.category(), "math");
679 assert!(calc.capabilities().contains(&"math"));
680 assert!(calc.is_cacheable());
681
682 let echo = EchoTool;
683 assert_eq!(echo.category(), "utility");
684 assert!(echo.capabilities().contains(&"test"));
685 }
686}