1use serde::{Deserialize, Serialize};
41use serde_json::Value as JsonValue;
42use std::collections::HashMap;
43use std::error::Error;
44use std::fmt;
45
46#[cfg(feature = "json-schema")]
47use schemars::{schema_for, JsonSchema};
48
49pub trait Tool: Send + Sync {
51 fn name(&self) -> &str;
53
54 fn description(&self) -> &str;
56
57 fn parameters_schema(&self) -> JsonValue;
59
60 fn execute(&self, args: JsonValue) -> Result<JsonValue, Box<dyn Error + Send + Sync>>;
62
63 fn validate(&self, _args: &JsonValue) -> Result<(), Box<dyn Error + Send + Sync>> {
65 Ok(())
66 }
67}
68
69#[cfg(feature = "json-schema")]
94pub trait SchemaBasedTool: Send + Sync {
95 type Params: JsonSchema + for<'de> Deserialize<'de>;
97
98 fn name(&self) -> &str;
100
101 fn description(&self) -> &str;
103
104 fn execute_typed(
106 &self,
107 params: Self::Params,
108 ) -> Result<JsonValue, Box<dyn Error + Send + Sync>>;
109
110 fn validate_typed(&self, _params: &Self::Params) -> Result<(), Box<dyn Error + Send + Sync>> {
112 Ok(())
113 }
114}
115
116#[cfg(feature = "json-schema")]
118impl<T: SchemaBasedTool> Tool for T {
119 fn name(&self) -> &str {
120 SchemaBasedTool::name(self)
121 }
122
123 fn description(&self) -> &str {
124 SchemaBasedTool::description(self)
125 }
126
127 fn parameters_schema(&self) -> JsonValue {
128 let schema = schema_for!(T::Params);
131 serde_json::to_value(&schema).unwrap_or_else(|_| JsonValue::Null)
132 }
133
134 fn execute(&self, args: JsonValue) -> Result<JsonValue, Box<dyn Error + Send + Sync>> {
135 let params: T::Params = serde_json::from_value(args)
137 .map_err(|e| format!("Failed to deserialize parameters: {}", e))?;
138
139 self.validate_typed(¶ms)?;
141
142 self.execute_typed(params)
144 }
145
146 fn validate(&self, args: &JsonValue) -> Result<(), Box<dyn Error + Send + Sync>> {
147 let params: T::Params = serde_json::from_value(args.clone())
149 .map_err(|e| format!("Invalid parameters: {}", e))?;
150 self.validate_typed(¶ms)
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct ToolDefinition {
157 pub name: String,
159
160 pub description: String,
162
163 pub parameters: JsonValue,
165
166 #[serde(rename = "type", default = "default_tool_type")]
168 pub tool_type: String,
169}
170
171fn default_tool_type() -> String {
172 "function".to_string()
173}
174
175impl ToolDefinition {
176 pub fn new(
178 name: impl Into<String>,
179 description: impl Into<String>,
180 parameters: JsonValue,
181 ) -> Self {
182 Self {
183 name: name.into(),
184 description: description.into(),
185 parameters,
186 tool_type: "function".to_string(),
187 }
188 }
189
190 pub fn from_tool(tool: &dyn Tool) -> Self {
192 Self::new(tool.name(), tool.description(), tool.parameters_schema())
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ToolCall {
199 pub id: String,
201
202 pub name: String,
204
205 pub arguments: JsonValue,
207}
208
209impl ToolCall {
210 pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: JsonValue) -> Self {
212 Self {
213 id: id.into(),
214 name: name.into(),
215 arguments,
216 }
217 }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ToolResult {
223 pub tool_call_id: String,
225
226 pub tool_name: String,
228
229 pub content: JsonValue,
231
232 pub success: bool,
234
235 pub error: Option<String>,
237}
238
239impl ToolResult {
240 pub fn success(
242 tool_call_id: impl Into<String>,
243 tool_name: impl Into<String>,
244 content: JsonValue,
245 ) -> Self {
246 Self {
247 tool_call_id: tool_call_id.into(),
248 tool_name: tool_name.into(),
249 content,
250 success: true,
251 error: None,
252 }
253 }
254
255 pub fn error(
257 tool_call_id: impl Into<String>,
258 tool_name: impl Into<String>,
259 error: impl Into<String>,
260 ) -> Self {
261 Self {
262 tool_call_id: tool_call_id.into(),
263 tool_name: tool_name.into(),
264 content: JsonValue::Null,
265 success: false,
266 error: Some(error.into()),
267 }
268 }
269}
270
271pub struct ToolRegistry {
273 tools: HashMap<String, Box<dyn Tool>>,
274}
275
276impl ToolRegistry {
277 pub fn new() -> Self {
279 Self {
280 tools: HashMap::new(),
281 }
282 }
283
284 pub fn register(&mut self, tool: Box<dyn Tool>) -> Result<(), ToolRegistryError> {
286 let name = tool.name().to_string();
287
288 if self.tools.contains_key(&name) {
289 return Err(ToolRegistryError::DuplicateTool(name));
290 }
291
292 self.tools.insert(name, tool);
293 Ok(())
294 }
295
296 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
298 self.tools.get(name).map(|b| b.as_ref())
299 }
300
301 pub fn contains(&self, name: &str) -> bool {
303 self.tools.contains_key(name)
304 }
305
306 pub fn tool_names(&self) -> Vec<&str> {
308 self.tools.keys().map(|s| s.as_str()).collect()
309 }
310
311 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
313 self.tools
314 .values()
315 .map(|tool| ToolDefinition::from_tool(tool.as_ref()))
316 .collect()
317 }
318
319 pub fn execute(&self, tool_call: &ToolCall) -> ToolResult {
321 match self.get(&tool_call.name) {
322 Some(tool) => {
323 if let Err(e) = tool.validate(&tool_call.arguments) {
325 return ToolResult::error(
326 &tool_call.id,
327 &tool_call.name,
328 format!("Validation failed: {}", e),
329 );
330 }
331
332 match tool.execute(tool_call.arguments.clone()) {
334 Ok(result) => ToolResult::success(&tool_call.id, &tool_call.name, result),
335 Err(e) => ToolResult::error(&tool_call.id, &tool_call.name, e.to_string()),
336 }
337 }
338 None => ToolResult::error(
339 &tool_call.id,
340 &tool_call.name,
341 format!("Tool '{}' not found", tool_call.name),
342 ),
343 }
344 }
345
346 pub fn execute_batch(&self, tool_calls: &[ToolCall]) -> Vec<ToolResult> {
348 tool_calls.iter().map(|tc| self.execute(tc)).collect()
349 }
350
351 pub fn len(&self) -> usize {
353 self.tools.len()
354 }
355
356 pub fn is_empty(&self) -> bool {
358 self.tools.is_empty()
359 }
360}
361
362impl Default for ToolRegistry {
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368impl fmt::Debug for ToolRegistry {
369 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370 f.debug_struct("ToolRegistry")
371 .field("tools", &self.tool_names())
372 .finish()
373 }
374}
375
376#[derive(Debug, Clone)]
378pub enum ToolRegistryError {
379 DuplicateTool(String),
381}
382
383impl fmt::Display for ToolRegistryError {
384 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385 match self {
386 ToolRegistryError::DuplicateTool(name) => {
387 write!(f, "Tool '{}' is already registered", name)
388 }
389 }
390 }
391}
392
393impl Error for ToolRegistryError {}
394
395#[macro_export]
397macro_rules! simple_tool {
398 (
399 name: $name:expr,
400 description: $desc:expr,
401 parameters: $params:expr,
402 execute: |$args:ident| $body:expr
403 ) => {{
404 struct SimpleTool;
405 impl $crate::tools::Tool for SimpleTool {
406 fn name(&self) -> &str {
407 $name
408 }
409 fn description(&self) -> &str {
410 $desc
411 }
412 fn parameters_schema(&self) -> serde_json::Value {
413 $params
414 }
415 fn execute(
416 &self,
417 $args: serde_json::Value,
418 ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
419 Ok($body)
420 }
421 }
422 Box::new(SimpleTool)
423 }};
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use serde_json::json;
430
431 struct TestTool;
432 impl Tool for TestTool {
433 fn name(&self) -> &str {
434 "test_tool"
435 }
436 fn description(&self) -> &str {
437 "A test tool"
438 }
439 fn parameters_schema(&self) -> JsonValue {
440 json!({"type": "object", "properties": {"input": {"type": "string"}}})
441 }
442 fn execute(&self, args: JsonValue) -> Result<JsonValue, Box<dyn Error + Send + Sync>> {
443 Ok(json!({"result": format!("Processed: {}", args["input"])}))
444 }
445 }
446
447 #[test]
448 fn test_tool_registry() {
449 let mut registry = ToolRegistry::new();
450 assert_eq!(registry.len(), 0);
451
452 registry.register(Box::new(TestTool)).unwrap();
453 assert_eq!(registry.len(), 1);
454 assert!(registry.contains("test_tool"));
455 }
456
457 #[test]
458 fn test_tool_execution() {
459 let mut registry = ToolRegistry::new();
460 registry.register(Box::new(TestTool)).unwrap();
461
462 let call = ToolCall::new("call-1", "test_tool", json!({"input": "hello"}));
463 let result = registry.execute(&call);
464
465 assert!(result.success);
466 assert_eq!(result.tool_name, "test_tool");
467 }
468
469 #[test]
470 fn test_simple_tool_macro() {
471 let tool = simple_tool!(
472 name: "echo",
473 description: "Echoes input",
474 parameters: json!({"type": "object", "properties": {"text": {"type": "string"}}}),
475 execute: |args| {
476 json!({"echo": args["text"]})
477 }
478 );
479
480 assert_eq!(tool.name(), "echo");
481 let result = tool.execute(json!({"text": "hello"})).unwrap();
482 assert_eq!(result["echo"], "hello");
483 }
484}