1use autoagents_llm::chat::{FunctionTool, Tool};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::fmt::Debug;
5use std::sync::Arc;
6mod runtime;
7use async_trait::async_trait;
8pub use runtime::ToolRuntime;
9
10#[cfg(feature = "wasmtime")]
11pub use runtime::{WasmRuntime, WasmRuntimeError};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolCallResult {
17 pub tool_name: String,
18 pub success: bool,
19 pub arguments: Value,
20 pub result: Value,
21}
22
23#[derive(Debug, thiserror::Error)]
24pub enum ToolCallError {
25 #[error("Runtime Error {0}")]
26 RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
27
28 #[error("Serde Error {0}")]
29 SerdeError(#[from] serde_json::Error),
30}
31
32pub trait ToolT: Send + Sync + Debug + ToolRuntime {
33 fn name(&self) -> &str;
35 fn description(&self) -> &str;
37 fn args_schema(&self) -> Value;
39}
40
41pub trait ToolInputT {
43 fn io_schema() -> &'static str;
44}
45
46#[derive(Debug)]
51pub struct SharedTool {
52 inner: Arc<dyn ToolT>,
53}
54
55impl SharedTool {
56 pub fn new(tool: Arc<dyn ToolT>) -> Self {
58 Self { inner: tool }
59 }
60}
61
62#[async_trait]
63impl ToolRuntime for SharedTool {
64 async fn execute(&self, args: Value) -> Result<Value, ToolCallError> {
65 self.inner.execute(args).await
66 }
67}
68
69impl ToolT for SharedTool {
70 fn name(&self) -> &str {
71 self.inner.name()
72 }
73
74 fn description(&self) -> &str {
75 self.inner.description()
76 }
77
78 fn args_schema(&self) -> Value {
79 self.inner.args_schema()
80 }
81}
82
83pub fn shared_tools_to_boxes(tools: &[Arc<dyn ToolT>]) -> Vec<Box<dyn ToolT>> {
88 tools
89 .iter()
90 .map(|t| Box::new(SharedTool::new(Arc::clone(t))) as Box<dyn ToolT>)
91 .collect()
92}
93
94#[allow(clippy::borrowed_box)]
96pub fn to_llm_tool(tool: &Box<dyn ToolT>) -> Tool {
97 Tool {
98 tool_type: "function".to_string(),
99 function: FunctionTool {
100 name: tool.name().to_string(),
101 description: tool.description().to_string(),
102 parameters: tool.args_schema(),
103 },
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use autoagents_llm::chat::Tool;
111 use serde::{Deserialize, Serialize};
112 use serde_json::json;
113
114 #[derive(Debug, Serialize, Deserialize)]
115 struct TestInput {
116 name: String,
117 value: i32,
118 }
119
120 impl ToolInputT for TestInput {
121 fn io_schema() -> &'static str {
122 r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
123 }
124 }
125
126 #[derive(Debug)]
127 struct MockTool {
128 name: &'static str,
129 description: &'static str,
130 should_fail: bool,
131 }
132
133 impl MockTool {
134 fn new(name: &'static str, description: &'static str) -> Self {
135 Self {
136 name,
137 description,
138 should_fail: false,
139 }
140 }
141
142 fn with_failure(name: &'static str, description: &'static str) -> Self {
143 Self {
144 name,
145 description,
146 should_fail: true,
147 }
148 }
149 }
150
151 impl ToolT for MockTool {
152 fn name(&self) -> &'static str {
153 self.name
154 }
155
156 fn description(&self) -> &'static str {
157 self.description
158 }
159
160 fn args_schema(&self) -> Value {
161 json!({
162 "type": "object",
163 "properties": {
164 "name": {"type": "string"},
165 "value": {"type": "integer"}
166 },
167 "required": ["name", "value"]
168 })
169 }
170 }
171
172 #[async_trait]
173 impl ToolRuntime for MockTool {
174 async fn execute(
175 &self,
176 args: serde_json::Value,
177 ) -> Result<serde_json::Value, ToolCallError> {
178 if self.should_fail {
179 return Err(ToolCallError::RuntimeError(
180 "Mock tool failure".to_string().into(),
181 ));
182 }
183
184 let input: TestInput = serde_json::from_value(args)?;
185 Ok(json!({
186 "processed_name": input.name,
187 "doubled_value": input.value * 2
188 }))
189 }
190 }
191
192 #[test]
193 fn test_tool_call_error_runtime_error() {
194 let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
195 assert_eq!(error.to_string(), "Runtime Error Runtime error");
196 }
197
198 #[test]
199 fn test_tool_call_error_serde_error() {
200 let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
201 let error = ToolCallError::SerdeError(json_error);
202 assert!(error.to_string().contains("Serde Error"));
203 }
204
205 #[test]
206 fn test_tool_call_error_debug() {
207 let error = ToolCallError::RuntimeError("Debug test".to_string().into());
208 let debug_str = format!("{error:?}");
209 assert!(debug_str.contains("RuntimeError"));
210 }
211
212 #[test]
213 fn test_tool_call_error_from_serde() {
214 let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
215 let error: ToolCallError = json_error.into();
216 assert!(matches!(error, ToolCallError::SerdeError(_)));
217 }
218
219 #[test]
220 fn test_tool_call_error_from_box_error() {
221 let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
222 let error: ToolCallError = box_error.into();
223 assert!(matches!(error, ToolCallError::RuntimeError(_)));
224 }
225
226 #[test]
227 fn test_mock_tool_creation() {
228 let tool = MockTool::new("test_tool", "A test tool");
229 assert_eq!(tool.name(), "test_tool");
230 assert_eq!(tool.description(), "A test tool");
231 assert!(!tool.should_fail);
232 }
233
234 #[test]
235 fn test_mock_tool_with_failure() {
236 let tool = MockTool::with_failure("failing_tool", "A failing tool");
237 assert_eq!(tool.name(), "failing_tool");
238 assert_eq!(tool.description(), "A failing tool");
239 assert!(tool.should_fail);
240 }
241
242 #[test]
243 fn test_mock_tool_args_schema() {
244 let tool = MockTool::new("schema_tool", "Schema test");
245 let schema = tool.args_schema();
246
247 assert_eq!(schema["type"], "object");
248 assert!(schema["properties"].is_object());
249 assert!(schema["properties"]["name"].is_object());
250 assert!(schema["properties"]["value"].is_object());
251 assert_eq!(schema["properties"]["name"]["type"], "string");
252 assert_eq!(schema["properties"]["value"]["type"], "integer");
253 }
254
255 #[tokio::test]
256 async fn test_mock_tool_run_success() {
257 let tool = MockTool::new("success_tool", "Success test");
258 let input = json!({
259 "name": "test",
260 "value": 42
261 });
262
263 let result = tool.execute(input).await;
264 assert!(result.is_ok());
265
266 let output = result.unwrap();
267 assert_eq!(output["processed_name"], "test");
268 assert_eq!(output["doubled_value"], 84);
269 }
270
271 #[tokio::test]
272 async fn test_mock_tool_run_failure() {
273 let tool = MockTool::with_failure("failure_tool", "Failure test");
274 let input = json!({
275 "name": "test",
276 "value": 42
277 });
278
279 let result = tool.execute(input).await;
280 assert!(result.is_err());
281 assert!(
282 result
283 .unwrap_err()
284 .to_string()
285 .contains("Mock tool failure")
286 );
287 }
288
289 #[tokio::test]
290 async fn test_mock_tool_run_invalid_input() {
291 let tool = MockTool::new("invalid_input_tool", "Invalid input test");
292 let input = json!({
293 "invalid_field": "test"
294 });
295
296 let result = tool.execute(input).await;
297 assert!(result.is_err());
298 assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
299 }
300
301 #[tokio::test]
302 async fn test_mock_tool_run_with_extra_fields() {
303 let tool = MockTool::new("extra_fields_tool", "Extra fields test");
304 let input = json!({
305 "name": "test",
306 "value": 42,
307 "extra_field": "ignored"
308 });
309
310 let result = tool.execute(input).await;
311 assert!(result.is_ok());
312
313 let output = result.unwrap();
314 assert_eq!(output["processed_name"], "test");
315 assert_eq!(output["doubled_value"], 84);
316 }
317
318 #[test]
319 fn test_mock_tool_debug() {
320 let tool = MockTool::new("debug_tool", "Debug test");
321 let debug_str = format!("{tool:?}");
322 assert!(debug_str.contains("MockTool"));
323 assert!(debug_str.contains("debug_tool"));
324 }
325
326 #[test]
327 fn test_tool_input_trait() {
328 let schema = TestInput::io_schema();
329 assert!(schema.contains("object"));
330 assert!(schema.contains("name"));
331 assert!(schema.contains("value"));
332 assert!(schema.contains("string"));
333 assert!(schema.contains("integer"));
334 }
335
336 #[test]
337 fn test_test_input_serialization() {
338 let input = TestInput {
339 name: "test".to_string(),
340 value: 42,
341 };
342 let serialized = serde_json::to_string(&input).unwrap();
343 assert!(serialized.contains("test"));
344 assert!(serialized.contains("42"));
345 }
346
347 #[test]
348 fn test_test_input_deserialization() {
349 let json = r#"{"name":"test","value":42}"#;
350 let input: TestInput = serde_json::from_str(json).unwrap();
351 assert_eq!(input.name, "test");
352 assert_eq!(input.value, 42);
353 }
354
355 #[test]
356 fn test_test_input_debug() {
357 let input = TestInput {
358 name: "debug".to_string(),
359 value: 123,
360 };
361 let debug_str = format!("{input:?}");
362 assert!(debug_str.contains("TestInput"));
363 assert!(debug_str.contains("debug"));
364 assert!(debug_str.contains("123"));
365 }
366
367 #[test]
368 fn test_boxed_tool_to_tool_conversion() {
369 let mock_tool = MockTool::new("convert_tool", "Conversion test");
370 let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
371
372 let tool: Tool = to_llm_tool(&boxed_tool);
373 assert_eq!(tool.tool_type, "function");
374 assert_eq!(tool.function.name, "convert_tool");
375 assert_eq!(tool.function.description, "Conversion test");
376 assert_eq!(tool.function.parameters["type"], "object");
377 }
378
379 #[test]
380 fn test_tool_conversion_preserves_schema() {
381 let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
382 let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
383
384 let tool: Tool = to_llm_tool(&boxed_tool);
385 let schema = &tool.function.parameters;
386
387 assert_eq!(schema["type"], "object");
388 assert_eq!(schema["properties"]["name"]["type"], "string");
389 assert_eq!(schema["properties"]["value"]["type"], "integer");
390 assert_eq!(schema["required"][0], "name");
391 assert_eq!(schema["required"][1], "value");
392 }
393
394 #[test]
395 fn test_tool_trait_object_usage() {
396 let tools: Vec<Box<dyn ToolT>> = vec![
397 Box::new(MockTool::new("tool1", "First tool")),
398 Box::new(MockTool::new("tool2", "Second tool")),
399 Box::new(MockTool::with_failure("tool3", "Third tool")),
400 ];
401
402 for tool in &tools {
403 assert!(!tool.name().is_empty());
404 assert!(!tool.description().is_empty());
405 assert!(tool.args_schema().is_object());
406 }
407 }
408
409 #[tokio::test]
410 async fn test_tool_run_with_different_inputs() {
411 let tool = MockTool::new("varied_input_tool", "Varied input test");
412
413 let inputs = vec![
414 json!({"name": "test1", "value": 1}),
415 json!({"name": "test2", "value": -5}),
416 json!({"name": "", "value": 0}),
417 json!({"name": "long_name_test", "value": 999999}),
418 ];
419
420 for input in inputs {
421 let result = tool.execute(input.clone()).await;
422 assert!(result.is_ok());
423
424 let output = result.unwrap();
425 assert_eq!(output["processed_name"], input["name"]);
426 assert_eq!(
427 output["doubled_value"],
428 input["value"].as_i64().unwrap() * 2
429 );
430 }
431 }
432
433 #[test]
434 fn test_tool_error_chaining() {
435 let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
436 let tool_error = ToolCallError::SerdeError(json_error);
437
438 use std::error::Error;
440 assert!(tool_error.source().is_some());
441 }
442
443 #[test]
444 fn test_tool_with_empty_name() {
445 let tool = MockTool::new("", "Empty name test");
446 assert_eq!(tool.name(), "");
447 assert_eq!(tool.description(), "Empty name test");
448 }
449
450 #[test]
451 fn test_tool_with_empty_description() {
452 let tool = MockTool::new("empty_desc", "");
453 assert_eq!(tool.name(), "empty_desc");
454 assert_eq!(tool.description(), "");
455 }
456
457 #[test]
458 fn test_tool_schema_complex() {
459 let tool = MockTool::new("complex_tool", "Complex schema test");
460 let schema = tool.args_schema();
461
462 assert!(schema.is_object());
464 assert!(schema["properties"].is_object());
465 assert!(schema["required"].is_array());
466 assert_eq!(schema["required"].as_array().unwrap().len(), 2);
467 }
468
469 #[test]
470 fn test_multiple_tool_instances() {
471 let tool1 = MockTool::new("tool1", "First instance");
472 let tool2 = MockTool::new("tool2", "Second instance");
473
474 assert_ne!(tool1.name(), tool2.name());
475 assert_ne!(tool1.description(), tool2.description());
476
477 assert_eq!(tool1.args_schema(), tool2.args_schema());
479 }
480
481 #[test]
482 fn test_tool_send_sync() {
483 fn assert_send_sync<T: Send + Sync>() {}
484 assert_send_sync::<MockTool>();
485 }
486
487 #[test]
488 fn test_tool_trait_object_send_sync() {
489 fn assert_send_sync<T: Send + Sync>() {}
490 assert_send_sync::<Box<dyn ToolT>>();
491 }
492
493 #[test]
494 fn test_tool_call_result_creation() {
495 let result = ToolCallResult {
496 tool_name: "test_tool".to_string(),
497 success: true,
498 arguments: json!({"param": "value"}),
499 result: json!({"output": "success"}),
500 };
501
502 assert_eq!(result.tool_name, "test_tool");
503 assert!(result.success);
504 assert_eq!(result.arguments, json!({"param": "value"}));
505 assert_eq!(result.result, json!({"output": "success"}));
506 }
507
508 #[test]
509 fn test_tool_call_result_serialization() {
510 let result = ToolCallResult {
511 tool_name: "serialize_tool".to_string(),
512 success: false,
513 arguments: json!({"input": "test"}),
514 result: json!({"error": "failed"}),
515 };
516
517 let serialized = serde_json::to_string(&result).unwrap();
518 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
519
520 assert_eq!(deserialized.tool_name, "serialize_tool");
521 assert!(!deserialized.success);
522 assert_eq!(deserialized.arguments, json!({"input": "test"}));
523 assert_eq!(deserialized.result, json!({"error": "failed"}));
524 }
525
526 #[test]
527 fn test_tool_call_result_clone() {
528 let result = ToolCallResult {
529 tool_name: "clone_tool".to_string(),
530 success: true,
531 arguments: json!({"data": [1, 2, 3]}),
532 result: json!({"processed": [2, 4, 6]}),
533 };
534
535 let cloned = result.clone();
536 assert_eq!(result.tool_name, cloned.tool_name);
537 assert_eq!(result.success, cloned.success);
538 assert_eq!(result.arguments, cloned.arguments);
539 assert_eq!(result.result, cloned.result);
540 }
541
542 #[test]
543 fn test_tool_call_result_debug() {
544 let result = ToolCallResult {
545 tool_name: "debug_tool".to_string(),
546 success: true,
547 arguments: json!({}),
548 result: json!(null),
549 };
550
551 let debug_str = format!("{result:?}");
552 assert!(debug_str.contains("ToolCallResult"));
553 assert!(debug_str.contains("debug_tool"));
554 }
555
556 #[test]
557 fn test_tool_call_result_with_null_values() {
558 let result = ToolCallResult {
559 tool_name: "null_tool".to_string(),
560 success: false,
561 arguments: json!(null),
562 result: json!(null),
563 };
564
565 let serialized = serde_json::to_string(&result).unwrap();
566 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
567
568 assert_eq!(deserialized.tool_name, "null_tool");
569 assert!(!deserialized.success);
570 assert_eq!(deserialized.arguments, json!(null));
571 assert_eq!(deserialized.result, json!(null));
572 }
573
574 #[test]
575 fn test_tool_call_result_with_complex_data() {
576 let complex_args = json!({
577 "nested": {
578 "array": [1, 2, {"key": "value"}],
579 "string": "test",
580 "number": 42.5
581 }
582 });
583
584 let complex_result = json!({
585 "status": "completed",
586 "data": {
587 "items": ["a", "b", "c"],
588 "count": 3
589 }
590 });
591
592 let result = ToolCallResult {
593 tool_name: "complex_tool".to_string(),
594 success: true,
595 arguments: complex_args.clone(),
596 result: complex_result.clone(),
597 };
598
599 let serialized = serde_json::to_string(&result).unwrap();
600 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
601
602 assert_eq!(deserialized.arguments, complex_args);
603 assert_eq!(deserialized.result, complex_result);
604 }
605
606 #[test]
607 fn test_tool_call_result_empty_tool_name() {
608 let result = ToolCallResult {
609 tool_name: String::new(),
610 success: true,
611 arguments: json!({}),
612 result: json!({}),
613 };
614
615 assert!(result.tool_name.is_empty());
616 assert!(result.success);
617 }
618
619 #[test]
620 fn test_tool_call_result_large_data() {
621 let large_string = "x".repeat(10000);
622 let result = ToolCallResult {
623 tool_name: "large_tool".to_string(),
624 success: true,
625 arguments: json!({"large_param": large_string}),
626 result: json!({"processed": true}),
627 };
628
629 let serialized = serde_json::to_string(&result).unwrap();
630 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
631
632 assert_eq!(deserialized.tool_name, "large_tool");
633 assert!(deserialized.success);
634 assert!(
635 deserialized.arguments["large_param"]
636 .as_str()
637 .unwrap()
638 .len()
639 == 10000
640 );
641 }
642
643 #[test]
644 fn test_tool_call_result_equality() {
645 let result1 = ToolCallResult {
646 tool_name: "equal_tool".to_string(),
647 success: true,
648 arguments: json!({"param": "value"}),
649 result: json!({"output": "result"}),
650 };
651
652 let result2 = ToolCallResult {
653 tool_name: "equal_tool".to_string(),
654 success: true,
655 arguments: json!({"param": "value"}),
656 result: json!({"output": "result"}),
657 };
658
659 let result3 = ToolCallResult {
660 tool_name: "different_tool".to_string(),
661 success: true,
662 arguments: json!({"param": "value"}),
663 result: json!({"output": "result"}),
664 };
665
666 let serialized1 = serde_json::to_string(&result1).unwrap();
668 let serialized2 = serde_json::to_string(&result2).unwrap();
669 let serialized3 = serde_json::to_string(&result3).unwrap();
670
671 assert_eq!(serialized1, serialized2);
672 assert_ne!(serialized1, serialized3);
673 }
674
675 #[test]
676 fn test_tool_call_result_with_unicode() {
677 let result = ToolCallResult {
678 tool_name: "unicode_tool".to_string(),
679 success: true,
680 arguments: json!({"message": "Hello δΈη! π"}),
681 result: json!({"response": "Processed: Hello δΈη! π"}),
682 };
683
684 let serialized = serde_json::to_string(&result).unwrap();
685 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
686
687 assert_eq!(deserialized.arguments["message"], "Hello δΈη! π");
688 assert_eq!(deserialized.result["response"], "Processed: Hello δΈη! π");
689 }
690
691 #[test]
692 fn test_tool_call_result_with_arrays() {
693 let result = ToolCallResult {
694 tool_name: "array_tool".to_string(),
695 success: true,
696 arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
697 result: json!({"sum": 15, "count": 5}),
698 };
699
700 let serialized = serde_json::to_string(&result).unwrap();
701 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
702
703 assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
704 assert_eq!(deserialized.result["sum"], 15);
705 assert_eq!(deserialized.result["count"], 5);
706 }
707
708 #[test]
709 fn test_tool_call_result_boolean_values() {
710 let result = ToolCallResult {
711 tool_name: "bool_tool".to_string(),
712 success: false,
713 arguments: json!({"enabled": true, "debug": false}),
714 result: json!({"valid": false, "error": true}),
715 };
716
717 let serialized = serde_json::to_string(&result).unwrap();
718 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
719
720 assert!(!deserialized.success);
721 assert_eq!(deserialized.arguments["enabled"], true);
722 assert_eq!(deserialized.arguments["debug"], false);
723 assert_eq!(deserialized.result["valid"], false);
724 assert_eq!(deserialized.result["error"], true);
725 }
726}