autoagents_core/tool/
mod.rs

1use autoagents_llm::chat::{FunctionTool, Tool};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::fmt::Debug;
5use std::sync::Arc;
6mod runtime;
7use async_trait::async_trait;
8pub use runtime::ToolRuntime;
9
10#[cfg(feature = "wasmtime")]
11pub use runtime::{WasmRuntime, WasmRuntimeError};
12
13/// Result emitted after executing a single tool call, including the tool name,
14/// success flag, original arguments, and the tool's structured result payload.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolCallResult {
17    pub tool_name: String,
18    pub success: bool,
19    pub arguments: Value,
20    pub result: Value,
21}
22
23#[derive(Debug, thiserror::Error)]
24pub enum ToolCallError {
25    #[error("Runtime Error {0}")]
26    RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
27
28    #[error("Serde Error {0}")]
29    SerdeError(#[from] serde_json::Error),
30}
31
32pub trait ToolT: Send + Sync + Debug + ToolRuntime {
33    /// The name of the tool.
34    fn name(&self) -> &'static str;
35    /// A description explaining the tool’s purpose.
36    fn description(&self) -> &'static str;
37    /// Return a description of the expected arguments.
38    fn args_schema(&self) -> Value;
39}
40
41/// Marker trait for input types used by `#[derive(ToolInput)]` macros.
42pub trait ToolInputT {
43    fn io_schema() -> &'static str;
44}
45
46/// A wrapper that allows Arc<dyn ToolT> to be used as Box<dyn ToolT>
47/// This is useful for sharing tools across multiple agents without cloning
48/// Wrapper around `Arc<dyn ToolT>` that presents a `Box<dyn ToolT>`-like API.
49/// Useful when sharing tool instances across multiple agents without cloning.
50#[derive(Debug)]
51pub struct SharedTool {
52    inner: Arc<dyn ToolT>,
53}
54
55impl SharedTool {
56    /// Create a new SharedTool from an Arc<dyn ToolT>
57    pub fn new(tool: Arc<dyn ToolT>) -> Self {
58        Self { inner: tool }
59    }
60}
61
62#[async_trait]
63impl ToolRuntime for SharedTool {
64    async fn execute(&self, args: Value) -> Result<Value, ToolCallError> {
65        self.inner.execute(args).await
66    }
67}
68
69impl ToolT for SharedTool {
70    fn name(&self) -> &'static str {
71        self.inner.name()
72    }
73
74    fn description(&self) -> &'static str {
75        self.inner.description()
76    }
77
78    fn args_schema(&self) -> Value {
79        self.inner.args_schema()
80    }
81}
82
83/// Helper function to convert Vec<Arc<dyn ToolT>> to Vec<Box<dyn ToolT>>
84/// This is useful when implementing AgentDeriveT::tools() with shared tools
85/// Convert a vector of `Arc<dyn ToolT>` into boxed trait objects for use in
86/// agent definitions.
87pub fn shared_tools_to_boxes(tools: &[Arc<dyn ToolT>]) -> Vec<Box<dyn ToolT>> {
88    tools
89        .iter()
90        .map(|t| Box::new(SharedTool::new(Arc::clone(t))) as Box<dyn ToolT>)
91        .collect()
92}
93
94/// Convert a ToolT trait object to an LLM Tool
95#[allow(clippy::borrowed_box)]
96pub fn to_llm_tool(tool: &Box<dyn ToolT>) -> Tool {
97    Tool {
98        tool_type: "function".to_string(),
99        function: FunctionTool {
100            name: tool.name().to_string(),
101            description: tool.description().to_string(),
102            parameters: tool.args_schema(),
103        },
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use autoagents_llm::chat::Tool;
111    use serde::{Deserialize, Serialize};
112    use serde_json::json;
113
114    #[derive(Debug, Serialize, Deserialize)]
115    struct TestInput {
116        name: String,
117        value: i32,
118    }
119
120    impl ToolInputT for TestInput {
121        fn io_schema() -> &'static str {
122            r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
123        }
124    }
125
126    #[derive(Debug)]
127    struct MockTool {
128        name: &'static str,
129        description: &'static str,
130        should_fail: bool,
131    }
132
133    impl MockTool {
134        fn new(name: &'static str, description: &'static str) -> Self {
135            Self {
136                name,
137                description,
138                should_fail: false,
139            }
140        }
141
142        fn with_failure(name: &'static str, description: &'static str) -> Self {
143            Self {
144                name,
145                description,
146                should_fail: true,
147            }
148        }
149    }
150
151    impl ToolT for MockTool {
152        fn name(&self) -> &'static str {
153            self.name
154        }
155
156        fn description(&self) -> &'static str {
157            self.description
158        }
159
160        fn args_schema(&self) -> Value {
161            json!({
162                "type": "object",
163                "properties": {
164                    "name": {"type": "string"},
165                    "value": {"type": "integer"}
166                },
167                "required": ["name", "value"]
168            })
169        }
170    }
171
172    #[async_trait]
173    impl ToolRuntime for MockTool {
174        async fn execute(
175            &self,
176            args: serde_json::Value,
177        ) -> Result<serde_json::Value, ToolCallError> {
178            if self.should_fail {
179                return Err(ToolCallError::RuntimeError(
180                    "Mock tool failure".to_string().into(),
181                ));
182            }
183
184            let input: TestInput = serde_json::from_value(args)?;
185            Ok(json!({
186                "processed_name": input.name,
187                "doubled_value": input.value * 2
188            }))
189        }
190    }
191
192    #[test]
193    fn test_tool_call_error_runtime_error() {
194        let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
195        assert_eq!(error.to_string(), "Runtime Error Runtime error");
196    }
197
198    #[test]
199    fn test_tool_call_error_serde_error() {
200        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
201        let error = ToolCallError::SerdeError(json_error);
202        assert!(error.to_string().contains("Serde Error"));
203    }
204
205    #[test]
206    fn test_tool_call_error_debug() {
207        let error = ToolCallError::RuntimeError("Debug test".to_string().into());
208        let debug_str = format!("{error:?}");
209        assert!(debug_str.contains("RuntimeError"));
210    }
211
212    #[test]
213    fn test_tool_call_error_from_serde() {
214        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
215        let error: ToolCallError = json_error.into();
216        assert!(matches!(error, ToolCallError::SerdeError(_)));
217    }
218
219    #[test]
220    fn test_tool_call_error_from_box_error() {
221        let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
222        let error: ToolCallError = box_error.into();
223        assert!(matches!(error, ToolCallError::RuntimeError(_)));
224    }
225
226    #[test]
227    fn test_mock_tool_creation() {
228        let tool = MockTool::new("test_tool", "A test tool");
229        assert_eq!(tool.name(), "test_tool");
230        assert_eq!(tool.description(), "A test tool");
231        assert!(!tool.should_fail);
232    }
233
234    #[test]
235    fn test_mock_tool_with_failure() {
236        let tool = MockTool::with_failure("failing_tool", "A failing tool");
237        assert_eq!(tool.name(), "failing_tool");
238        assert_eq!(tool.description(), "A failing tool");
239        assert!(tool.should_fail);
240    }
241
242    #[test]
243    fn test_mock_tool_args_schema() {
244        let tool = MockTool::new("schema_tool", "Schema test");
245        let schema = tool.args_schema();
246
247        assert_eq!(schema["type"], "object");
248        assert!(schema["properties"].is_object());
249        assert!(schema["properties"]["name"].is_object());
250        assert!(schema["properties"]["value"].is_object());
251        assert_eq!(schema["properties"]["name"]["type"], "string");
252        assert_eq!(schema["properties"]["value"]["type"], "integer");
253    }
254
255    #[tokio::test]
256    async fn test_mock_tool_run_success() {
257        let tool = MockTool::new("success_tool", "Success test");
258        let input = json!({
259            "name": "test",
260            "value": 42
261        });
262
263        let result = tool.execute(input).await;
264        assert!(result.is_ok());
265
266        let output = result.unwrap();
267        assert_eq!(output["processed_name"], "test");
268        assert_eq!(output["doubled_value"], 84);
269    }
270
271    #[tokio::test]
272    async fn test_mock_tool_run_failure() {
273        let tool = MockTool::with_failure("failure_tool", "Failure test");
274        let input = json!({
275            "name": "test",
276            "value": 42
277        });
278
279        let result = tool.execute(input).await;
280        assert!(result.is_err());
281        assert!(result
282            .unwrap_err()
283            .to_string()
284            .contains("Mock tool failure"));
285    }
286
287    #[tokio::test]
288    async fn test_mock_tool_run_invalid_input() {
289        let tool = MockTool::new("invalid_input_tool", "Invalid input test");
290        let input = json!({
291            "invalid_field": "test"
292        });
293
294        let result = tool.execute(input).await;
295        assert!(result.is_err());
296        assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
297    }
298
299    #[tokio::test]
300    async fn test_mock_tool_run_with_extra_fields() {
301        let tool = MockTool::new("extra_fields_tool", "Extra fields test");
302        let input = json!({
303            "name": "test",
304            "value": 42,
305            "extra_field": "ignored"
306        });
307
308        let result = tool.execute(input).await;
309        assert!(result.is_ok());
310
311        let output = result.unwrap();
312        assert_eq!(output["processed_name"], "test");
313        assert_eq!(output["doubled_value"], 84);
314    }
315
316    #[test]
317    fn test_mock_tool_debug() {
318        let tool = MockTool::new("debug_tool", "Debug test");
319        let debug_str = format!("{tool:?}");
320        assert!(debug_str.contains("MockTool"));
321        assert!(debug_str.contains("debug_tool"));
322    }
323
324    #[test]
325    fn test_tool_input_trait() {
326        let schema = TestInput::io_schema();
327        assert!(schema.contains("object"));
328        assert!(schema.contains("name"));
329        assert!(schema.contains("value"));
330        assert!(schema.contains("string"));
331        assert!(schema.contains("integer"));
332    }
333
334    #[test]
335    fn test_test_input_serialization() {
336        let input = TestInput {
337            name: "test".to_string(),
338            value: 42,
339        };
340        let serialized = serde_json::to_string(&input).unwrap();
341        assert!(serialized.contains("test"));
342        assert!(serialized.contains("42"));
343    }
344
345    #[test]
346    fn test_test_input_deserialization() {
347        let json = r#"{"name":"test","value":42}"#;
348        let input: TestInput = serde_json::from_str(json).unwrap();
349        assert_eq!(input.name, "test");
350        assert_eq!(input.value, 42);
351    }
352
353    #[test]
354    fn test_test_input_debug() {
355        let input = TestInput {
356            name: "debug".to_string(),
357            value: 123,
358        };
359        let debug_str = format!("{input:?}");
360        assert!(debug_str.contains("TestInput"));
361        assert!(debug_str.contains("debug"));
362        assert!(debug_str.contains("123"));
363    }
364
365    #[test]
366    fn test_boxed_tool_to_tool_conversion() {
367        let mock_tool = MockTool::new("convert_tool", "Conversion test");
368        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
369
370        let tool: Tool = to_llm_tool(&boxed_tool);
371        assert_eq!(tool.tool_type, "function");
372        assert_eq!(tool.function.name, "convert_tool");
373        assert_eq!(tool.function.description, "Conversion test");
374        assert_eq!(tool.function.parameters["type"], "object");
375    }
376
377    #[test]
378    fn test_tool_conversion_preserves_schema() {
379        let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
380        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
381
382        let tool: Tool = to_llm_tool(&boxed_tool);
383        let schema = &tool.function.parameters;
384
385        assert_eq!(schema["type"], "object");
386        assert_eq!(schema["properties"]["name"]["type"], "string");
387        assert_eq!(schema["properties"]["value"]["type"], "integer");
388        assert_eq!(schema["required"][0], "name");
389        assert_eq!(schema["required"][1], "value");
390    }
391
392    #[test]
393    fn test_tool_trait_object_usage() {
394        let tools: Vec<Box<dyn ToolT>> = vec![
395            Box::new(MockTool::new("tool1", "First tool")),
396            Box::new(MockTool::new("tool2", "Second tool")),
397            Box::new(MockTool::with_failure("tool3", "Third tool")),
398        ];
399
400        for tool in &tools {
401            assert!(!tool.name().is_empty());
402            assert!(!tool.description().is_empty());
403            assert!(tool.args_schema().is_object());
404        }
405    }
406
407    #[tokio::test]
408    async fn test_tool_run_with_different_inputs() {
409        let tool = MockTool::new("varied_input_tool", "Varied input test");
410
411        let inputs = vec![
412            json!({"name": "test1", "value": 1}),
413            json!({"name": "test2", "value": -5}),
414            json!({"name": "", "value": 0}),
415            json!({"name": "long_name_test", "value": 999999}),
416        ];
417
418        for input in inputs {
419            let result = tool.execute(input.clone()).await;
420            assert!(result.is_ok());
421
422            let output = result.unwrap();
423            assert_eq!(output["processed_name"], input["name"]);
424            assert_eq!(
425                output["doubled_value"],
426                input["value"].as_i64().unwrap() * 2
427            );
428        }
429    }
430
431    #[test]
432    fn test_tool_error_chaining() {
433        let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
434        let tool_error = ToolCallError::SerdeError(json_error);
435
436        // Test error source chain
437        use std::error::Error;
438        assert!(tool_error.source().is_some());
439    }
440
441    #[test]
442    fn test_tool_with_empty_name() {
443        let tool = MockTool::new("", "Empty name test");
444        assert_eq!(tool.name(), "");
445        assert_eq!(tool.description(), "Empty name test");
446    }
447
448    #[test]
449    fn test_tool_with_empty_description() {
450        let tool = MockTool::new("empty_desc", "");
451        assert_eq!(tool.name(), "empty_desc");
452        assert_eq!(tool.description(), "");
453    }
454
455    #[test]
456    fn test_tool_schema_complex() {
457        let tool = MockTool::new("complex_tool", "Complex schema test");
458        let schema = tool.args_schema();
459
460        // Verify schema structure
461        assert!(schema.is_object());
462        assert!(schema["properties"].is_object());
463        assert!(schema["required"].is_array());
464        assert_eq!(schema["required"].as_array().unwrap().len(), 2);
465    }
466
467    #[test]
468    fn test_multiple_tool_instances() {
469        let tool1 = MockTool::new("tool1", "First instance");
470        let tool2 = MockTool::new("tool2", "Second instance");
471
472        assert_ne!(tool1.name(), tool2.name());
473        assert_ne!(tool1.description(), tool2.description());
474
475        // Both should have the same schema structure
476        assert_eq!(tool1.args_schema(), tool2.args_schema());
477    }
478
479    #[test]
480    fn test_tool_send_sync() {
481        fn assert_send_sync<T: Send + Sync>() {}
482        assert_send_sync::<MockTool>();
483    }
484
485    #[test]
486    fn test_tool_trait_object_send_sync() {
487        fn assert_send_sync<T: Send + Sync>() {}
488        assert_send_sync::<Box<dyn ToolT>>();
489    }
490
491    #[test]
492    fn test_tool_call_result_creation() {
493        let result = ToolCallResult {
494            tool_name: "test_tool".to_string(),
495            success: true,
496            arguments: json!({"param": "value"}),
497            result: json!({"output": "success"}),
498        };
499
500        assert_eq!(result.tool_name, "test_tool");
501        assert!(result.success);
502        assert_eq!(result.arguments, json!({"param": "value"}));
503        assert_eq!(result.result, json!({"output": "success"}));
504    }
505
506    #[test]
507    fn test_tool_call_result_serialization() {
508        let result = ToolCallResult {
509            tool_name: "serialize_tool".to_string(),
510            success: false,
511            arguments: json!({"input": "test"}),
512            result: json!({"error": "failed"}),
513        };
514
515        let serialized = serde_json::to_string(&result).unwrap();
516        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
517
518        assert_eq!(deserialized.tool_name, "serialize_tool");
519        assert!(!deserialized.success);
520        assert_eq!(deserialized.arguments, json!({"input": "test"}));
521        assert_eq!(deserialized.result, json!({"error": "failed"}));
522    }
523
524    #[test]
525    fn test_tool_call_result_clone() {
526        let result = ToolCallResult {
527            tool_name: "clone_tool".to_string(),
528            success: true,
529            arguments: json!({"data": [1, 2, 3]}),
530            result: json!({"processed": [2, 4, 6]}),
531        };
532
533        let cloned = result.clone();
534        assert_eq!(result.tool_name, cloned.tool_name);
535        assert_eq!(result.success, cloned.success);
536        assert_eq!(result.arguments, cloned.arguments);
537        assert_eq!(result.result, cloned.result);
538    }
539
540    #[test]
541    fn test_tool_call_result_debug() {
542        let result = ToolCallResult {
543            tool_name: "debug_tool".to_string(),
544            success: true,
545            arguments: json!({}),
546            result: json!(null),
547        };
548
549        let debug_str = format!("{result:?}");
550        assert!(debug_str.contains("ToolCallResult"));
551        assert!(debug_str.contains("debug_tool"));
552    }
553
554    #[test]
555    fn test_tool_call_result_with_null_values() {
556        let result = ToolCallResult {
557            tool_name: "null_tool".to_string(),
558            success: false,
559            arguments: json!(null),
560            result: json!(null),
561        };
562
563        let serialized = serde_json::to_string(&result).unwrap();
564        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
565
566        assert_eq!(deserialized.tool_name, "null_tool");
567        assert!(!deserialized.success);
568        assert_eq!(deserialized.arguments, json!(null));
569        assert_eq!(deserialized.result, json!(null));
570    }
571
572    #[test]
573    fn test_tool_call_result_with_complex_data() {
574        let complex_args = json!({
575            "nested": {
576                "array": [1, 2, {"key": "value"}],
577                "string": "test",
578                "number": 42.5
579            }
580        });
581
582        let complex_result = json!({
583            "status": "completed",
584            "data": {
585                "items": ["a", "b", "c"],
586                "count": 3
587            }
588        });
589
590        let result = ToolCallResult {
591            tool_name: "complex_tool".to_string(),
592            success: true,
593            arguments: complex_args.clone(),
594            result: complex_result.clone(),
595        };
596
597        let serialized = serde_json::to_string(&result).unwrap();
598        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
599
600        assert_eq!(deserialized.arguments, complex_args);
601        assert_eq!(deserialized.result, complex_result);
602    }
603
604    #[test]
605    fn test_tool_call_result_empty_tool_name() {
606        let result = ToolCallResult {
607            tool_name: String::new(),
608            success: true,
609            arguments: json!({}),
610            result: json!({}),
611        };
612
613        assert!(result.tool_name.is_empty());
614        assert!(result.success);
615    }
616
617    #[test]
618    fn test_tool_call_result_large_data() {
619        let large_string = "x".repeat(10000);
620        let result = ToolCallResult {
621            tool_name: "large_tool".to_string(),
622            success: true,
623            arguments: json!({"large_param": large_string}),
624            result: json!({"processed": true}),
625        };
626
627        let serialized = serde_json::to_string(&result).unwrap();
628        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
629
630        assert_eq!(deserialized.tool_name, "large_tool");
631        assert!(deserialized.success);
632        assert!(
633            deserialized.arguments["large_param"]
634                .as_str()
635                .unwrap()
636                .len()
637                == 10000
638        );
639    }
640
641    #[test]
642    fn test_tool_call_result_equality() {
643        let result1 = ToolCallResult {
644            tool_name: "equal_tool".to_string(),
645            success: true,
646            arguments: json!({"param": "value"}),
647            result: json!({"output": "result"}),
648        };
649
650        let result2 = ToolCallResult {
651            tool_name: "equal_tool".to_string(),
652            success: true,
653            arguments: json!({"param": "value"}),
654            result: json!({"output": "result"}),
655        };
656
657        let result3 = ToolCallResult {
658            tool_name: "different_tool".to_string(),
659            success: true,
660            arguments: json!({"param": "value"}),
661            result: json!({"output": "result"}),
662        };
663
664        // Test equality through serialization since ToolCallResult doesn't implement PartialEq
665        let serialized1 = serde_json::to_string(&result1).unwrap();
666        let serialized2 = serde_json::to_string(&result2).unwrap();
667        let serialized3 = serde_json::to_string(&result3).unwrap();
668
669        assert_eq!(serialized1, serialized2);
670        assert_ne!(serialized1, serialized3);
671    }
672
673    #[test]
674    fn test_tool_call_result_with_unicode() {
675        let result = ToolCallResult {
676            tool_name: "unicode_tool".to_string(),
677            success: true,
678            arguments: json!({"message": "Hello δΈ–η•Œ! 🌍"}),
679            result: json!({"response": "Processed: Hello δΈ–η•Œ! 🌍"}),
680        };
681
682        let serialized = serde_json::to_string(&result).unwrap();
683        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
684
685        assert_eq!(deserialized.arguments["message"], "Hello δΈ–η•Œ! 🌍");
686        assert_eq!(deserialized.result["response"], "Processed: Hello δΈ–η•Œ! 🌍");
687    }
688
689    #[test]
690    fn test_tool_call_result_with_arrays() {
691        let result = ToolCallResult {
692            tool_name: "array_tool".to_string(),
693            success: true,
694            arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
695            result: json!({"sum": 15, "count": 5}),
696        };
697
698        let serialized = serde_json::to_string(&result).unwrap();
699        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
700
701        assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
702        assert_eq!(deserialized.result["sum"], 15);
703        assert_eq!(deserialized.result["count"], 5);
704    }
705
706    #[test]
707    fn test_tool_call_result_boolean_values() {
708        let result = ToolCallResult {
709            tool_name: "bool_tool".to_string(),
710            success: false,
711            arguments: json!({"enabled": true, "debug": false}),
712            result: json!({"valid": false, "error": true}),
713        };
714
715        let serialized = serde_json::to_string(&result).unwrap();
716        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
717
718        assert!(!deserialized.success);
719        assert_eq!(deserialized.arguments["enabled"], true);
720        assert_eq!(deserialized.arguments["debug"], false);
721        assert_eq!(deserialized.result["valid"], false);
722        assert_eq!(deserialized.result["error"], true);
723    }
724}