autoagents_core/tool/
mod.rs

1use autoagents_llm::chat::{FunctionTool, Tool};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::fmt::Debug;
5use std::sync::Arc;
6mod runtime;
7use async_trait::async_trait;
8pub use runtime::ToolRuntime;
9
10#[cfg(feature = "wasmtime")]
11pub use runtime::{WasmRuntime, WasmRuntimeError};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCallResult {
15    pub tool_name: String,
16    pub success: bool,
17    pub arguments: Value,
18    pub result: Value,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ToolCallError {
23    #[error("Runtime Error {0}")]
24    RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
25
26    #[error("Serde Error {0}")]
27    SerdeError(#[from] serde_json::Error),
28}
29
30pub trait ToolT: Send + Sync + Debug + ToolRuntime {
31    /// The name of the tool.
32    fn name(&self) -> &'static str;
33    /// A description explaining the tool’s purpose.
34    fn description(&self) -> &'static str;
35    /// Return a description of the expected arguments.
36    fn args_schema(&self) -> Value;
37}
38
39pub trait ToolInputT {
40    fn io_schema() -> &'static str;
41}
42
43/// A wrapper that allows Arc<dyn ToolT> to be used as Box<dyn ToolT>
44/// This is useful for sharing tools across multiple agents without cloning
45#[derive(Debug)]
46pub struct SharedTool {
47    inner: Arc<dyn ToolT>,
48}
49
50impl SharedTool {
51    /// Create a new SharedTool from an Arc<dyn ToolT>
52    pub fn new(tool: Arc<dyn ToolT>) -> Self {
53        Self { inner: tool }
54    }
55}
56
57#[async_trait]
58impl ToolRuntime for SharedTool {
59    async fn execute(&self, args: Value) -> Result<Value, ToolCallError> {
60        self.inner.execute(args).await
61    }
62}
63
64impl ToolT for SharedTool {
65    fn name(&self) -> &'static str {
66        self.inner.name()
67    }
68
69    fn description(&self) -> &'static str {
70        self.inner.description()
71    }
72
73    fn args_schema(&self) -> Value {
74        self.inner.args_schema()
75    }
76}
77
78/// Helper function to convert Vec<Arc<dyn ToolT>> to Vec<Box<dyn ToolT>>
79/// This is useful when implementing AgentDeriveT::tools() with shared tools
80pub fn shared_tools_to_boxes(tools: &[Arc<dyn ToolT>]) -> Vec<Box<dyn ToolT>> {
81    tools
82        .iter()
83        .map(|t| Box::new(SharedTool::new(Arc::clone(t))) as Box<dyn ToolT>)
84        .collect()
85}
86
87/// Convert a ToolT trait object to an LLM Tool
88#[allow(clippy::borrowed_box)]
89pub fn to_llm_tool(tool: &Box<dyn ToolT>) -> Tool {
90    Tool {
91        tool_type: "function".to_string(),
92        function: FunctionTool {
93            name: tool.name().to_string(),
94            description: tool.description().to_string(),
95            parameters: tool.args_schema(),
96        },
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use autoagents_llm::chat::Tool;
104    use serde::{Deserialize, Serialize};
105    use serde_json::json;
106
107    #[derive(Debug, Serialize, Deserialize)]
108    struct TestInput {
109        name: String,
110        value: i32,
111    }
112
113    impl ToolInputT for TestInput {
114        fn io_schema() -> &'static str {
115            r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
116        }
117    }
118
119    #[derive(Debug)]
120    struct MockTool {
121        name: &'static str,
122        description: &'static str,
123        should_fail: bool,
124    }
125
126    impl MockTool {
127        fn new(name: &'static str, description: &'static str) -> Self {
128            Self {
129                name,
130                description,
131                should_fail: false,
132            }
133        }
134
135        fn with_failure(name: &'static str, description: &'static str) -> Self {
136            Self {
137                name,
138                description,
139                should_fail: true,
140            }
141        }
142    }
143
144    impl ToolT for MockTool {
145        fn name(&self) -> &'static str {
146            self.name
147        }
148
149        fn description(&self) -> &'static str {
150            self.description
151        }
152
153        fn args_schema(&self) -> Value {
154            json!({
155                "type": "object",
156                "properties": {
157                    "name": {"type": "string"},
158                    "value": {"type": "integer"}
159                },
160                "required": ["name", "value"]
161            })
162        }
163    }
164
165    #[async_trait]
166    impl ToolRuntime for MockTool {
167        async fn execute(
168            &self,
169            args: serde_json::Value,
170        ) -> Result<serde_json::Value, ToolCallError> {
171            if self.should_fail {
172                return Err(ToolCallError::RuntimeError(
173                    "Mock tool failure".to_string().into(),
174                ));
175            }
176
177            let input: TestInput = serde_json::from_value(args)?;
178            Ok(json!({
179                "processed_name": input.name,
180                "doubled_value": input.value * 2
181            }))
182        }
183    }
184
185    #[test]
186    fn test_tool_call_error_runtime_error() {
187        let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
188        assert_eq!(error.to_string(), "Runtime Error Runtime error");
189    }
190
191    #[test]
192    fn test_tool_call_error_serde_error() {
193        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
194        let error = ToolCallError::SerdeError(json_error);
195        assert!(error.to_string().contains("Serde Error"));
196    }
197
198    #[test]
199    fn test_tool_call_error_debug() {
200        let error = ToolCallError::RuntimeError("Debug test".to_string().into());
201        let debug_str = format!("{error:?}");
202        assert!(debug_str.contains("RuntimeError"));
203    }
204
205    #[test]
206    fn test_tool_call_error_from_serde() {
207        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
208        let error: ToolCallError = json_error.into();
209        assert!(matches!(error, ToolCallError::SerdeError(_)));
210    }
211
212    #[test]
213    fn test_tool_call_error_from_box_error() {
214        let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
215        let error: ToolCallError = box_error.into();
216        assert!(matches!(error, ToolCallError::RuntimeError(_)));
217    }
218
219    #[test]
220    fn test_mock_tool_creation() {
221        let tool = MockTool::new("test_tool", "A test tool");
222        assert_eq!(tool.name(), "test_tool");
223        assert_eq!(tool.description(), "A test tool");
224        assert!(!tool.should_fail);
225    }
226
227    #[test]
228    fn test_mock_tool_with_failure() {
229        let tool = MockTool::with_failure("failing_tool", "A failing tool");
230        assert_eq!(tool.name(), "failing_tool");
231        assert_eq!(tool.description(), "A failing tool");
232        assert!(tool.should_fail);
233    }
234
235    #[test]
236    fn test_mock_tool_args_schema() {
237        let tool = MockTool::new("schema_tool", "Schema test");
238        let schema = tool.args_schema();
239
240        assert_eq!(schema["type"], "object");
241        assert!(schema["properties"].is_object());
242        assert!(schema["properties"]["name"].is_object());
243        assert!(schema["properties"]["value"].is_object());
244        assert_eq!(schema["properties"]["name"]["type"], "string");
245        assert_eq!(schema["properties"]["value"]["type"], "integer");
246    }
247
248    #[tokio::test]
249    async fn test_mock_tool_run_success() {
250        let tool = MockTool::new("success_tool", "Success test");
251        let input = json!({
252            "name": "test",
253            "value": 42
254        });
255
256        let result = tool.execute(input).await;
257        assert!(result.is_ok());
258
259        let output = result.unwrap();
260        assert_eq!(output["processed_name"], "test");
261        assert_eq!(output["doubled_value"], 84);
262    }
263
264    #[tokio::test]
265    async fn test_mock_tool_run_failure() {
266        let tool = MockTool::with_failure("failure_tool", "Failure test");
267        let input = json!({
268            "name": "test",
269            "value": 42
270        });
271
272        let result = tool.execute(input).await;
273        assert!(result.is_err());
274        assert!(result
275            .unwrap_err()
276            .to_string()
277            .contains("Mock tool failure"));
278    }
279
280    #[tokio::test]
281    async fn test_mock_tool_run_invalid_input() {
282        let tool = MockTool::new("invalid_input_tool", "Invalid input test");
283        let input = json!({
284            "invalid_field": "test"
285        });
286
287        let result = tool.execute(input).await;
288        assert!(result.is_err());
289        assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
290    }
291
292    #[tokio::test]
293    async fn test_mock_tool_run_with_extra_fields() {
294        let tool = MockTool::new("extra_fields_tool", "Extra fields test");
295        let input = json!({
296            "name": "test",
297            "value": 42,
298            "extra_field": "ignored"
299        });
300
301        let result = tool.execute(input).await;
302        assert!(result.is_ok());
303
304        let output = result.unwrap();
305        assert_eq!(output["processed_name"], "test");
306        assert_eq!(output["doubled_value"], 84);
307    }
308
309    #[test]
310    fn test_mock_tool_debug() {
311        let tool = MockTool::new("debug_tool", "Debug test");
312        let debug_str = format!("{tool:?}");
313        assert!(debug_str.contains("MockTool"));
314        assert!(debug_str.contains("debug_tool"));
315    }
316
317    #[test]
318    fn test_tool_input_trait() {
319        let schema = TestInput::io_schema();
320        assert!(schema.contains("object"));
321        assert!(schema.contains("name"));
322        assert!(schema.contains("value"));
323        assert!(schema.contains("string"));
324        assert!(schema.contains("integer"));
325    }
326
327    #[test]
328    fn test_test_input_serialization() {
329        let input = TestInput {
330            name: "test".to_string(),
331            value: 42,
332        };
333        let serialized = serde_json::to_string(&input).unwrap();
334        assert!(serialized.contains("test"));
335        assert!(serialized.contains("42"));
336    }
337
338    #[test]
339    fn test_test_input_deserialization() {
340        let json = r#"{"name":"test","value":42}"#;
341        let input: TestInput = serde_json::from_str(json).unwrap();
342        assert_eq!(input.name, "test");
343        assert_eq!(input.value, 42);
344    }
345
346    #[test]
347    fn test_test_input_debug() {
348        let input = TestInput {
349            name: "debug".to_string(),
350            value: 123,
351        };
352        let debug_str = format!("{input:?}");
353        assert!(debug_str.contains("TestInput"));
354        assert!(debug_str.contains("debug"));
355        assert!(debug_str.contains("123"));
356    }
357
358    #[test]
359    fn test_boxed_tool_to_tool_conversion() {
360        let mock_tool = MockTool::new("convert_tool", "Conversion test");
361        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
362
363        let tool: Tool = to_llm_tool(&boxed_tool);
364        assert_eq!(tool.tool_type, "function");
365        assert_eq!(tool.function.name, "convert_tool");
366        assert_eq!(tool.function.description, "Conversion test");
367        assert_eq!(tool.function.parameters["type"], "object");
368    }
369
370    #[test]
371    fn test_tool_conversion_preserves_schema() {
372        let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
373        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
374
375        let tool: Tool = to_llm_tool(&boxed_tool);
376        let schema = &tool.function.parameters;
377
378        assert_eq!(schema["type"], "object");
379        assert_eq!(schema["properties"]["name"]["type"], "string");
380        assert_eq!(schema["properties"]["value"]["type"], "integer");
381        assert_eq!(schema["required"][0], "name");
382        assert_eq!(schema["required"][1], "value");
383    }
384
385    #[test]
386    fn test_tool_trait_object_usage() {
387        let tools: Vec<Box<dyn ToolT>> = vec![
388            Box::new(MockTool::new("tool1", "First tool")),
389            Box::new(MockTool::new("tool2", "Second tool")),
390            Box::new(MockTool::with_failure("tool3", "Third tool")),
391        ];
392
393        for tool in &tools {
394            assert!(!tool.name().is_empty());
395            assert!(!tool.description().is_empty());
396            assert!(tool.args_schema().is_object());
397        }
398    }
399
400    #[tokio::test]
401    async fn test_tool_run_with_different_inputs() {
402        let tool = MockTool::new("varied_input_tool", "Varied input test");
403
404        let inputs = vec![
405            json!({"name": "test1", "value": 1}),
406            json!({"name": "test2", "value": -5}),
407            json!({"name": "", "value": 0}),
408            json!({"name": "long_name_test", "value": 999999}),
409        ];
410
411        for input in inputs {
412            let result = tool.execute(input.clone()).await;
413            assert!(result.is_ok());
414
415            let output = result.unwrap();
416            assert_eq!(output["processed_name"], input["name"]);
417            assert_eq!(
418                output["doubled_value"],
419                input["value"].as_i64().unwrap() * 2
420            );
421        }
422    }
423
424    #[test]
425    fn test_tool_error_chaining() {
426        let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
427        let tool_error = ToolCallError::SerdeError(json_error);
428
429        // Test error source chain
430        use std::error::Error;
431        assert!(tool_error.source().is_some());
432    }
433
434    #[test]
435    fn test_tool_with_empty_name() {
436        let tool = MockTool::new("", "Empty name test");
437        assert_eq!(tool.name(), "");
438        assert_eq!(tool.description(), "Empty name test");
439    }
440
441    #[test]
442    fn test_tool_with_empty_description() {
443        let tool = MockTool::new("empty_desc", "");
444        assert_eq!(tool.name(), "empty_desc");
445        assert_eq!(tool.description(), "");
446    }
447
448    #[test]
449    fn test_tool_schema_complex() {
450        let tool = MockTool::new("complex_tool", "Complex schema test");
451        let schema = tool.args_schema();
452
453        // Verify schema structure
454        assert!(schema.is_object());
455        assert!(schema["properties"].is_object());
456        assert!(schema["required"].is_array());
457        assert_eq!(schema["required"].as_array().unwrap().len(), 2);
458    }
459
460    #[test]
461    fn test_multiple_tool_instances() {
462        let tool1 = MockTool::new("tool1", "First instance");
463        let tool2 = MockTool::new("tool2", "Second instance");
464
465        assert_ne!(tool1.name(), tool2.name());
466        assert_ne!(tool1.description(), tool2.description());
467
468        // Both should have the same schema structure
469        assert_eq!(tool1.args_schema(), tool2.args_schema());
470    }
471
472    #[test]
473    fn test_tool_send_sync() {
474        fn assert_send_sync<T: Send + Sync>() {}
475        assert_send_sync::<MockTool>();
476    }
477
478    #[test]
479    fn test_tool_trait_object_send_sync() {
480        fn assert_send_sync<T: Send + Sync>() {}
481        assert_send_sync::<Box<dyn ToolT>>();
482    }
483
484    #[test]
485    fn test_tool_call_result_creation() {
486        let result = ToolCallResult {
487            tool_name: "test_tool".to_string(),
488            success: true,
489            arguments: json!({"param": "value"}),
490            result: json!({"output": "success"}),
491        };
492
493        assert_eq!(result.tool_name, "test_tool");
494        assert!(result.success);
495        assert_eq!(result.arguments, json!({"param": "value"}));
496        assert_eq!(result.result, json!({"output": "success"}));
497    }
498
499    #[test]
500    fn test_tool_call_result_serialization() {
501        let result = ToolCallResult {
502            tool_name: "serialize_tool".to_string(),
503            success: false,
504            arguments: json!({"input": "test"}),
505            result: json!({"error": "failed"}),
506        };
507
508        let serialized = serde_json::to_string(&result).unwrap();
509        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
510
511        assert_eq!(deserialized.tool_name, "serialize_tool");
512        assert!(!deserialized.success);
513        assert_eq!(deserialized.arguments, json!({"input": "test"}));
514        assert_eq!(deserialized.result, json!({"error": "failed"}));
515    }
516
517    #[test]
518    fn test_tool_call_result_clone() {
519        let result = ToolCallResult {
520            tool_name: "clone_tool".to_string(),
521            success: true,
522            arguments: json!({"data": [1, 2, 3]}),
523            result: json!({"processed": [2, 4, 6]}),
524        };
525
526        let cloned = result.clone();
527        assert_eq!(result.tool_name, cloned.tool_name);
528        assert_eq!(result.success, cloned.success);
529        assert_eq!(result.arguments, cloned.arguments);
530        assert_eq!(result.result, cloned.result);
531    }
532
533    #[test]
534    fn test_tool_call_result_debug() {
535        let result = ToolCallResult {
536            tool_name: "debug_tool".to_string(),
537            success: true,
538            arguments: json!({}),
539            result: json!(null),
540        };
541
542        let debug_str = format!("{result:?}");
543        assert!(debug_str.contains("ToolCallResult"));
544        assert!(debug_str.contains("debug_tool"));
545    }
546
547    #[test]
548    fn test_tool_call_result_with_null_values() {
549        let result = ToolCallResult {
550            tool_name: "null_tool".to_string(),
551            success: false,
552            arguments: json!(null),
553            result: json!(null),
554        };
555
556        let serialized = serde_json::to_string(&result).unwrap();
557        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
558
559        assert_eq!(deserialized.tool_name, "null_tool");
560        assert!(!deserialized.success);
561        assert_eq!(deserialized.arguments, json!(null));
562        assert_eq!(deserialized.result, json!(null));
563    }
564
565    #[test]
566    fn test_tool_call_result_with_complex_data() {
567        let complex_args = json!({
568            "nested": {
569                "array": [1, 2, {"key": "value"}],
570                "string": "test",
571                "number": 42.5
572            }
573        });
574
575        let complex_result = json!({
576            "status": "completed",
577            "data": {
578                "items": ["a", "b", "c"],
579                "count": 3
580            }
581        });
582
583        let result = ToolCallResult {
584            tool_name: "complex_tool".to_string(),
585            success: true,
586            arguments: complex_args.clone(),
587            result: complex_result.clone(),
588        };
589
590        let serialized = serde_json::to_string(&result).unwrap();
591        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
592
593        assert_eq!(deserialized.arguments, complex_args);
594        assert_eq!(deserialized.result, complex_result);
595    }
596
597    #[test]
598    fn test_tool_call_result_empty_tool_name() {
599        let result = ToolCallResult {
600            tool_name: String::new(),
601            success: true,
602            arguments: json!({}),
603            result: json!({}),
604        };
605
606        assert!(result.tool_name.is_empty());
607        assert!(result.success);
608    }
609
610    #[test]
611    fn test_tool_call_result_large_data() {
612        let large_string = "x".repeat(10000);
613        let result = ToolCallResult {
614            tool_name: "large_tool".to_string(),
615            success: true,
616            arguments: json!({"large_param": large_string}),
617            result: json!({"processed": true}),
618        };
619
620        let serialized = serde_json::to_string(&result).unwrap();
621        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
622
623        assert_eq!(deserialized.tool_name, "large_tool");
624        assert!(deserialized.success);
625        assert!(
626            deserialized.arguments["large_param"]
627                .as_str()
628                .unwrap()
629                .len()
630                == 10000
631        );
632    }
633
634    #[test]
635    fn test_tool_call_result_equality() {
636        let result1 = ToolCallResult {
637            tool_name: "equal_tool".to_string(),
638            success: true,
639            arguments: json!({"param": "value"}),
640            result: json!({"output": "result"}),
641        };
642
643        let result2 = ToolCallResult {
644            tool_name: "equal_tool".to_string(),
645            success: true,
646            arguments: json!({"param": "value"}),
647            result: json!({"output": "result"}),
648        };
649
650        let result3 = ToolCallResult {
651            tool_name: "different_tool".to_string(),
652            success: true,
653            arguments: json!({"param": "value"}),
654            result: json!({"output": "result"}),
655        };
656
657        // Test equality through serialization since ToolCallResult doesn't implement PartialEq
658        let serialized1 = serde_json::to_string(&result1).unwrap();
659        let serialized2 = serde_json::to_string(&result2).unwrap();
660        let serialized3 = serde_json::to_string(&result3).unwrap();
661
662        assert_eq!(serialized1, serialized2);
663        assert_ne!(serialized1, serialized3);
664    }
665
666    #[test]
667    fn test_tool_call_result_with_unicode() {
668        let result = ToolCallResult {
669            tool_name: "unicode_tool".to_string(),
670            success: true,
671            arguments: json!({"message": "Hello δΈ–η•Œ! 🌍"}),
672            result: json!({"response": "Processed: Hello δΈ–η•Œ! 🌍"}),
673        };
674
675        let serialized = serde_json::to_string(&result).unwrap();
676        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
677
678        assert_eq!(deserialized.arguments["message"], "Hello δΈ–η•Œ! 🌍");
679        assert_eq!(deserialized.result["response"], "Processed: Hello δΈ–η•Œ! 🌍");
680    }
681
682    #[test]
683    fn test_tool_call_result_with_arrays() {
684        let result = ToolCallResult {
685            tool_name: "array_tool".to_string(),
686            success: true,
687            arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
688            result: json!({"sum": 15, "count": 5}),
689        };
690
691        let serialized = serde_json::to_string(&result).unwrap();
692        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
693
694        assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
695        assert_eq!(deserialized.result["sum"], 15);
696        assert_eq!(deserialized.result["count"], 5);
697    }
698
699    #[test]
700    fn test_tool_call_result_boolean_values() {
701        let result = ToolCallResult {
702            tool_name: "bool_tool".to_string(),
703            success: false,
704            arguments: json!({"enabled": true, "debug": false}),
705            result: json!({"valid": false, "error": true}),
706        };
707
708        let serialized = serde_json::to_string(&result).unwrap();
709        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
710
711        assert!(!deserialized.success);
712        assert_eq!(deserialized.arguments["enabled"], true);
713        assert_eq!(deserialized.arguments["debug"], false);
714        assert_eq!(deserialized.result["valid"], false);
715        assert_eq!(deserialized.result["error"], true);
716    }
717}