1use autoagents_llm::chat::{FunctionTool, Tool};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::fmt::Debug;
5use std::sync::Arc;
6mod runtime;
7use async_trait::async_trait;
8pub use runtime::ToolRuntime;
9
10#[cfg(feature = "wasmtime")]
11pub use runtime::{WasmRuntime, WasmRuntimeError};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCallResult {
15 pub tool_name: String,
16 pub success: bool,
17 pub arguments: Value,
18 pub result: Value,
19}
20
21#[derive(Debug, thiserror::Error)]
22pub enum ToolCallError {
23 #[error("Runtime Error {0}")]
24 RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
25
26 #[error("Serde Error {0}")]
27 SerdeError(#[from] serde_json::Error),
28}
29
30pub trait ToolT: Send + Sync + Debug + ToolRuntime {
31 fn name(&self) -> &'static str;
33 fn description(&self) -> &'static str;
35 fn args_schema(&self) -> Value;
37}
38
39pub trait ToolInputT {
40 fn io_schema() -> &'static str;
41}
42
43#[derive(Debug)]
46pub struct SharedTool {
47 inner: Arc<dyn ToolT>,
48}
49
50impl SharedTool {
51 pub fn new(tool: Arc<dyn ToolT>) -> Self {
53 Self { inner: tool }
54 }
55}
56
57#[async_trait]
58impl ToolRuntime for SharedTool {
59 async fn execute(&self, args: Value) -> Result<Value, ToolCallError> {
60 self.inner.execute(args).await
61 }
62}
63
64impl ToolT for SharedTool {
65 fn name(&self) -> &'static str {
66 self.inner.name()
67 }
68
69 fn description(&self) -> &'static str {
70 self.inner.description()
71 }
72
73 fn args_schema(&self) -> Value {
74 self.inner.args_schema()
75 }
76}
77
78pub fn shared_tools_to_boxes(tools: &[Arc<dyn ToolT>]) -> Vec<Box<dyn ToolT>> {
81 tools
82 .iter()
83 .map(|t| Box::new(SharedTool::new(Arc::clone(t))) as Box<dyn ToolT>)
84 .collect()
85}
86
87#[allow(clippy::borrowed_box)]
89pub fn to_llm_tool(tool: &Box<dyn ToolT>) -> Tool {
90 Tool {
91 tool_type: "function".to_string(),
92 function: FunctionTool {
93 name: tool.name().to_string(),
94 description: tool.description().to_string(),
95 parameters: tool.args_schema(),
96 },
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use autoagents_llm::chat::Tool;
104 use serde::{Deserialize, Serialize};
105 use serde_json::json;
106
107 #[derive(Debug, Serialize, Deserialize)]
108 struct TestInput {
109 name: String,
110 value: i32,
111 }
112
113 impl ToolInputT for TestInput {
114 fn io_schema() -> &'static str {
115 r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
116 }
117 }
118
119 #[derive(Debug)]
120 struct MockTool {
121 name: &'static str,
122 description: &'static str,
123 should_fail: bool,
124 }
125
126 impl MockTool {
127 fn new(name: &'static str, description: &'static str) -> Self {
128 Self {
129 name,
130 description,
131 should_fail: false,
132 }
133 }
134
135 fn with_failure(name: &'static str, description: &'static str) -> Self {
136 Self {
137 name,
138 description,
139 should_fail: true,
140 }
141 }
142 }
143
144 impl ToolT for MockTool {
145 fn name(&self) -> &'static str {
146 self.name
147 }
148
149 fn description(&self) -> &'static str {
150 self.description
151 }
152
153 fn args_schema(&self) -> Value {
154 json!({
155 "type": "object",
156 "properties": {
157 "name": {"type": "string"},
158 "value": {"type": "integer"}
159 },
160 "required": ["name", "value"]
161 })
162 }
163 }
164
165 #[async_trait]
166 impl ToolRuntime for MockTool {
167 async fn execute(
168 &self,
169 args: serde_json::Value,
170 ) -> Result<serde_json::Value, ToolCallError> {
171 if self.should_fail {
172 return Err(ToolCallError::RuntimeError(
173 "Mock tool failure".to_string().into(),
174 ));
175 }
176
177 let input: TestInput = serde_json::from_value(args)?;
178 Ok(json!({
179 "processed_name": input.name,
180 "doubled_value": input.value * 2
181 }))
182 }
183 }
184
185 #[test]
186 fn test_tool_call_error_runtime_error() {
187 let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
188 assert_eq!(error.to_string(), "Runtime Error Runtime error");
189 }
190
191 #[test]
192 fn test_tool_call_error_serde_error() {
193 let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
194 let error = ToolCallError::SerdeError(json_error);
195 assert!(error.to_string().contains("Serde Error"));
196 }
197
198 #[test]
199 fn test_tool_call_error_debug() {
200 let error = ToolCallError::RuntimeError("Debug test".to_string().into());
201 let debug_str = format!("{error:?}");
202 assert!(debug_str.contains("RuntimeError"));
203 }
204
205 #[test]
206 fn test_tool_call_error_from_serde() {
207 let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
208 let error: ToolCallError = json_error.into();
209 assert!(matches!(error, ToolCallError::SerdeError(_)));
210 }
211
212 #[test]
213 fn test_tool_call_error_from_box_error() {
214 let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
215 let error: ToolCallError = box_error.into();
216 assert!(matches!(error, ToolCallError::RuntimeError(_)));
217 }
218
219 #[test]
220 fn test_mock_tool_creation() {
221 let tool = MockTool::new("test_tool", "A test tool");
222 assert_eq!(tool.name(), "test_tool");
223 assert_eq!(tool.description(), "A test tool");
224 assert!(!tool.should_fail);
225 }
226
227 #[test]
228 fn test_mock_tool_with_failure() {
229 let tool = MockTool::with_failure("failing_tool", "A failing tool");
230 assert_eq!(tool.name(), "failing_tool");
231 assert_eq!(tool.description(), "A failing tool");
232 assert!(tool.should_fail);
233 }
234
235 #[test]
236 fn test_mock_tool_args_schema() {
237 let tool = MockTool::new("schema_tool", "Schema test");
238 let schema = tool.args_schema();
239
240 assert_eq!(schema["type"], "object");
241 assert!(schema["properties"].is_object());
242 assert!(schema["properties"]["name"].is_object());
243 assert!(schema["properties"]["value"].is_object());
244 assert_eq!(schema["properties"]["name"]["type"], "string");
245 assert_eq!(schema["properties"]["value"]["type"], "integer");
246 }
247
248 #[tokio::test]
249 async fn test_mock_tool_run_success() {
250 let tool = MockTool::new("success_tool", "Success test");
251 let input = json!({
252 "name": "test",
253 "value": 42
254 });
255
256 let result = tool.execute(input).await;
257 assert!(result.is_ok());
258
259 let output = result.unwrap();
260 assert_eq!(output["processed_name"], "test");
261 assert_eq!(output["doubled_value"], 84);
262 }
263
264 #[tokio::test]
265 async fn test_mock_tool_run_failure() {
266 let tool = MockTool::with_failure("failure_tool", "Failure test");
267 let input = json!({
268 "name": "test",
269 "value": 42
270 });
271
272 let result = tool.execute(input).await;
273 assert!(result.is_err());
274 assert!(result
275 .unwrap_err()
276 .to_string()
277 .contains("Mock tool failure"));
278 }
279
280 #[tokio::test]
281 async fn test_mock_tool_run_invalid_input() {
282 let tool = MockTool::new("invalid_input_tool", "Invalid input test");
283 let input = json!({
284 "invalid_field": "test"
285 });
286
287 let result = tool.execute(input).await;
288 assert!(result.is_err());
289 assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
290 }
291
292 #[tokio::test]
293 async fn test_mock_tool_run_with_extra_fields() {
294 let tool = MockTool::new("extra_fields_tool", "Extra fields test");
295 let input = json!({
296 "name": "test",
297 "value": 42,
298 "extra_field": "ignored"
299 });
300
301 let result = tool.execute(input).await;
302 assert!(result.is_ok());
303
304 let output = result.unwrap();
305 assert_eq!(output["processed_name"], "test");
306 assert_eq!(output["doubled_value"], 84);
307 }
308
309 #[test]
310 fn test_mock_tool_debug() {
311 let tool = MockTool::new("debug_tool", "Debug test");
312 let debug_str = format!("{tool:?}");
313 assert!(debug_str.contains("MockTool"));
314 assert!(debug_str.contains("debug_tool"));
315 }
316
317 #[test]
318 fn test_tool_input_trait() {
319 let schema = TestInput::io_schema();
320 assert!(schema.contains("object"));
321 assert!(schema.contains("name"));
322 assert!(schema.contains("value"));
323 assert!(schema.contains("string"));
324 assert!(schema.contains("integer"));
325 }
326
327 #[test]
328 fn test_test_input_serialization() {
329 let input = TestInput {
330 name: "test".to_string(),
331 value: 42,
332 };
333 let serialized = serde_json::to_string(&input).unwrap();
334 assert!(serialized.contains("test"));
335 assert!(serialized.contains("42"));
336 }
337
338 #[test]
339 fn test_test_input_deserialization() {
340 let json = r#"{"name":"test","value":42}"#;
341 let input: TestInput = serde_json::from_str(json).unwrap();
342 assert_eq!(input.name, "test");
343 assert_eq!(input.value, 42);
344 }
345
346 #[test]
347 fn test_test_input_debug() {
348 let input = TestInput {
349 name: "debug".to_string(),
350 value: 123,
351 };
352 let debug_str = format!("{input:?}");
353 assert!(debug_str.contains("TestInput"));
354 assert!(debug_str.contains("debug"));
355 assert!(debug_str.contains("123"));
356 }
357
358 #[test]
359 fn test_boxed_tool_to_tool_conversion() {
360 let mock_tool = MockTool::new("convert_tool", "Conversion test");
361 let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
362
363 let tool: Tool = to_llm_tool(&boxed_tool);
364 assert_eq!(tool.tool_type, "function");
365 assert_eq!(tool.function.name, "convert_tool");
366 assert_eq!(tool.function.description, "Conversion test");
367 assert_eq!(tool.function.parameters["type"], "object");
368 }
369
370 #[test]
371 fn test_tool_conversion_preserves_schema() {
372 let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
373 let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
374
375 let tool: Tool = to_llm_tool(&boxed_tool);
376 let schema = &tool.function.parameters;
377
378 assert_eq!(schema["type"], "object");
379 assert_eq!(schema["properties"]["name"]["type"], "string");
380 assert_eq!(schema["properties"]["value"]["type"], "integer");
381 assert_eq!(schema["required"][0], "name");
382 assert_eq!(schema["required"][1], "value");
383 }
384
385 #[test]
386 fn test_tool_trait_object_usage() {
387 let tools: Vec<Box<dyn ToolT>> = vec![
388 Box::new(MockTool::new("tool1", "First tool")),
389 Box::new(MockTool::new("tool2", "Second tool")),
390 Box::new(MockTool::with_failure("tool3", "Third tool")),
391 ];
392
393 for tool in &tools {
394 assert!(!tool.name().is_empty());
395 assert!(!tool.description().is_empty());
396 assert!(tool.args_schema().is_object());
397 }
398 }
399
400 #[tokio::test]
401 async fn test_tool_run_with_different_inputs() {
402 let tool = MockTool::new("varied_input_tool", "Varied input test");
403
404 let inputs = vec![
405 json!({"name": "test1", "value": 1}),
406 json!({"name": "test2", "value": -5}),
407 json!({"name": "", "value": 0}),
408 json!({"name": "long_name_test", "value": 999999}),
409 ];
410
411 for input in inputs {
412 let result = tool.execute(input.clone()).await;
413 assert!(result.is_ok());
414
415 let output = result.unwrap();
416 assert_eq!(output["processed_name"], input["name"]);
417 assert_eq!(
418 output["doubled_value"],
419 input["value"].as_i64().unwrap() * 2
420 );
421 }
422 }
423
424 #[test]
425 fn test_tool_error_chaining() {
426 let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
427 let tool_error = ToolCallError::SerdeError(json_error);
428
429 use std::error::Error;
431 assert!(tool_error.source().is_some());
432 }
433
434 #[test]
435 fn test_tool_with_empty_name() {
436 let tool = MockTool::new("", "Empty name test");
437 assert_eq!(tool.name(), "");
438 assert_eq!(tool.description(), "Empty name test");
439 }
440
441 #[test]
442 fn test_tool_with_empty_description() {
443 let tool = MockTool::new("empty_desc", "");
444 assert_eq!(tool.name(), "empty_desc");
445 assert_eq!(tool.description(), "");
446 }
447
448 #[test]
449 fn test_tool_schema_complex() {
450 let tool = MockTool::new("complex_tool", "Complex schema test");
451 let schema = tool.args_schema();
452
453 assert!(schema.is_object());
455 assert!(schema["properties"].is_object());
456 assert!(schema["required"].is_array());
457 assert_eq!(schema["required"].as_array().unwrap().len(), 2);
458 }
459
460 #[test]
461 fn test_multiple_tool_instances() {
462 let tool1 = MockTool::new("tool1", "First instance");
463 let tool2 = MockTool::new("tool2", "Second instance");
464
465 assert_ne!(tool1.name(), tool2.name());
466 assert_ne!(tool1.description(), tool2.description());
467
468 assert_eq!(tool1.args_schema(), tool2.args_schema());
470 }
471
472 #[test]
473 fn test_tool_send_sync() {
474 fn assert_send_sync<T: Send + Sync>() {}
475 assert_send_sync::<MockTool>();
476 }
477
478 #[test]
479 fn test_tool_trait_object_send_sync() {
480 fn assert_send_sync<T: Send + Sync>() {}
481 assert_send_sync::<Box<dyn ToolT>>();
482 }
483
484 #[test]
485 fn test_tool_call_result_creation() {
486 let result = ToolCallResult {
487 tool_name: "test_tool".to_string(),
488 success: true,
489 arguments: json!({"param": "value"}),
490 result: json!({"output": "success"}),
491 };
492
493 assert_eq!(result.tool_name, "test_tool");
494 assert!(result.success);
495 assert_eq!(result.arguments, json!({"param": "value"}));
496 assert_eq!(result.result, json!({"output": "success"}));
497 }
498
499 #[test]
500 fn test_tool_call_result_serialization() {
501 let result = ToolCallResult {
502 tool_name: "serialize_tool".to_string(),
503 success: false,
504 arguments: json!({"input": "test"}),
505 result: json!({"error": "failed"}),
506 };
507
508 let serialized = serde_json::to_string(&result).unwrap();
509 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
510
511 assert_eq!(deserialized.tool_name, "serialize_tool");
512 assert!(!deserialized.success);
513 assert_eq!(deserialized.arguments, json!({"input": "test"}));
514 assert_eq!(deserialized.result, json!({"error": "failed"}));
515 }
516
517 #[test]
518 fn test_tool_call_result_clone() {
519 let result = ToolCallResult {
520 tool_name: "clone_tool".to_string(),
521 success: true,
522 arguments: json!({"data": [1, 2, 3]}),
523 result: json!({"processed": [2, 4, 6]}),
524 };
525
526 let cloned = result.clone();
527 assert_eq!(result.tool_name, cloned.tool_name);
528 assert_eq!(result.success, cloned.success);
529 assert_eq!(result.arguments, cloned.arguments);
530 assert_eq!(result.result, cloned.result);
531 }
532
533 #[test]
534 fn test_tool_call_result_debug() {
535 let result = ToolCallResult {
536 tool_name: "debug_tool".to_string(),
537 success: true,
538 arguments: json!({}),
539 result: json!(null),
540 };
541
542 let debug_str = format!("{result:?}");
543 assert!(debug_str.contains("ToolCallResult"));
544 assert!(debug_str.contains("debug_tool"));
545 }
546
547 #[test]
548 fn test_tool_call_result_with_null_values() {
549 let result = ToolCallResult {
550 tool_name: "null_tool".to_string(),
551 success: false,
552 arguments: json!(null),
553 result: json!(null),
554 };
555
556 let serialized = serde_json::to_string(&result).unwrap();
557 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
558
559 assert_eq!(deserialized.tool_name, "null_tool");
560 assert!(!deserialized.success);
561 assert_eq!(deserialized.arguments, json!(null));
562 assert_eq!(deserialized.result, json!(null));
563 }
564
565 #[test]
566 fn test_tool_call_result_with_complex_data() {
567 let complex_args = json!({
568 "nested": {
569 "array": [1, 2, {"key": "value"}],
570 "string": "test",
571 "number": 42.5
572 }
573 });
574
575 let complex_result = json!({
576 "status": "completed",
577 "data": {
578 "items": ["a", "b", "c"],
579 "count": 3
580 }
581 });
582
583 let result = ToolCallResult {
584 tool_name: "complex_tool".to_string(),
585 success: true,
586 arguments: complex_args.clone(),
587 result: complex_result.clone(),
588 };
589
590 let serialized = serde_json::to_string(&result).unwrap();
591 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
592
593 assert_eq!(deserialized.arguments, complex_args);
594 assert_eq!(deserialized.result, complex_result);
595 }
596
597 #[test]
598 fn test_tool_call_result_empty_tool_name() {
599 let result = ToolCallResult {
600 tool_name: String::new(),
601 success: true,
602 arguments: json!({}),
603 result: json!({}),
604 };
605
606 assert!(result.tool_name.is_empty());
607 assert!(result.success);
608 }
609
610 #[test]
611 fn test_tool_call_result_large_data() {
612 let large_string = "x".repeat(10000);
613 let result = ToolCallResult {
614 tool_name: "large_tool".to_string(),
615 success: true,
616 arguments: json!({"large_param": large_string}),
617 result: json!({"processed": true}),
618 };
619
620 let serialized = serde_json::to_string(&result).unwrap();
621 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
622
623 assert_eq!(deserialized.tool_name, "large_tool");
624 assert!(deserialized.success);
625 assert!(
626 deserialized.arguments["large_param"]
627 .as_str()
628 .unwrap()
629 .len()
630 == 10000
631 );
632 }
633
634 #[test]
635 fn test_tool_call_result_equality() {
636 let result1 = ToolCallResult {
637 tool_name: "equal_tool".to_string(),
638 success: true,
639 arguments: json!({"param": "value"}),
640 result: json!({"output": "result"}),
641 };
642
643 let result2 = ToolCallResult {
644 tool_name: "equal_tool".to_string(),
645 success: true,
646 arguments: json!({"param": "value"}),
647 result: json!({"output": "result"}),
648 };
649
650 let result3 = ToolCallResult {
651 tool_name: "different_tool".to_string(),
652 success: true,
653 arguments: json!({"param": "value"}),
654 result: json!({"output": "result"}),
655 };
656
657 let serialized1 = serde_json::to_string(&result1).unwrap();
659 let serialized2 = serde_json::to_string(&result2).unwrap();
660 let serialized3 = serde_json::to_string(&result3).unwrap();
661
662 assert_eq!(serialized1, serialized2);
663 assert_ne!(serialized1, serialized3);
664 }
665
666 #[test]
667 fn test_tool_call_result_with_unicode() {
668 let result = ToolCallResult {
669 tool_name: "unicode_tool".to_string(),
670 success: true,
671 arguments: json!({"message": "Hello δΈη! π"}),
672 result: json!({"response": "Processed: Hello δΈη! π"}),
673 };
674
675 let serialized = serde_json::to_string(&result).unwrap();
676 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
677
678 assert_eq!(deserialized.arguments["message"], "Hello δΈη! π");
679 assert_eq!(deserialized.result["response"], "Processed: Hello δΈη! π");
680 }
681
682 #[test]
683 fn test_tool_call_result_with_arrays() {
684 let result = ToolCallResult {
685 tool_name: "array_tool".to_string(),
686 success: true,
687 arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
688 result: json!({"sum": 15, "count": 5}),
689 };
690
691 let serialized = serde_json::to_string(&result).unwrap();
692 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
693
694 assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
695 assert_eq!(deserialized.result["sum"], 15);
696 assert_eq!(deserialized.result["count"], 5);
697 }
698
699 #[test]
700 fn test_tool_call_result_boolean_values() {
701 let result = ToolCallResult {
702 tool_name: "bool_tool".to_string(),
703 success: false,
704 arguments: json!({"enabled": true, "debug": false}),
705 result: json!({"valid": false, "error": true}),
706 };
707
708 let serialized = serde_json::to_string(&result).unwrap();
709 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
710
711 assert!(!deserialized.success);
712 assert_eq!(deserialized.arguments["enabled"], true);
713 assert_eq!(deserialized.arguments["debug"], false);
714 assert_eq!(deserialized.result["valid"], false);
715 assert_eq!(deserialized.result["error"], true);
716 }
717}