autoagents_core/tool/
mod.rs

1use autoagents_llm::chat::{FunctionTool, Tool};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::fmt::Debug;
5mod runtime;
6pub use runtime::ToolRuntime;
7
8#[cfg(feature = "wasmtime")]
9pub use runtime::{WasmRuntime, WasmRuntimeError};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ToolCallResult {
13    pub tool_name: String,
14    pub success: bool,
15    pub arguments: Value,
16    pub result: Value,
17}
18
19#[derive(Debug, thiserror::Error)]
20pub enum ToolCallError {
21    #[error("Runtime Error {0}")]
22    RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
23
24    #[error("Serde Error {0}")]
25    SerdeError(#[from] serde_json::Error),
26}
27
28pub trait ToolT: Send + Sync + Debug + ToolRuntime {
29    /// The name of the tool.
30    fn name(&self) -> &'static str;
31    /// A description explaining the tool’s purpose.
32    fn description(&self) -> &'static str;
33    /// Return a description of the expected arguments.
34    fn args_schema(&self) -> Value;
35    /// Run the tool with the given arguments (in JSON) and return the result (in JSON).
36    fn run(&self, args: Value) -> Result<Value, ToolCallError> {
37        self.execute(args)
38    }
39}
40
41pub trait ToolInputT {
42    fn io_schema() -> &'static str;
43}
44
45impl From<&Box<dyn ToolT>> for Tool {
46    fn from(tool: &Box<dyn ToolT>) -> Self {
47        Tool {
48            tool_type: "function".to_string(),
49            function: FunctionTool {
50                name: tool.name().to_string(),
51                description: tool.description().to_string(),
52                parameters: tool.args_schema(),
53            },
54        }
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use autoagents_llm::chat::Tool;
62    use serde::{Deserialize, Serialize};
63    use serde_json::json;
64
65    #[derive(Debug, Serialize, Deserialize)]
66    struct TestInput {
67        name: String,
68        value: i32,
69    }
70
71    impl ToolInputT for TestInput {
72        fn io_schema() -> &'static str {
73            r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
74        }
75    }
76
77    #[derive(Debug)]
78    struct MockTool {
79        name: &'static str,
80        description: &'static str,
81        should_fail: bool,
82    }
83
84    impl MockTool {
85        fn new(name: &'static str, description: &'static str) -> Self {
86            Self {
87                name,
88                description,
89                should_fail: false,
90            }
91        }
92
93        fn with_failure(name: &'static str, description: &'static str) -> Self {
94            Self {
95                name,
96                description,
97                should_fail: true,
98            }
99        }
100    }
101
102    impl ToolT for MockTool {
103        fn name(&self) -> &'static str {
104            self.name
105        }
106
107        fn description(&self) -> &'static str {
108            self.description
109        }
110
111        fn args_schema(&self) -> Value {
112            json!({
113                "type": "object",
114                "properties": {
115                    "name": {"type": "string"},
116                    "value": {"type": "integer"}
117                },
118                "required": ["name", "value"]
119            })
120        }
121    }
122
123    impl ToolRuntime for MockTool {
124        fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolCallError> {
125            if self.should_fail {
126                return Err(ToolCallError::RuntimeError(
127                    "Mock tool failure".to_string().into(),
128                ));
129            }
130
131            let input: TestInput = serde_json::from_value(args)?;
132            Ok(json!({
133                "processed_name": input.name,
134                "doubled_value": input.value * 2
135            }))
136        }
137    }
138
139    #[test]
140    fn test_tool_call_error_runtime_error() {
141        let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
142        assert_eq!(error.to_string(), "Runtime Error Runtime error");
143    }
144
145    #[test]
146    fn test_tool_call_error_serde_error() {
147        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
148        let error = ToolCallError::SerdeError(json_error);
149        assert!(error.to_string().contains("Serde Error"));
150    }
151
152    #[test]
153    fn test_tool_call_error_debug() {
154        let error = ToolCallError::RuntimeError("Debug test".to_string().into());
155        let debug_str = format!("{error:?}");
156        assert!(debug_str.contains("RuntimeError"));
157    }
158
159    #[test]
160    fn test_tool_call_error_from_serde() {
161        let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
162        let error: ToolCallError = json_error.into();
163        assert!(matches!(error, ToolCallError::SerdeError(_)));
164    }
165
166    #[test]
167    fn test_tool_call_error_from_box_error() {
168        let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
169        let error: ToolCallError = box_error.into();
170        assert!(matches!(error, ToolCallError::RuntimeError(_)));
171    }
172
173    #[test]
174    fn test_mock_tool_creation() {
175        let tool = MockTool::new("test_tool", "A test tool");
176        assert_eq!(tool.name(), "test_tool");
177        assert_eq!(tool.description(), "A test tool");
178        assert!(!tool.should_fail);
179    }
180
181    #[test]
182    fn test_mock_tool_with_failure() {
183        let tool = MockTool::with_failure("failing_tool", "A failing tool");
184        assert_eq!(tool.name(), "failing_tool");
185        assert_eq!(tool.description(), "A failing tool");
186        assert!(tool.should_fail);
187    }
188
189    #[test]
190    fn test_mock_tool_args_schema() {
191        let tool = MockTool::new("schema_tool", "Schema test");
192        let schema = tool.args_schema();
193
194        assert_eq!(schema["type"], "object");
195        assert!(schema["properties"].is_object());
196        assert!(schema["properties"]["name"].is_object());
197        assert!(schema["properties"]["value"].is_object());
198        assert_eq!(schema["properties"]["name"]["type"], "string");
199        assert_eq!(schema["properties"]["value"]["type"], "integer");
200    }
201
202    #[test]
203    fn test_mock_tool_run_success() {
204        let tool = MockTool::new("success_tool", "Success test");
205        let input = json!({
206            "name": "test",
207            "value": 42
208        });
209
210        let result = tool.run(input);
211        assert!(result.is_ok());
212
213        let output = result.unwrap();
214        assert_eq!(output["processed_name"], "test");
215        assert_eq!(output["doubled_value"], 84);
216    }
217
218    #[test]
219    fn test_mock_tool_run_failure() {
220        let tool = MockTool::with_failure("failure_tool", "Failure test");
221        let input = json!({
222            "name": "test",
223            "value": 42
224        });
225
226        let result = tool.run(input);
227        assert!(result.is_err());
228        assert!(result
229            .unwrap_err()
230            .to_string()
231            .contains("Mock tool failure"));
232    }
233
234    #[test]
235    fn test_mock_tool_run_invalid_input() {
236        let tool = MockTool::new("invalid_input_tool", "Invalid input test");
237        let input = json!({
238            "invalid_field": "test"
239        });
240
241        let result = tool.run(input);
242        assert!(result.is_err());
243        assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
244    }
245
246    #[test]
247    fn test_mock_tool_run_with_extra_fields() {
248        let tool = MockTool::new("extra_fields_tool", "Extra fields test");
249        let input = json!({
250            "name": "test",
251            "value": 42,
252            "extra_field": "ignored"
253        });
254
255        let result = tool.run(input);
256        assert!(result.is_ok());
257
258        let output = result.unwrap();
259        assert_eq!(output["processed_name"], "test");
260        assert_eq!(output["doubled_value"], 84);
261    }
262
263    #[test]
264    fn test_mock_tool_debug() {
265        let tool = MockTool::new("debug_tool", "Debug test");
266        let debug_str = format!("{tool:?}");
267        assert!(debug_str.contains("MockTool"));
268        assert!(debug_str.contains("debug_tool"));
269    }
270
271    #[test]
272    fn test_tool_input_trait() {
273        let schema = TestInput::io_schema();
274        assert!(schema.contains("object"));
275        assert!(schema.contains("name"));
276        assert!(schema.contains("value"));
277        assert!(schema.contains("string"));
278        assert!(schema.contains("integer"));
279    }
280
281    #[test]
282    fn test_test_input_serialization() {
283        let input = TestInput {
284            name: "test".to_string(),
285            value: 42,
286        };
287        let serialized = serde_json::to_string(&input).unwrap();
288        assert!(serialized.contains("test"));
289        assert!(serialized.contains("42"));
290    }
291
292    #[test]
293    fn test_test_input_deserialization() {
294        let json = r#"{"name":"test","value":42}"#;
295        let input: TestInput = serde_json::from_str(json).unwrap();
296        assert_eq!(input.name, "test");
297        assert_eq!(input.value, 42);
298    }
299
300    #[test]
301    fn test_test_input_debug() {
302        let input = TestInput {
303            name: "debug".to_string(),
304            value: 123,
305        };
306        let debug_str = format!("{input:?}");
307        assert!(debug_str.contains("TestInput"));
308        assert!(debug_str.contains("debug"));
309        assert!(debug_str.contains("123"));
310    }
311
312    #[test]
313    fn test_boxed_tool_to_tool_conversion() {
314        let mock_tool = MockTool::new("convert_tool", "Conversion test");
315        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
316
317        let tool: Tool = (&boxed_tool).into();
318        assert_eq!(tool.tool_type, "function");
319        assert_eq!(tool.function.name, "convert_tool");
320        assert_eq!(tool.function.description, "Conversion test");
321        assert_eq!(tool.function.parameters["type"], "object");
322    }
323
324    #[test]
325    fn test_tool_conversion_preserves_schema() {
326        let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
327        let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
328
329        let tool: Tool = (&boxed_tool).into();
330        let schema = &tool.function.parameters;
331
332        assert_eq!(schema["type"], "object");
333        assert_eq!(schema["properties"]["name"]["type"], "string");
334        assert_eq!(schema["properties"]["value"]["type"], "integer");
335        assert_eq!(schema["required"][0], "name");
336        assert_eq!(schema["required"][1], "value");
337    }
338
339    #[test]
340    fn test_tool_trait_object_usage() {
341        let tools: Vec<Box<dyn ToolT>> = vec![
342            Box::new(MockTool::new("tool1", "First tool")),
343            Box::new(MockTool::new("tool2", "Second tool")),
344            Box::new(MockTool::with_failure("tool3", "Third tool")),
345        ];
346
347        for tool in &tools {
348            assert!(!tool.name().is_empty());
349            assert!(!tool.description().is_empty());
350            assert!(tool.args_schema().is_object());
351        }
352    }
353
354    #[test]
355    fn test_tool_run_with_different_inputs() {
356        let tool = MockTool::new("varied_input_tool", "Varied input test");
357
358        let inputs = vec![
359            json!({"name": "test1", "value": 1}),
360            json!({"name": "test2", "value": -5}),
361            json!({"name": "", "value": 0}),
362            json!({"name": "long_name_test", "value": 999999}),
363        ];
364
365        for input in inputs {
366            let result = tool.run(input.clone());
367            assert!(result.is_ok());
368
369            let output = result.unwrap();
370            assert_eq!(output["processed_name"], input["name"]);
371            assert_eq!(
372                output["doubled_value"],
373                input["value"].as_i64().unwrap() * 2
374            );
375        }
376    }
377
378    #[test]
379    fn test_tool_error_chaining() {
380        let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
381        let tool_error = ToolCallError::SerdeError(json_error);
382
383        // Test error source chain
384        use std::error::Error;
385        assert!(tool_error.source().is_some());
386    }
387
388    #[test]
389    fn test_tool_with_empty_name() {
390        let tool = MockTool::new("", "Empty name test");
391        assert_eq!(tool.name(), "");
392        assert_eq!(tool.description(), "Empty name test");
393    }
394
395    #[test]
396    fn test_tool_with_empty_description() {
397        let tool = MockTool::new("empty_desc", "");
398        assert_eq!(tool.name(), "empty_desc");
399        assert_eq!(tool.description(), "");
400    }
401
402    #[test]
403    fn test_tool_schema_complex() {
404        let tool = MockTool::new("complex_tool", "Complex schema test");
405        let schema = tool.args_schema();
406
407        // Verify schema structure
408        assert!(schema.is_object());
409        assert!(schema["properties"].is_object());
410        assert!(schema["required"].is_array());
411        assert_eq!(schema["required"].as_array().unwrap().len(), 2);
412    }
413
414    #[test]
415    fn test_multiple_tool_instances() {
416        let tool1 = MockTool::new("tool1", "First instance");
417        let tool2 = MockTool::new("tool2", "Second instance");
418
419        assert_ne!(tool1.name(), tool2.name());
420        assert_ne!(tool1.description(), tool2.description());
421
422        // Both should have the same schema structure
423        assert_eq!(tool1.args_schema(), tool2.args_schema());
424    }
425
426    #[test]
427    fn test_tool_send_sync() {
428        fn assert_send_sync<T: Send + Sync>() {}
429        assert_send_sync::<MockTool>();
430    }
431
432    #[test]
433    fn test_tool_trait_object_send_sync() {
434        fn assert_send_sync<T: Send + Sync>() {}
435        assert_send_sync::<Box<dyn ToolT>>();
436    }
437
438    #[test]
439    fn test_tool_call_result_creation() {
440        let result = ToolCallResult {
441            tool_name: "test_tool".to_string(),
442            success: true,
443            arguments: json!({"param": "value"}),
444            result: json!({"output": "success"}),
445        };
446
447        assert_eq!(result.tool_name, "test_tool");
448        assert!(result.success);
449        assert_eq!(result.arguments, json!({"param": "value"}));
450        assert_eq!(result.result, json!({"output": "success"}));
451    }
452
453    #[test]
454    fn test_tool_call_result_serialization() {
455        let result = ToolCallResult {
456            tool_name: "serialize_tool".to_string(),
457            success: false,
458            arguments: json!({"input": "test"}),
459            result: json!({"error": "failed"}),
460        };
461
462        let serialized = serde_json::to_string(&result).unwrap();
463        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
464
465        assert_eq!(deserialized.tool_name, "serialize_tool");
466        assert!(!deserialized.success);
467        assert_eq!(deserialized.arguments, json!({"input": "test"}));
468        assert_eq!(deserialized.result, json!({"error": "failed"}));
469    }
470
471    #[test]
472    fn test_tool_call_result_clone() {
473        let result = ToolCallResult {
474            tool_name: "clone_tool".to_string(),
475            success: true,
476            arguments: json!({"data": [1, 2, 3]}),
477            result: json!({"processed": [2, 4, 6]}),
478        };
479
480        let cloned = result.clone();
481        assert_eq!(result.tool_name, cloned.tool_name);
482        assert_eq!(result.success, cloned.success);
483        assert_eq!(result.arguments, cloned.arguments);
484        assert_eq!(result.result, cloned.result);
485    }
486
487    #[test]
488    fn test_tool_call_result_debug() {
489        let result = ToolCallResult {
490            tool_name: "debug_tool".to_string(),
491            success: true,
492            arguments: json!({}),
493            result: json!(null),
494        };
495
496        let debug_str = format!("{result:?}");
497        assert!(debug_str.contains("ToolCallResult"));
498        assert!(debug_str.contains("debug_tool"));
499    }
500
501    #[test]
502    fn test_tool_call_result_with_null_values() {
503        let result = ToolCallResult {
504            tool_name: "null_tool".to_string(),
505            success: false,
506            arguments: json!(null),
507            result: json!(null),
508        };
509
510        let serialized = serde_json::to_string(&result).unwrap();
511        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
512
513        assert_eq!(deserialized.tool_name, "null_tool");
514        assert!(!deserialized.success);
515        assert_eq!(deserialized.arguments, json!(null));
516        assert_eq!(deserialized.result, json!(null));
517    }
518
519    #[test]
520    fn test_tool_call_result_with_complex_data() {
521        let complex_args = json!({
522            "nested": {
523                "array": [1, 2, {"key": "value"}],
524                "string": "test",
525                "number": 42.5
526            }
527        });
528
529        let complex_result = json!({
530            "status": "completed",
531            "data": {
532                "items": ["a", "b", "c"],
533                "count": 3
534            }
535        });
536
537        let result = ToolCallResult {
538            tool_name: "complex_tool".to_string(),
539            success: true,
540            arguments: complex_args.clone(),
541            result: complex_result.clone(),
542        };
543
544        let serialized = serde_json::to_string(&result).unwrap();
545        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
546
547        assert_eq!(deserialized.arguments, complex_args);
548        assert_eq!(deserialized.result, complex_result);
549    }
550
551    #[test]
552    fn test_tool_call_result_empty_tool_name() {
553        let result = ToolCallResult {
554            tool_name: String::new(),
555            success: true,
556            arguments: json!({}),
557            result: json!({}),
558        };
559
560        assert!(result.tool_name.is_empty());
561        assert!(result.success);
562    }
563
564    #[test]
565    fn test_tool_call_result_large_data() {
566        let large_string = "x".repeat(10000);
567        let result = ToolCallResult {
568            tool_name: "large_tool".to_string(),
569            success: true,
570            arguments: json!({"large_param": large_string}),
571            result: json!({"processed": true}),
572        };
573
574        let serialized = serde_json::to_string(&result).unwrap();
575        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
576
577        assert_eq!(deserialized.tool_name, "large_tool");
578        assert!(deserialized.success);
579        assert!(
580            deserialized.arguments["large_param"]
581                .as_str()
582                .unwrap()
583                .len()
584                == 10000
585        );
586    }
587
588    #[test]
589    fn test_tool_call_result_equality() {
590        let result1 = ToolCallResult {
591            tool_name: "equal_tool".to_string(),
592            success: true,
593            arguments: json!({"param": "value"}),
594            result: json!({"output": "result"}),
595        };
596
597        let result2 = ToolCallResult {
598            tool_name: "equal_tool".to_string(),
599            success: true,
600            arguments: json!({"param": "value"}),
601            result: json!({"output": "result"}),
602        };
603
604        let result3 = ToolCallResult {
605            tool_name: "different_tool".to_string(),
606            success: true,
607            arguments: json!({"param": "value"}),
608            result: json!({"output": "result"}),
609        };
610
611        // Test equality through serialization since ToolCallResult doesn't implement PartialEq
612        let serialized1 = serde_json::to_string(&result1).unwrap();
613        let serialized2 = serde_json::to_string(&result2).unwrap();
614        let serialized3 = serde_json::to_string(&result3).unwrap();
615
616        assert_eq!(serialized1, serialized2);
617        assert_ne!(serialized1, serialized3);
618    }
619
620    #[test]
621    fn test_tool_call_result_with_unicode() {
622        let result = ToolCallResult {
623            tool_name: "unicode_tool".to_string(),
624            success: true,
625            arguments: json!({"message": "Hello δΈ–η•Œ! 🌍"}),
626            result: json!({"response": "Processed: Hello δΈ–η•Œ! 🌍"}),
627        };
628
629        let serialized = serde_json::to_string(&result).unwrap();
630        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
631
632        assert_eq!(deserialized.arguments["message"], "Hello δΈ–η•Œ! 🌍");
633        assert_eq!(deserialized.result["response"], "Processed: Hello δΈ–η•Œ! 🌍");
634    }
635
636    #[test]
637    fn test_tool_call_result_with_arrays() {
638        let result = ToolCallResult {
639            tool_name: "array_tool".to_string(),
640            success: true,
641            arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
642            result: json!({"sum": 15, "count": 5}),
643        };
644
645        let serialized = serde_json::to_string(&result).unwrap();
646        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
647
648        assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
649        assert_eq!(deserialized.result["sum"], 15);
650        assert_eq!(deserialized.result["count"], 5);
651    }
652
653    #[test]
654    fn test_tool_call_result_boolean_values() {
655        let result = ToolCallResult {
656            tool_name: "bool_tool".to_string(),
657            success: false,
658            arguments: json!({"enabled": true, "debug": false}),
659            result: json!({"valid": false, "error": true}),
660        };
661
662        let serialized = serde_json::to_string(&result).unwrap();
663        let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
664
665        assert!(!deserialized.success);
666        assert_eq!(deserialized.arguments["enabled"], true);
667        assert_eq!(deserialized.arguments["debug"], false);
668        assert_eq!(deserialized.result["valid"], false);
669        assert_eq!(deserialized.result["error"], true);
670    }
671}