1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use thiserror::Error;
7
8#[derive(Debug, Clone, PartialEq)]
10pub struct ToolResult {
11 pub content: Value,
12}
13
14#[derive(Debug, Error, PartialEq)]
16pub enum ToolError {
17 #[error("tool `{name}` not found")]
18 NotFound { name: String },
19 #[error("invalid input: {0}")]
20 InvalidInput(String),
21 #[error("execution error: {0}")]
22 Execution(String),
23}
24
25#[async_trait]
26pub trait Tool: Send + Sync {
27 fn name(&self) -> &'static str;
28 fn json_schema(&self) -> Value;
29 async fn call(&self, args: Value) -> Result<ToolResult, ToolError>;
30}
31
32#[derive(Default)]
34pub struct ToolRegistry {
35 tools: HashMap<String, Arc<dyn Tool>>,
36}
37
38impl ToolRegistry {
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn register<T>(&mut self, tool: T)
44 where
45 T: Tool + 'static,
46 {
47 self.tools.insert(tool.name().to_string(), Arc::new(tool));
48 }
49
50 pub fn register_boxed(&mut self, tool: Box<dyn Tool>) {
51 let name = tool.name().to_string();
52 let arc: Arc<dyn Tool> = tool.into();
53 self.tools.insert(name, arc);
54 }
55
56 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
57 self.tools.get(name).cloned()
58 }
59
60 pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolResult, ToolError> {
61 match self.get(name) {
62 Some(tool) => tool.call(args).await,
63 None => Err(ToolError::NotFound {
64 name: name.to_string(),
65 }),
66 }
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 struct EchoTool;
75
76 #[async_trait]
77 impl Tool for EchoTool {
78 fn name(&self) -> &'static str {
79 "echo"
80 }
81
82 fn json_schema(&self) -> Value {
83 serde_json::json!({
84 "type": "object",
85 "properties": {
86 "message": {"type": "string"}
87 },
88 "required": ["message"]
89 })
90 }
91
92 async fn call(&self, args: Value) -> Result<ToolResult, ToolError> {
93 let msg = args
94 .get("message")
95 .and_then(Value::as_str)
96 .ok_or_else(|| ToolError::InvalidInput("missing message".into()))?;
97 Ok(ToolResult {
98 content: serde_json::json!({ "echo": msg }),
99 })
100 }
101 }
102
103 #[tokio::test]
104 async fn registry_registers_and_invokes_tool() {
105 let mut registry = ToolRegistry::new();
106 registry.register(EchoTool);
107 let args = serde_json::json!({ "message": "hello" });
108 let result = registry.invoke("echo", args).await.expect("tool result");
109 assert_eq!(result.content, serde_json::json!({ "echo": "hello" }));
110 }
111
112 #[tokio::test]
113 async fn registry_returns_not_found() {
114 let registry = ToolRegistry::new();
115 let err = registry
116 .invoke("missing", serde_json::json!({}))
117 .await
118 .unwrap_err();
119 assert!(matches!(err, ToolError::NotFound { .. }));
120 }
121}