Skip to main content

autoagents_core/tool/
mod.rs

1use autoagents_llm::chat::{FunctionTool, Tool};
2pub use autoagents_protocol::ToolCallResult;
3use serde_json::Value;
4use std::fmt::Debug;
5use std::sync::Arc;
6mod runtime;
7use async_trait::async_trait;
8pub use runtime::ToolRuntime;
9
10#[cfg(feature = "wasmtime")]
11pub use runtime::{WasmRuntime, WasmRuntimeError};
12
13#[derive(Debug, thiserror::Error)]
14pub enum ToolCallError {
15    #[error("Runtime Error {0}")]
16    RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
17
18    #[error("Serde Error {0}")]
19    SerdeError(#[from] serde_json::Error),
20}
21
22pub trait ToolT: Send + Sync + Debug + ToolRuntime {
23    /// The name of the tool.
24    fn name(&self) -> &str;
25    /// A description explaining the tool’s purpose.
26    fn description(&self) -> &str;
27    /// Return a description of the expected arguments.
28    fn args_schema(&self) -> Value;
29}
30
31/// Marker trait for input types used by `#[derive(ToolInput)]` macros.
32pub trait ToolInputT {
33    fn io_schema() -> &'static str;
34}
35
36/// A wrapper that allows Arc<dyn ToolT> to be used as Box<dyn ToolT>
37/// This is useful for sharing tools across multiple agents without cloning
38/// Wrapper around `Arc<dyn ToolT>` that presents a `Box<dyn ToolT>`-like API.
39/// Useful when sharing tool instances across multiple agents without cloning.
40#[derive(Debug)]
41pub struct SharedTool {
42    inner: Arc<dyn ToolT>,
43}
44
45impl SharedTool {
46    /// Create a new SharedTool from an Arc<dyn ToolT>
47    pub fn new(tool: Arc<dyn ToolT>) -> Self {
48        Self { inner: tool }
49    }
50}
51
52#[async_trait]
53impl ToolRuntime for SharedTool {
54    async fn execute(&self, args: Value) -> Result<Value, ToolCallError> {
55        self.inner.execute(args).await
56    }
57}
58
59impl ToolT for SharedTool {
60    fn name(&self) -> &str {
61        self.inner.name()
62    }
63
64    fn description(&self) -> &str {
65        self.inner.description()
66    }
67
68    fn args_schema(&self) -> Value {
69        self.inner.args_schema()
70    }
71}
72
73/// Helper function to convert Vec<Arc<dyn ToolT>> to Vec<Box<dyn ToolT>>
74/// This is useful when implementing AgentDeriveT::tools() with shared tools
75/// Convert a vector of `Arc<dyn ToolT>` into boxed trait objects for use in
76/// agent definitions.
77pub fn shared_tools_to_boxes(tools: &[Arc<dyn ToolT>]) -> Vec<Box<dyn ToolT>> {
78    tools
79        .iter()
80        .map(|t| Box::new(SharedTool::new(Arc::clone(t))) as Box<dyn ToolT>)
81        .collect()
82}
83
84/// Convert a ToolT trait object to an LLM Tool
85#[allow(clippy::borrowed_box)]
86pub fn to_llm_tool(tool: &Box<dyn ToolT>) -> Tool {
87    Tool {
88        tool_type: "function".to_string(),
89        function: FunctionTool {
90            name: tool.name().to_string(),
91            description: tool.description().to_string(),
92            parameters: tool.args_schema(),
93        },
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use autoagents_llm::chat::Tool;
101    use serde::{Deserialize, Serialize};
102    use serde_json::json;
103
104    #[derive(Debug, Serialize, Deserialize)]
105    struct TestInput {
106        name: String,
107        value: i32,
108    }
109
110    impl ToolInputT for TestInput {
111        fn io_schema() -> &'static str {
112            r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
113        }
114    }
115
116    #[derive(Debug)]
117    struct MockTool {
118        name: &'static str,
119        description: &'static str,
120        should_fail: bool,
121    }
122
123    impl MockTool {
124        fn new(name: &'static str, description: &'static str) -> Self {
125            Self {
126                name,
127                description,
128                should_fail: false,
129            }
130        }
131
132        fn with_failure(name: &'static str, description: &'static str) -> Self {
133            Self {
134                name,
135                description,
136                should_fail: true,
137            }
138        }
139    }
140
141    impl ToolT for MockTool {
142        fn name(&self) -> &'static str {
143            self.name
144        }
145
146        fn description(&self) -> &'static str {
147            self.description
148        }
149
150        fn args_schema(&self) -> Value {
151            json!({
152                "type": "object",
153                "properties": {
154                    "name": {"type": "string"},
155                    "value": {"type": "integer"}
156                },
157                "required": ["name", "value"]
158            })
159        }
160    }
161
162    #[async_trait]
163    impl ToolRuntime for MockTool {
164        async fn execute(
165            &self,
166            args: serde_json::Value,
167        ) -> Result<serde_json::Value, ToolCallError> {
168            if self.should_fail {
169                return Err(ToolCallError::RuntimeError(
170                    "Mock tool failure".to_string().into(),
171                ));
172            }
173
174            let input: TestInput = serde_json::from_value(args)?;
175            Ok(json!({
176                "processed_name": input.name,
177                "doubled_value": input.value * 2
178            }))
179        }
180    }
181
182    #[test]
183    fn test_tool_call_error_runtime_error() {
184        let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
185        assert_eq!(error.to_string(), "Runtime Error Runtime error");
186    }
187
188    #[test]
189    fn test_tool_call_error_serde_error() {
190        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
191        let error = ToolCallError::SerdeError(json_error);
192        assert!(error.to_string().contains("Serde Error"));
193    }
194
195    #[test]
196    fn test_tool_call_error_debug() {
197        let error = ToolCallError::RuntimeError("Debug test".to_string().into());
198        let debug_str = format!("{error:?}");
199        assert!(debug_str.contains("RuntimeError"));
200    }
201
202    #[test]
203    fn test_tool_call_error_from_serde() {
204        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
205        let error: ToolCallError = json_error.into();
206        assert!(matches!(error, ToolCallError::SerdeError(_)));
207    }
208
209    #[test]
210    fn test_tool_call_error_from_box_error() {
211        let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
212        let error: ToolCallError = box_error.into();
213        assert!(matches!(error, ToolCallError::RuntimeError(_)));
214    }
215
216    #[test]
217    fn test_mock_tool_creation() {
218        let tool = MockTool::new("test_tool", "A test tool");
219        assert_eq!(tool.name(), "test_tool");
220        assert_eq!(tool.description(), "A test tool");
221        assert!(!tool.should_fail);
222    }
223
224    #[test]
225    fn test_mock_tool_with_failure() {
226        let tool = MockTool::with_failure("failing_tool", "A failing tool");
227        assert_eq!(tool.name(), "failing_tool");
228        assert_eq!(tool.description(), "A failing tool");
229        assert!(tool.should_fail);
230    }
231
232    #[test]
233    fn test_mock_tool_args_schema() {
234        let tool = MockTool::new("schema_tool", "Schema test");
235        let schema = tool.args_schema();
236
237        assert_eq!(schema["type"], "object");
238        assert!(schema["properties"].is_object());
239        assert!(schema["properties"]["name"].is_object());
240        assert!(schema["properties"]["value"].is_object());
241        assert_eq!(schema["properties"]["name"]["type"], "string");
242        assert_eq!(schema["properties"]["value"]["type"], "integer");
243    }
244
245    #[tokio::test]
246    async fn test_mock_tool_run_success() {
247        let tool = MockTool::new("success_tool", "Success test");
248        let input = json!({
249            "name": "test",
250            "value": 42
251        });
252
253        let result = tool.execute(input).await;
254        assert!(result.is_ok());
255
256        let output = result.unwrap();
257        assert_eq!(output["processed_name"], "test");
258        assert_eq!(output["doubled_value"], 84);
259    }
260
261    #[tokio::test]
262    async fn test_mock_tool_run_failure() {
263        let tool = MockTool::with_failure("failure_tool", "Failure test");
264        let input = json!({
265            "name": "test",
266            "value": 42
267        });
268
269        let result = tool.execute(input).await;
270        assert!(result.is_err());
271        assert!(
272            result
273                .unwrap_err()
274                .to_string()
275                .contains("Mock tool failure")
276        );
277    }
278
279    #[tokio::test]
280    async fn test_mock_tool_run_invalid_input() {
281        let tool = MockTool::new("invalid_input_tool", "Invalid input test");
282        let input = json!({
283            "invalid_field": "test"
284        });
285
286        let result = tool.execute(input).await;
287        assert!(result.is_err());
288        assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
289    }
290
291    #[tokio::test]
292    async fn test_mock_tool_run_with_extra_fields() {
293        let tool = MockTool::new("extra_fields_tool", "Extra fields test");
294        let input = json!({
295            "name": "test",
296            "value": 42,
297            "extra_field": "ignored"
298        });
299
300        let result = tool.execute(input).await;
301        assert!(result.is_ok());
302
303        let output = result.unwrap();
304        assert_eq!(output["processed_name"], "test");
305        assert_eq!(output["doubled_value"], 84);
306    }
307
308    #[test]
309    fn test_mock_tool_debug() {
310        let tool = MockTool::new("debug_tool", "Debug test");
311        let debug_str = format!("{tool:?}");
312        assert!(debug_str.contains("MockTool"));
313        assert!(debug_str.contains("debug_tool"));
314    }
315
316    #[test]
317    fn test_tool_input_trait() {
318        let schema = TestInput::io_schema();
319        assert!(schema.contains("object"));
320        assert!(schema.contains("name"));
321        assert!(schema.contains("value"));
322        assert!(schema.contains("string"));
323        assert!(schema.contains("integer"));
324    }
325
326    #[test]
327    fn test_test_input_serialization() {
328        let input = TestInput {
329            name: "test".to_string(),
330            value: 42,
331        };
332        let serialized = serde_json::to_string(&input).unwrap();
333        assert!(serialized.contains("test"));
334        assert!(serialized.contains("42"));
335    }
336
337    #[test]
338    fn test_test_input_deserialization() {
339        let json = r#"{"name":"test","value":42}"#;
340        let input: TestInput = serde_json::from_str(json).unwrap();
341        assert_eq!(input.name, "test");
342        assert_eq!(input.value, 42);
343    }
344
345    #[test]
346    fn test_test_input_debug() {
347        let input = TestInput {
348            name: "debug".to_string(),
349            value: 123,
350        };
351        let debug_str = format!("{input:?}");
352        assert!(debug_str.contains("TestInput"));
353        assert!(debug_str.contains("debug"));
354        assert!(debug_str.contains("123"));
355    }
356
357    #[test]
358    fn test_boxed_tool_to_tool_conversion() {
359        let mock_tool = MockTool::new("convert_tool", "Conversion test");
360        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
361
362        let tool: Tool = to_llm_tool(&boxed_tool);
363        assert_eq!(tool.tool_type, "function");
364        assert_eq!(tool.function.name, "convert_tool");
365        assert_eq!(tool.function.description, "Conversion test");
366        assert_eq!(tool.function.parameters["type"], "object");
367    }
368
369    #[test]
370    fn test_tool_conversion_preserves_schema() {
371        let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
372        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
373
374        let tool: Tool = to_llm_tool(&boxed_tool);
375        let schema = &tool.function.parameters;
376
377        assert_eq!(schema["type"], "object");
378        assert_eq!(schema["properties"]["name"]["type"], "string");
379        assert_eq!(schema["properties"]["value"]["type"], "integer");
380        assert_eq!(schema["required"][0], "name");
381        assert_eq!(schema["required"][1], "value");
382    }
383
384    #[test]
385    fn test_tool_trait_object_usage() {
386        let tools: Vec<Box<dyn ToolT>> = vec![
387            Box::new(MockTool::new("tool1", "First tool")),
388            Box::new(MockTool::new("tool2", "Second tool")),
389            Box::new(MockTool::with_failure("tool3", "Third tool")),
390        ];
391
392        for tool in &tools {
393            assert!(!tool.name().is_empty());
394            assert!(!tool.description().is_empty());
395            assert!(tool.args_schema().is_object());
396        }
397    }
398
399    #[tokio::test]
400    async fn test_tool_run_with_different_inputs() {
401        let tool = MockTool::new("varied_input_tool", "Varied input test");
402
403        let inputs = vec![
404            json!({"name": "test1", "value": 1}),
405            json!({"name": "test2", "value": -5}),
406            json!({"name": "", "value": 0}),
407            json!({"name": "long_name_test", "value": 999999}),
408        ];
409
410        for input in inputs {
411            let result = tool.execute(input.clone()).await;
412            assert!(result.is_ok());
413
414            let output = result.unwrap();
415            assert_eq!(output["processed_name"], input["name"]);
416            assert_eq!(
417                output["doubled_value"],
418                input["value"].as_i64().unwrap() * 2
419            );
420        }
421    }
422
423    #[test]
424    fn test_tool_error_chaining() {
425        let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
426        let tool_error = ToolCallError::SerdeError(json_error);
427
428        // Test error source chain
429        use std::error::Error;
430        assert!(tool_error.source().is_some());
431    }
432
433    #[test]
434    fn test_tool_with_empty_name() {
435        let tool = MockTool::new("", "Empty name test");
436        assert_eq!(tool.name(), "");
437        assert_eq!(tool.description(), "Empty name test");
438    }
439
440    #[test]
441    fn test_tool_with_empty_description() {
442        let tool = MockTool::new("empty_desc", "");
443        assert_eq!(tool.name(), "empty_desc");
444        assert_eq!(tool.description(), "");
445    }
446
447    #[test]
448    fn test_tool_schema_complex() {
449        let tool = MockTool::new("complex_tool", "Complex schema test");
450        let schema = tool.args_schema();
451
452        // Verify schema structure
453        assert!(schema.is_object());
454        assert!(schema["properties"].is_object());
455        assert!(schema["required"].is_array());
456        assert_eq!(schema["required"].as_array().unwrap().len(), 2);
457    }
458
459    #[test]
460    fn test_multiple_tool_instances() {
461        let tool1 = MockTool::new("tool1", "First instance");
462        let tool2 = MockTool::new("tool2", "Second instance");
463
464        assert_ne!(tool1.name(), tool2.name());
465        assert_ne!(tool1.description(), tool2.description());
466
467        // Both should have the same schema structure
468        assert_eq!(tool1.args_schema(), tool2.args_schema());
469    }
470
471    #[test]
472    fn test_tool_send_sync() {
473        fn assert_send_sync<T: Send + Sync>() {}
474        assert_send_sync::<MockTool>();
475    }
476
477    #[test]
478    fn test_tool_trait_object_send_sync() {
479        fn assert_send_sync<T: Send + Sync>() {}
480        assert_send_sync::<Box<dyn ToolT>>();
481    }
482
483    #[test]
484    fn test_tool_call_result_creation() {
485        let result = ToolCallResult {
486            tool_name: "test_tool".to_string(),
487            success: true,
488            arguments: json!({"param": "value"}),
489            result: json!({"output": "success"}),
490        };
491
492        assert_eq!(result.tool_name, "test_tool");
493        assert!(result.success);
494        assert_eq!(result.arguments, json!({"param": "value"}));
495        assert_eq!(result.result, json!({"output": "success"}));
496    }
497
498    #[test]
499    fn test_tool_call_result_serialization() {
500        let result = ToolCallResult {
501            tool_name: "serialize_tool".to_string(),
502            success: false,
503            arguments: json!({"input": "test"}),
504            result: json!({"error": "failed"}),
505        };
506
507        let serialized = serde_json::to_string(&result).unwrap();
508        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
509
510        assert_eq!(deserialized.tool_name, "serialize_tool");
511        assert!(!deserialized.success);
512        assert_eq!(deserialized.arguments, json!({"input": "test"}));
513        assert_eq!(deserialized.result, json!({"error": "failed"}));
514    }
515
516    #[test]
517    fn test_tool_call_result_clone() {
518        let result = ToolCallResult {
519            tool_name: "clone_tool".to_string(),
520            success: true,
521            arguments: json!({"data": [1, 2, 3]}),
522            result: json!({"processed": [2, 4, 6]}),
523        };
524
525        let cloned = result.clone();
526        assert_eq!(result.tool_name, cloned.tool_name);
527        assert_eq!(result.success, cloned.success);
528        assert_eq!(result.arguments, cloned.arguments);
529        assert_eq!(result.result, cloned.result);
530    }
531
532    #[test]
533    fn test_tool_call_result_debug() {
534        let result = ToolCallResult {
535            tool_name: "debug_tool".to_string(),
536            success: true,
537            arguments: json!({}),
538            result: json!(null),
539        };
540
541        let debug_str = format!("{result:?}");
542        assert!(debug_str.contains("ToolCallResult"));
543        assert!(debug_str.contains("debug_tool"));
544    }
545
546    #[test]
547    fn test_tool_call_result_with_null_values() {
548        let result = ToolCallResult {
549            tool_name: "null_tool".to_string(),
550            success: false,
551            arguments: json!(null),
552            result: json!(null),
553        };
554
555        let serialized = serde_json::to_string(&result).unwrap();
556        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
557
558        assert_eq!(deserialized.tool_name, "null_tool");
559        assert!(!deserialized.success);
560        assert_eq!(deserialized.arguments, json!(null));
561        assert_eq!(deserialized.result, json!(null));
562    }
563
564    #[test]
565    fn test_tool_call_result_with_complex_data() {
566        let complex_args = json!({
567            "nested": {
568                "array": [1, 2, {"key": "value"}],
569                "string": "test",
570                "number": 42.5
571            }
572        });
573
574        let complex_result = json!({
575            "status": "completed",
576            "data": {
577                "items": ["a", "b", "c"],
578                "count": 3
579            }
580        });
581
582        let result = ToolCallResult {
583            tool_name: "complex_tool".to_string(),
584            success: true,
585            arguments: complex_args.clone(),
586            result: complex_result.clone(),
587        };
588
589        let serialized = serde_json::to_string(&result).unwrap();
590        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
591
592        assert_eq!(deserialized.arguments, complex_args);
593        assert_eq!(deserialized.result, complex_result);
594    }
595
596    #[test]
597    fn test_tool_call_result_empty_tool_name() {
598        let result = ToolCallResult {
599            tool_name: String::default(),
600            success: true,
601            arguments: json!({}),
602            result: json!({}),
603        };
604
605        assert!(result.tool_name.is_empty());
606        assert!(result.success);
607    }
608
609    #[test]
610    fn test_tool_call_result_large_data() {
611        let large_string = "x".repeat(10000);
612        let result = ToolCallResult {
613            tool_name: "large_tool".to_string(),
614            success: true,
615            arguments: json!({"large_param": large_string}),
616            result: json!({"processed": true}),
617        };
618
619        let serialized = serde_json::to_string(&result).unwrap();
620        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
621
622        assert_eq!(deserialized.tool_name, "large_tool");
623        assert!(deserialized.success);
624        assert!(
625            deserialized.arguments["large_param"]
626                .as_str()
627                .unwrap()
628                .len()
629                == 10000
630        );
631    }
632
633    #[test]
634    fn test_tool_call_result_equality() {
635        let result1 = ToolCallResult {
636            tool_name: "equal_tool".to_string(),
637            success: true,
638            arguments: json!({"param": "value"}),
639            result: json!({"output": "result"}),
640        };
641
642        let result2 = ToolCallResult {
643            tool_name: "equal_tool".to_string(),
644            success: true,
645            arguments: json!({"param": "value"}),
646            result: json!({"output": "result"}),
647        };
648
649        let result3 = ToolCallResult {
650            tool_name: "different_tool".to_string(),
651            success: true,
652            arguments: json!({"param": "value"}),
653            result: json!({"output": "result"}),
654        };
655
656        // Test equality through serialization since ToolCallResult doesn't implement PartialEq
657        let serialized1 = serde_json::to_string(&result1).unwrap();
658        let serialized2 = serde_json::to_string(&result2).unwrap();
659        let serialized3 = serde_json::to_string(&result3).unwrap();
660
661        assert_eq!(serialized1, serialized2);
662        assert_ne!(serialized1, serialized3);
663    }
664
665    #[test]
666    fn test_tool_call_result_with_unicode() {
667        let result = ToolCallResult {
668            tool_name: "unicode_tool".to_string(),
669            success: true,
670            arguments: json!({"message": "Hello δΈ–η•Œ! 🌍"}),
671            result: json!({"response": "Processed: Hello δΈ–η•Œ! 🌍"}),
672        };
673
674        let serialized = serde_json::to_string(&result).unwrap();
675        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
676
677        assert_eq!(deserialized.arguments["message"], "Hello δΈ–η•Œ! 🌍");
678        assert_eq!(deserialized.result["response"], "Processed: Hello δΈ–η•Œ! 🌍");
679    }
680
681    #[test]
682    fn test_tool_call_result_with_arrays() {
683        let result = ToolCallResult {
684            tool_name: "array_tool".to_string(),
685            success: true,
686            arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
687            result: json!({"sum": 15, "count": 5}),
688        };
689
690        let serialized = serde_json::to_string(&result).unwrap();
691        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
692
693        assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
694        assert_eq!(deserialized.result["sum"], 15);
695        assert_eq!(deserialized.result["count"], 5);
696    }
697
698    #[test]
699    fn test_tool_call_result_boolean_values() {
700        let result = ToolCallResult {
701            tool_name: "bool_tool".to_string(),
702            success: false,
703            arguments: json!({"enabled": true, "debug": false}),
704            result: json!({"valid": false, "error": true}),
705        };
706
707        let serialized = serde_json::to_string(&result).unwrap();
708        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
709
710        assert!(!deserialized.success);
711        assert_eq!(deserialized.arguments["enabled"], true);
712        assert_eq!(deserialized.arguments["debug"], false);
713        assert_eq!(deserialized.result["valid"], false);
714        assert_eq!(deserialized.result["error"], true);
715    }
716}