Skip to main content

autoagents_core/tool/
mod.rs

1use autoagents_llm::chat::{FunctionTool, Tool};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::fmt::Debug;
5use std::sync::Arc;
6mod runtime;
7use async_trait::async_trait;
8pub use runtime::ToolRuntime;
9
10#[cfg(feature = "wasmtime")]
11pub use runtime::{WasmRuntime, WasmRuntimeError};
12
13/// Result emitted after executing a single tool call, including the tool name,
14/// success flag, original arguments, and the tool's structured result payload.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolCallResult {
17    pub tool_name: String,
18    pub success: bool,
19    pub arguments: Value,
20    pub result: Value,
21}
22
23#[derive(Debug, thiserror::Error)]
24pub enum ToolCallError {
25    #[error("Runtime Error {0}")]
26    RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
27
28    #[error("Serde Error {0}")]
29    SerdeError(#[from] serde_json::Error),
30}
31
32pub trait ToolT: Send + Sync + Debug + ToolRuntime {
33    /// The name of the tool.
34    fn name(&self) -> &str;
35    /// A description explaining the tool’s purpose.
36    fn description(&self) -> &str;
37    /// Return a description of the expected arguments.
38    fn args_schema(&self) -> Value;
39}
40
41/// Marker trait for input types used by `#[derive(ToolInput)]` macros.
42pub trait ToolInputT {
43    fn io_schema() -> &'static str;
44}
45
46/// A wrapper that allows Arc<dyn ToolT> to be used as Box<dyn ToolT>
47/// This is useful for sharing tools across multiple agents without cloning
48/// Wrapper around `Arc<dyn ToolT>` that presents a `Box<dyn ToolT>`-like API.
49/// Useful when sharing tool instances across multiple agents without cloning.
50#[derive(Debug)]
51pub struct SharedTool {
52    inner: Arc<dyn ToolT>,
53}
54
55impl SharedTool {
56    /// Create a new SharedTool from an Arc<dyn ToolT>
57    pub fn new(tool: Arc<dyn ToolT>) -> Self {
58        Self { inner: tool }
59    }
60}
61
62#[async_trait]
63impl ToolRuntime for SharedTool {
64    async fn execute(&self, args: Value) -> Result<Value, ToolCallError> {
65        self.inner.execute(args).await
66    }
67}
68
69impl ToolT for SharedTool {
70    fn name(&self) -> &str {
71        self.inner.name()
72    }
73
74    fn description(&self) -> &str {
75        self.inner.description()
76    }
77
78    fn args_schema(&self) -> Value {
79        self.inner.args_schema()
80    }
81}
82
83/// Helper function to convert Vec<Arc<dyn ToolT>> to Vec<Box<dyn ToolT>>
84/// This is useful when implementing AgentDeriveT::tools() with shared tools
85/// Convert a vector of `Arc<dyn ToolT>` into boxed trait objects for use in
86/// agent definitions.
87pub fn shared_tools_to_boxes(tools: &[Arc<dyn ToolT>]) -> Vec<Box<dyn ToolT>> {
88    tools
89        .iter()
90        .map(|t| Box::new(SharedTool::new(Arc::clone(t))) as Box<dyn ToolT>)
91        .collect()
92}
93
94/// Convert a ToolT trait object to an LLM Tool
95#[allow(clippy::borrowed_box)]
96pub fn to_llm_tool(tool: &Box<dyn ToolT>) -> Tool {
97    Tool {
98        tool_type: "function".to_string(),
99        function: FunctionTool {
100            name: tool.name().to_string(),
101            description: tool.description().to_string(),
102            parameters: tool.args_schema(),
103        },
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use autoagents_llm::chat::Tool;
111    use serde::{Deserialize, Serialize};
112    use serde_json::json;
113
114    #[derive(Debug, Serialize, Deserialize)]
115    struct TestInput {
116        name: String,
117        value: i32,
118    }
119
120    impl ToolInputT for TestInput {
121        fn io_schema() -> &'static str {
122            r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
123        }
124    }
125
126    #[derive(Debug)]
127    struct MockTool {
128        name: &'static str,
129        description: &'static str,
130        should_fail: bool,
131    }
132
133    impl MockTool {
134        fn new(name: &'static str, description: &'static str) -> Self {
135            Self {
136                name,
137                description,
138                should_fail: false,
139            }
140        }
141
142        fn with_failure(name: &'static str, description: &'static str) -> Self {
143            Self {
144                name,
145                description,
146                should_fail: true,
147            }
148        }
149    }
150
151    impl ToolT for MockTool {
152        fn name(&self) -> &'static str {
153            self.name
154        }
155
156        fn description(&self) -> &'static str {
157            self.description
158        }
159
160        fn args_schema(&self) -> Value {
161            json!({
162                "type": "object",
163                "properties": {
164                    "name": {"type": "string"},
165                    "value": {"type": "integer"}
166                },
167                "required": ["name", "value"]
168            })
169        }
170    }
171
172    #[async_trait]
173    impl ToolRuntime for MockTool {
174        async fn execute(
175            &self,
176            args: serde_json::Value,
177        ) -> Result<serde_json::Value, ToolCallError> {
178            if self.should_fail {
179                return Err(ToolCallError::RuntimeError(
180                    "Mock tool failure".to_string().into(),
181                ));
182            }
183
184            let input: TestInput = serde_json::from_value(args)?;
185            Ok(json!({
186                "processed_name": input.name,
187                "doubled_value": input.value * 2
188            }))
189        }
190    }
191
192    #[test]
193    fn test_tool_call_error_runtime_error() {
194        let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
195        assert_eq!(error.to_string(), "Runtime Error Runtime error");
196    }
197
198    #[test]
199    fn test_tool_call_error_serde_error() {
200        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
201        let error = ToolCallError::SerdeError(json_error);
202        assert!(error.to_string().contains("Serde Error"));
203    }
204
205    #[test]
206    fn test_tool_call_error_debug() {
207        let error = ToolCallError::RuntimeError("Debug test".to_string().into());
208        let debug_str = format!("{error:?}");
209        assert!(debug_str.contains("RuntimeError"));
210    }
211
212    #[test]
213    fn test_tool_call_error_from_serde() {
214        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
215        let error: ToolCallError = json_error.into();
216        assert!(matches!(error, ToolCallError::SerdeError(_)));
217    }
218
219    #[test]
220    fn test_tool_call_error_from_box_error() {
221        let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
222        let error: ToolCallError = box_error.into();
223        assert!(matches!(error, ToolCallError::RuntimeError(_)));
224    }
225
226    #[test]
227    fn test_mock_tool_creation() {
228        let tool = MockTool::new("test_tool", "A test tool");
229        assert_eq!(tool.name(), "test_tool");
230        assert_eq!(tool.description(), "A test tool");
231        assert!(!tool.should_fail);
232    }
233
234    #[test]
235    fn test_mock_tool_with_failure() {
236        let tool = MockTool::with_failure("failing_tool", "A failing tool");
237        assert_eq!(tool.name(), "failing_tool");
238        assert_eq!(tool.description(), "A failing tool");
239        assert!(tool.should_fail);
240    }
241
242    #[test]
243    fn test_mock_tool_args_schema() {
244        let tool = MockTool::new("schema_tool", "Schema test");
245        let schema = tool.args_schema();
246
247        assert_eq!(schema["type"], "object");
248        assert!(schema["properties"].is_object());
249        assert!(schema["properties"]["name"].is_object());
250        assert!(schema["properties"]["value"].is_object());
251        assert_eq!(schema["properties"]["name"]["type"], "string");
252        assert_eq!(schema["properties"]["value"]["type"], "integer");
253    }
254
255    #[tokio::test]
256    async fn test_mock_tool_run_success() {
257        let tool = MockTool::new("success_tool", "Success test");
258        let input = json!({
259            "name": "test",
260            "value": 42
261        });
262
263        let result = tool.execute(input).await;
264        assert!(result.is_ok());
265
266        let output = result.unwrap();
267        assert_eq!(output["processed_name"], "test");
268        assert_eq!(output["doubled_value"], 84);
269    }
270
271    #[tokio::test]
272    async fn test_mock_tool_run_failure() {
273        let tool = MockTool::with_failure("failure_tool", "Failure test");
274        let input = json!({
275            "name": "test",
276            "value": 42
277        });
278
279        let result = tool.execute(input).await;
280        assert!(result.is_err());
281        assert!(
282            result
283                .unwrap_err()
284                .to_string()
285                .contains("Mock tool failure")
286        );
287    }
288
289    #[tokio::test]
290    async fn test_mock_tool_run_invalid_input() {
291        let tool = MockTool::new("invalid_input_tool", "Invalid input test");
292        let input = json!({
293            "invalid_field": "test"
294        });
295
296        let result = tool.execute(input).await;
297        assert!(result.is_err());
298        assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
299    }
300
301    #[tokio::test]
302    async fn test_mock_tool_run_with_extra_fields() {
303        let tool = MockTool::new("extra_fields_tool", "Extra fields test");
304        let input = json!({
305            "name": "test",
306            "value": 42,
307            "extra_field": "ignored"
308        });
309
310        let result = tool.execute(input).await;
311        assert!(result.is_ok());
312
313        let output = result.unwrap();
314        assert_eq!(output["processed_name"], "test");
315        assert_eq!(output["doubled_value"], 84);
316    }
317
318    #[test]
319    fn test_mock_tool_debug() {
320        let tool = MockTool::new("debug_tool", "Debug test");
321        let debug_str = format!("{tool:?}");
322        assert!(debug_str.contains("MockTool"));
323        assert!(debug_str.contains("debug_tool"));
324    }
325
326    #[test]
327    fn test_tool_input_trait() {
328        let schema = TestInput::io_schema();
329        assert!(schema.contains("object"));
330        assert!(schema.contains("name"));
331        assert!(schema.contains("value"));
332        assert!(schema.contains("string"));
333        assert!(schema.contains("integer"));
334    }
335
336    #[test]
337    fn test_test_input_serialization() {
338        let input = TestInput {
339            name: "test".to_string(),
340            value: 42,
341        };
342        let serialized = serde_json::to_string(&input).unwrap();
343        assert!(serialized.contains("test"));
344        assert!(serialized.contains("42"));
345    }
346
347    #[test]
348    fn test_test_input_deserialization() {
349        let json = r#"{"name":"test","value":42}"#;
350        let input: TestInput = serde_json::from_str(json).unwrap();
351        assert_eq!(input.name, "test");
352        assert_eq!(input.value, 42);
353    }
354
355    #[test]
356    fn test_test_input_debug() {
357        let input = TestInput {
358            name: "debug".to_string(),
359            value: 123,
360        };
361        let debug_str = format!("{input:?}");
362        assert!(debug_str.contains("TestInput"));
363        assert!(debug_str.contains("debug"));
364        assert!(debug_str.contains("123"));
365    }
366
367    #[test]
368    fn test_boxed_tool_to_tool_conversion() {
369        let mock_tool = MockTool::new("convert_tool", "Conversion test");
370        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
371
372        let tool: Tool = to_llm_tool(&boxed_tool);
373        assert_eq!(tool.tool_type, "function");
374        assert_eq!(tool.function.name, "convert_tool");
375        assert_eq!(tool.function.description, "Conversion test");
376        assert_eq!(tool.function.parameters["type"], "object");
377    }
378
379    #[test]
380    fn test_tool_conversion_preserves_schema() {
381        let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
382        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
383
384        let tool: Tool = to_llm_tool(&boxed_tool);
385        let schema = &tool.function.parameters;
386
387        assert_eq!(schema["type"], "object");
388        assert_eq!(schema["properties"]["name"]["type"], "string");
389        assert_eq!(schema["properties"]["value"]["type"], "integer");
390        assert_eq!(schema["required"][0], "name");
391        assert_eq!(schema["required"][1], "value");
392    }
393
394    #[test]
395    fn test_tool_trait_object_usage() {
396        let tools: Vec<Box<dyn ToolT>> = vec![
397            Box::new(MockTool::new("tool1", "First tool")),
398            Box::new(MockTool::new("tool2", "Second tool")),
399            Box::new(MockTool::with_failure("tool3", "Third tool")),
400        ];
401
402        for tool in &tools {
403            assert!(!tool.name().is_empty());
404            assert!(!tool.description().is_empty());
405            assert!(tool.args_schema().is_object());
406        }
407    }
408
409    #[tokio::test]
410    async fn test_tool_run_with_different_inputs() {
411        let tool = MockTool::new("varied_input_tool", "Varied input test");
412
413        let inputs = vec![
414            json!({"name": "test1", "value": 1}),
415            json!({"name": "test2", "value": -5}),
416            json!({"name": "", "value": 0}),
417            json!({"name": "long_name_test", "value": 999999}),
418        ];
419
420        for input in inputs {
421            let result = tool.execute(input.clone()).await;
422            assert!(result.is_ok());
423
424            let output = result.unwrap();
425            assert_eq!(output["processed_name"], input["name"]);
426            assert_eq!(
427                output["doubled_value"],
428                input["value"].as_i64().unwrap() * 2
429            );
430        }
431    }
432
433    #[test]
434    fn test_tool_error_chaining() {
435        let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
436        let tool_error = ToolCallError::SerdeError(json_error);
437
438        // Test error source chain
439        use std::error::Error;
440        assert!(tool_error.source().is_some());
441    }
442
443    #[test]
444    fn test_tool_with_empty_name() {
445        let tool = MockTool::new("", "Empty name test");
446        assert_eq!(tool.name(), "");
447        assert_eq!(tool.description(), "Empty name test");
448    }
449
450    #[test]
451    fn test_tool_with_empty_description() {
452        let tool = MockTool::new("empty_desc", "");
453        assert_eq!(tool.name(), "empty_desc");
454        assert_eq!(tool.description(), "");
455    }
456
457    #[test]
458    fn test_tool_schema_complex() {
459        let tool = MockTool::new("complex_tool", "Complex schema test");
460        let schema = tool.args_schema();
461
462        // Verify schema structure
463        assert!(schema.is_object());
464        assert!(schema["properties"].is_object());
465        assert!(schema["required"].is_array());
466        assert_eq!(schema["required"].as_array().unwrap().len(), 2);
467    }
468
469    #[test]
470    fn test_multiple_tool_instances() {
471        let tool1 = MockTool::new("tool1", "First instance");
472        let tool2 = MockTool::new("tool2", "Second instance");
473
474        assert_ne!(tool1.name(), tool2.name());
475        assert_ne!(tool1.description(), tool2.description());
476
477        // Both should have the same schema structure
478        assert_eq!(tool1.args_schema(), tool2.args_schema());
479    }
480
481    #[test]
482    fn test_tool_send_sync() {
483        fn assert_send_sync<T: Send + Sync>() {}
484        assert_send_sync::<MockTool>();
485    }
486
487    #[test]
488    fn test_tool_trait_object_send_sync() {
489        fn assert_send_sync<T: Send + Sync>() {}
490        assert_send_sync::<Box<dyn ToolT>>();
491    }
492
493    #[test]
494    fn test_tool_call_result_creation() {
495        let result = ToolCallResult {
496            tool_name: "test_tool".to_string(),
497            success: true,
498            arguments: json!({"param": "value"}),
499            result: json!({"output": "success"}),
500        };
501
502        assert_eq!(result.tool_name, "test_tool");
503        assert!(result.success);
504        assert_eq!(result.arguments, json!({"param": "value"}));
505        assert_eq!(result.result, json!({"output": "success"}));
506    }
507
508    #[test]
509    fn test_tool_call_result_serialization() {
510        let result = ToolCallResult {
511            tool_name: "serialize_tool".to_string(),
512            success: false,
513            arguments: json!({"input": "test"}),
514            result: json!({"error": "failed"}),
515        };
516
517        let serialized = serde_json::to_string(&result).unwrap();
518        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
519
520        assert_eq!(deserialized.tool_name, "serialize_tool");
521        assert!(!deserialized.success);
522        assert_eq!(deserialized.arguments, json!({"input": "test"}));
523        assert_eq!(deserialized.result, json!({"error": "failed"}));
524    }
525
526    #[test]
527    fn test_tool_call_result_clone() {
528        let result = ToolCallResult {
529            tool_name: "clone_tool".to_string(),
530            success: true,
531            arguments: json!({"data": [1, 2, 3]}),
532            result: json!({"processed": [2, 4, 6]}),
533        };
534
535        let cloned = result.clone();
536        assert_eq!(result.tool_name, cloned.tool_name);
537        assert_eq!(result.success, cloned.success);
538        assert_eq!(result.arguments, cloned.arguments);
539        assert_eq!(result.result, cloned.result);
540    }
541
542    #[test]
543    fn test_tool_call_result_debug() {
544        let result = ToolCallResult {
545            tool_name: "debug_tool".to_string(),
546            success: true,
547            arguments: json!({}),
548            result: json!(null),
549        };
550
551        let debug_str = format!("{result:?}");
552        assert!(debug_str.contains("ToolCallResult"));
553        assert!(debug_str.contains("debug_tool"));
554    }
555
556    #[test]
557    fn test_tool_call_result_with_null_values() {
558        let result = ToolCallResult {
559            tool_name: "null_tool".to_string(),
560            success: false,
561            arguments: json!(null),
562            result: json!(null),
563        };
564
565        let serialized = serde_json::to_string(&result).unwrap();
566        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
567
568        assert_eq!(deserialized.tool_name, "null_tool");
569        assert!(!deserialized.success);
570        assert_eq!(deserialized.arguments, json!(null));
571        assert_eq!(deserialized.result, json!(null));
572    }
573
574    #[test]
575    fn test_tool_call_result_with_complex_data() {
576        let complex_args = json!({
577            "nested": {
578                "array": [1, 2, {"key": "value"}],
579                "string": "test",
580                "number": 42.5
581            }
582        });
583
584        let complex_result = json!({
585            "status": "completed",
586            "data": {
587                "items": ["a", "b", "c"],
588                "count": 3
589            }
590        });
591
592        let result = ToolCallResult {
593            tool_name: "complex_tool".to_string(),
594            success: true,
595            arguments: complex_args.clone(),
596            result: complex_result.clone(),
597        };
598
599        let serialized = serde_json::to_string(&result).unwrap();
600        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
601
602        assert_eq!(deserialized.arguments, complex_args);
603        assert_eq!(deserialized.result, complex_result);
604    }
605
606    #[test]
607    fn test_tool_call_result_empty_tool_name() {
608        let result = ToolCallResult {
609            tool_name: String::new(),
610            success: true,
611            arguments: json!({}),
612            result: json!({}),
613        };
614
615        assert!(result.tool_name.is_empty());
616        assert!(result.success);
617    }
618
619    #[test]
620    fn test_tool_call_result_large_data() {
621        let large_string = "x".repeat(10000);
622        let result = ToolCallResult {
623            tool_name: "large_tool".to_string(),
624            success: true,
625            arguments: json!({"large_param": large_string}),
626            result: json!({"processed": true}),
627        };
628
629        let serialized = serde_json::to_string(&result).unwrap();
630        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
631
632        assert_eq!(deserialized.tool_name, "large_tool");
633        assert!(deserialized.success);
634        assert!(
635            deserialized.arguments["large_param"]
636                .as_str()
637                .unwrap()
638                .len()
639                == 10000
640        );
641    }
642
643    #[test]
644    fn test_tool_call_result_equality() {
645        let result1 = ToolCallResult {
646            tool_name: "equal_tool".to_string(),
647            success: true,
648            arguments: json!({"param": "value"}),
649            result: json!({"output": "result"}),
650        };
651
652        let result2 = ToolCallResult {
653            tool_name: "equal_tool".to_string(),
654            success: true,
655            arguments: json!({"param": "value"}),
656            result: json!({"output": "result"}),
657        };
658
659        let result3 = ToolCallResult {
660            tool_name: "different_tool".to_string(),
661            success: true,
662            arguments: json!({"param": "value"}),
663            result: json!({"output": "result"}),
664        };
665
666        // Test equality through serialization since ToolCallResult doesn't implement PartialEq
667        let serialized1 = serde_json::to_string(&result1).unwrap();
668        let serialized2 = serde_json::to_string(&result2).unwrap();
669        let serialized3 = serde_json::to_string(&result3).unwrap();
670
671        assert_eq!(serialized1, serialized2);
672        assert_ne!(serialized1, serialized3);
673    }
674
675    #[test]
676    fn test_tool_call_result_with_unicode() {
677        let result = ToolCallResult {
678            tool_name: "unicode_tool".to_string(),
679            success: true,
680            arguments: json!({"message": "Hello δΈ–η•Œ! 🌍"}),
681            result: json!({"response": "Processed: Hello δΈ–η•Œ! 🌍"}),
682        };
683
684        let serialized = serde_json::to_string(&result).unwrap();
685        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
686
687        assert_eq!(deserialized.arguments["message"], "Hello δΈ–η•Œ! 🌍");
688        assert_eq!(deserialized.result["response"], "Processed: Hello δΈ–η•Œ! 🌍");
689    }
690
691    #[test]
692    fn test_tool_call_result_with_arrays() {
693        let result = ToolCallResult {
694            tool_name: "array_tool".to_string(),
695            success: true,
696            arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
697            result: json!({"sum": 15, "count": 5}),
698        };
699
700        let serialized = serde_json::to_string(&result).unwrap();
701        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
702
703        assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
704        assert_eq!(deserialized.result["sum"], 15);
705        assert_eq!(deserialized.result["count"], 5);
706    }
707
708    #[test]
709    fn test_tool_call_result_boolean_values() {
710        let result = ToolCallResult {
711            tool_name: "bool_tool".to_string(),
712            success: false,
713            arguments: json!({"enabled": true, "debug": false}),
714            result: json!({"valid": false, "error": true}),
715        };
716
717        let serialized = serde_json::to_string(&result).unwrap();
718        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
719
720        assert!(!deserialized.success);
721        assert_eq!(deserialized.arguments["enabled"], true);
722        assert_eq!(deserialized.arguments["debug"], false);
723        assert_eq!(deserialized.result["valid"], false);
724        assert_eq!(deserialized.result["error"], true);
725    }
726}