sh_layer2/
tool_registry.rs1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::types::{Layer2Error, Layer2Result, ToolResult};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolMeta {
15 pub name: String,
16 pub description: String,
17 pub parameters: serde_json::Value,
18 pub required: Vec<String>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ToolRequest {
24 pub tool_call_id: String,
25 pub name: String,
26 pub arguments: serde_json::Value,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ToolDefinition {
32 pub r#type: String,
33 pub function: FunctionDefinition,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct FunctionDefinition {
39 pub name: String,
40 pub description: String,
41 pub parameters: serde_json::Value,
42}
43
44#[async_trait]
48pub trait Tool: Send + Sync {
49 fn name(&self) -> &str;
51
52 fn description(&self) -> &str;
54
55 fn parameters(&self) -> serde_json::Value;
57
58 async fn execute(&self, args: &str) -> Layer2Result<ToolResult>;
60
61 fn validate_args(&self, _args: &serde_json::Value) -> Layer2Result<bool> {
63 Ok(true)
65 }
66}
67
68#[async_trait]
70pub trait ToolRegistryTrait: Send + Sync {
71 fn register(&self, tool: Box<dyn Tool>) -> Layer2Result<()>;
73
74 fn unregister(&self, name: &str) -> Layer2Result<bool>;
76
77 fn get(&self, name: &str) -> Option<Arc<dyn Tool>>;
79
80 fn exists(&self, name: &str) -> bool;
82
83 fn list(&self) -> Vec<String>;
85
86 fn definitions(&self) -> Vec<ToolDefinition>;
88
89 async fn execute(&self, name: &str, args: &str) -> Layer2Result<ToolResult>;
91
92 fn count(&self) -> usize;
94}
95
96pub struct ToolRegistry {
98 tools: parking_lot::RwLock<HashMap<String, Arc<dyn Tool>>>,
99}
100
101impl ToolRegistry {
102 pub fn new() -> Self {
103 Self {
104 tools: parking_lot::RwLock::new(HashMap::new()),
105 }
106 }
107
108 pub fn with_builtin_tools() -> Self {
110 Self::new()
111 }
113}
114
115impl Default for ToolRegistry {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121#[async_trait]
122impl ToolRegistryTrait for ToolRegistry {
123 fn register(&self, tool: Box<dyn Tool>) -> Layer2Result<()> {
124 let mut tools = self.tools.write();
125 let name = tool.name().to_string();
126 tools.insert(name, Arc::from(tool));
127 Ok(())
128 }
129
130 fn unregister(&self, name: &str) -> Layer2Result<bool> {
131 let mut tools = self.tools.write();
132 Ok(tools.remove(name).is_some())
133 }
134
135 fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
136 let tools = self.tools.read();
137 tools.get(name).cloned()
138 }
139
140 fn exists(&self, name: &str) -> bool {
141 let tools = self.tools.read();
142 tools.contains_key(name)
143 }
144
145 fn list(&self) -> Vec<String> {
146 let tools = self.tools.read();
147 tools.keys().cloned().collect()
148 }
149
150 fn definitions(&self) -> Vec<ToolDefinition> {
151 let tools = self.tools.read();
152 tools
153 .values()
154 .map(|tool| ToolDefinition {
155 r#type: "function".to_string(),
156 function: FunctionDefinition {
157 name: tool.name().to_string(),
158 description: tool.description().to_string(),
159 parameters: tool.parameters(),
160 },
161 })
162 .collect()
163 }
164
165 async fn execute(&self, name: &str, args: &str) -> Layer2Result<ToolResult> {
166 let tool = self
167 .get(name)
168 .ok_or_else(|| Layer2Error::ToolNotFound(name.to_string()))?;
169
170 tool.execute(args).await
171 }
172
173 fn count(&self) -> usize {
174 let tools = self.tools.read();
175 tools.len()
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_tool_registry_creation() {
185 let registry = ToolRegistry::new();
186 assert_eq!(registry.count(), 0);
187 }
188
189 #[test]
190 fn test_tool_registry_list() {
191 let registry = ToolRegistry::new();
192 let list = registry.list();
193 assert!(list.is_empty());
194 }
195}