1use crate::types::{ToolError, ToolExecutionResult, ToolResult};
2use async_trait::async_trait;
3use futures::future::BoxFuture;
4use serde_json::Value as JsonValue;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8#[async_trait]
10pub trait Tool: Send + Sync {
11 fn name(&self) -> &str;
13
14 fn description(&self) -> &str;
16
17 fn input_schema(&self) -> JsonValue;
19
20 async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult;
22}
23
24type ToolExecutor = Arc<
25 dyn Fn(HashMap<String, JsonValue>) -> BoxFuture<'static, ToolExecutionResult> + Send + Sync,
26>;
27
28pub struct NativeTool {
33 name: String,
34 description: String,
35 input_schema: JsonValue,
36 executor: ToolExecutor,
37}
38
39impl NativeTool {
40 pub fn new<F, Fut>(
48 name: impl Into<String>,
49 description: impl Into<String>,
50 input_schema: JsonValue,
51 executor: F,
52 ) -> Self
53 where
54 F: Fn(HashMap<String, JsonValue>) -> Fut + Send + Sync + 'static,
55 Fut: std::future::Future<Output = ToolExecutionResult> + Send + 'static,
56 {
57 Self {
58 name: name.into(),
59 description: description.into(),
60 input_schema,
61 executor: Arc::new(move |params| Box::pin(executor(params))),
62 }
63 }
64}
65
66#[async_trait]
67impl Tool for NativeTool {
68 fn name(&self) -> &str {
69 &self.name
70 }
71
72 fn description(&self) -> &str {
73 &self.description
74 }
75
76 fn input_schema(&self) -> JsonValue {
77 self.input_schema.clone()
78 }
79
80 async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
81 (self.executor)(params).await
82 }
83}
84
85impl std::fmt::Debug for NativeTool {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("NativeTool")
88 .field("name", &self.name)
89 .field("description", &self.description)
90 .field("input_schema", &self.input_schema)
91 .finish()
92 }
93}
94
95pub struct ToolRegistry {
100 tools: HashMap<String, Arc<dyn Tool>>,
101}
102
103impl ToolRegistry {
104 pub fn new() -> Self {
106 Self {
107 tools: HashMap::new(),
108 }
109 }
110
111 pub fn register(&mut self, tool: impl Tool + 'static) -> &mut Self {
119 let name = tool.name().to_string();
120 self.tools.insert(name, Arc::new(tool));
121 self
122 }
123
124 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
126 self.tools.get(name)
127 }
128
129 pub fn list_names(&self) -> Vec<String> {
131 self.tools.keys().cloned().collect()
132 }
133
134 pub fn list_tools(&self) -> Vec<JsonValue> {
136 self.tools
137 .values()
138 .map(|tool| {
139 serde_json::json!({
140 "type": "function",
141 "function": {
142 "name": tool.name(),
143 "description": tool.description(),
144 "parameters": tool.input_schema(),
145 }
146 })
147 })
148 .collect()
149 }
150
151 pub async fn call_tool(
153 &self,
154 name: &str,
155 params: HashMap<String, JsonValue>,
156 ) -> ToolExecutionResult {
157 match self.tools.get(name) {
158 Some(tool) => tool.execute(params).await,
159 None => Err(ToolError::InvalidParameters(format!(
160 "Tool not found: {}",
161 name
162 ))),
163 }
164 }
165
166 pub fn has_tool(&self, name: &str) -> bool {
168 self.tools.contains_key(name)
169 }
170
171 pub fn len(&self) -> usize {
173 self.tools.len()
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.tools.is_empty()
179 }
180}
181
182impl Default for ToolRegistry {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl std::fmt::Debug for ToolRegistry {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("ToolRegistry")
191 .field("tool_count", &self.tools.len())
192 .field("tools", &self.tools.keys().collect::<Vec<_>>())
193 .finish()
194 }
195}
196
197pub struct EchoTool;
199
200#[async_trait]
201impl Tool for EchoTool {
202 fn name(&self) -> &str {
203 "echo"
204 }
205
206 fn description(&self) -> &str {
207 "Echoes back the input message"
208 }
209
210 fn input_schema(&self) -> JsonValue {
211 serde_json::json!({
212 "type": "object",
213 "properties": {
214 "message": {
215 "type": "string",
216 "description": "The message to echo"
217 }
218 },
219 "required": ["message"]
220 })
221 }
222
223 async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
224 let start = std::time::Instant::now();
225
226 let message = params
227 .get("message")
228 .and_then(|v| v.as_str())
229 .ok_or_else(|| ToolError::InvalidParameters("missing 'message' parameter".into()))?;
230
231 let output = serde_json::json!({
232 "echoed": message
233 });
234
235 Ok(ToolResult::success(
236 output,
237 start.elapsed().as_secs_f64() * 1000.0,
238 ))
239 }
240}
241
242pub struct CalculatorTool;
244
245#[async_trait]
246impl Tool for CalculatorTool {
247 fn name(&self) -> &str {
248 "calculator"
249 }
250
251 fn description(&self) -> &str {
252 "Performs basic arithmetic operations (add, subtract, multiply, divide)"
253 }
254
255 fn input_schema(&self) -> JsonValue {
256 serde_json::json!({
257 "type": "object",
258 "properties": {
259 "operation": {
260 "type": "string",
261 "enum": ["add", "subtract", "multiply", "divide"]
262 },
263 "a": { "type": "number" },
264 "b": { "type": "number" }
265 },
266 "required": ["operation", "a", "b"]
267 })
268 }
269
270 async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
271 let start = std::time::Instant::now();
272
273 let operation = params
274 .get("operation")
275 .and_then(|v| v.as_str())
276 .ok_or_else(|| ToolError::InvalidParameters("missing 'operation'".into()))?;
277
278 let a = params
279 .get("a")
280 .and_then(|v| v.as_f64())
281 .ok_or_else(|| ToolError::InvalidParameters("missing 'a'".into()))?;
282
283 let b = params
284 .get("b")
285 .and_then(|v| v.as_f64())
286 .ok_or_else(|| ToolError::InvalidParameters("missing 'b'".into()))?;
287
288 let result = match operation {
289 "add" => a + b,
290 "subtract" => a - b,
291 "multiply" => a * b,
292 "divide" => {
293 if b == 0.0 {
294 return Err(ToolError::ExecutionFailed("division by zero".into()));
295 }
296 a / b
297 }
298 _ => {
299 return Err(ToolError::InvalidParameters(format!(
300 "unknown operation: {}",
301 operation
302 )))
303 }
304 };
305
306 Ok(ToolResult::success(
307 serde_json::json!({ "result": result }),
308 start.elapsed().as_secs_f64() * 1000.0,
309 ))
310 }
311}