1use autoagents_llm::chat::{FunctionTool, Tool};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::fmt::Debug;
5use std::sync::Arc;
6mod runtime;
7use async_trait::async_trait;
8pub use runtime::ToolRuntime;
9
10#[cfg(feature = "wasmtime")]
11pub use runtime::{WasmRuntime, WasmRuntimeError};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolCallResult {
17 pub tool_name: String,
18 pub success: bool,
19 pub arguments: Value,
20 pub result: Value,
21}
22
23#[derive(Debug, thiserror::Error)]
24pub enum ToolCallError {
25 #[error("Runtime Error {0}")]
26 RuntimeError(#[from] Box<dyn std::error::Error + Sync + Send>),
27
28 #[error("Serde Error {0}")]
29 SerdeError(#[from] serde_json::Error),
30}
31
32pub trait ToolT: Send + Sync + Debug + ToolRuntime {
33 fn name(&self) -> &'static str;
35 fn description(&self) -> &'static str;
37 fn args_schema(&self) -> Value;
39}
40
41pub trait ToolInputT {
43 fn io_schema() -> &'static str;
44}
45
46#[derive(Debug)]
51pub struct SharedTool {
52 inner: Arc<dyn ToolT>,
53}
54
55impl SharedTool {
56 pub fn new(tool: Arc<dyn ToolT>) -> Self {
58 Self { inner: tool }
59 }
60}
61
62#[async_trait]
63impl ToolRuntime for SharedTool {
64 async fn execute(&self, args: Value) -> Result<Value, ToolCallError> {
65 self.inner.execute(args).await
66 }
67}
68
69impl ToolT for SharedTool {
70 fn name(&self) -> &'static str {
71 self.inner.name()
72 }
73
74 fn description(&self) -> &'static str {
75 self.inner.description()
76 }
77
78 fn args_schema(&self) -> Value {
79 self.inner.args_schema()
80 }
81}
82
83pub fn shared_tools_to_boxes(tools: &[Arc<dyn ToolT>]) -> Vec<Box<dyn ToolT>> {
88 tools
89 .iter()
90 .map(|t| Box::new(SharedTool::new(Arc::clone(t))) as Box<dyn ToolT>)
91 .collect()
92}
93
94#[allow(clippy::borrowed_box)]
96pub fn to_llm_tool(tool: &Box<dyn ToolT>) -> Tool {
97 Tool {
98 tool_type: "function".to_string(),
99 function: FunctionTool {
100 name: tool.name().to_string(),
101 description: tool.description().to_string(),
102 parameters: tool.args_schema(),
103 },
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use autoagents_llm::chat::Tool;
111 use serde::{Deserialize, Serialize};
112 use serde_json::json;
113
114 #[derive(Debug, Serialize, Deserialize)]
115 struct TestInput {
116 name: String,
117 value: i32,
118 }
119
120 impl ToolInputT for TestInput {
121 fn io_schema() -> &'static str {
122 r#"{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"integer"}},"required":["name","value"]}"#
123 }
124 }
125
126 #[derive(Debug)]
127 struct MockTool {
128 name: &'static str,
129 description: &'static str,
130 should_fail: bool,
131 }
132
133 impl MockTool {
134 fn new(name: &'static str, description: &'static str) -> Self {
135 Self {
136 name,
137 description,
138 should_fail: false,
139 }
140 }
141
142 fn with_failure(name: &'static str, description: &'static str) -> Self {
143 Self {
144 name,
145 description,
146 should_fail: true,
147 }
148 }
149 }
150
151 impl ToolT for MockTool {
152 fn name(&self) -> &'static str {
153 self.name
154 }
155
156 fn description(&self) -> &'static str {
157 self.description
158 }
159
160 fn args_schema(&self) -> Value {
161 json!({
162 "type": "object",
163 "properties": {
164 "name": {"type": "string"},
165 "value": {"type": "integer"}
166 },
167 "required": ["name", "value"]
168 })
169 }
170 }
171
172 #[async_trait]
173 impl ToolRuntime for MockTool {
174 async fn execute(
175 &self,
176 args: serde_json::Value,
177 ) -> Result<serde_json::Value, ToolCallError> {
178 if self.should_fail {
179 return Err(ToolCallError::RuntimeError(
180 "Mock tool failure".to_string().into(),
181 ));
182 }
183
184 let input: TestInput = serde_json::from_value(args)?;
185 Ok(json!({
186 "processed_name": input.name,
187 "doubled_value": input.value * 2
188 }))
189 }
190 }
191
192 #[test]
193 fn test_tool_call_error_runtime_error() {
194 let error = ToolCallError::RuntimeError("Runtime error".to_string().into());
195 assert_eq!(error.to_string(), "Runtime Error Runtime error");
196 }
197
198 #[test]
199 fn test_tool_call_error_serde_error() {
200 let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
201 let error = ToolCallError::SerdeError(json_error);
202 assert!(error.to_string().contains("Serde Error"));
203 }
204
205 #[test]
206 fn test_tool_call_error_debug() {
207 let error = ToolCallError::RuntimeError("Debug test".to_string().into());
208 let debug_str = format!("{error:?}");
209 assert!(debug_str.contains("RuntimeError"));
210 }
211
212 #[test]
213 fn test_tool_call_error_from_serde() {
214 let json_error = serde_json::from_str::<Value>("invalid json").unwrap_err();
215 let error: ToolCallError = json_error.into();
216 assert!(matches!(error, ToolCallError::SerdeError(_)));
217 }
218
219 #[test]
220 fn test_tool_call_error_from_box_error() {
221 let box_error: Box<dyn std::error::Error + Send + Sync> = "Test error".into();
222 let error: ToolCallError = box_error.into();
223 assert!(matches!(error, ToolCallError::RuntimeError(_)));
224 }
225
226 #[test]
227 fn test_mock_tool_creation() {
228 let tool = MockTool::new("test_tool", "A test tool");
229 assert_eq!(tool.name(), "test_tool");
230 assert_eq!(tool.description(), "A test tool");
231 assert!(!tool.should_fail);
232 }
233
234 #[test]
235 fn test_mock_tool_with_failure() {
236 let tool = MockTool::with_failure("failing_tool", "A failing tool");
237 assert_eq!(tool.name(), "failing_tool");
238 assert_eq!(tool.description(), "A failing tool");
239 assert!(tool.should_fail);
240 }
241
242 #[test]
243 fn test_mock_tool_args_schema() {
244 let tool = MockTool::new("schema_tool", "Schema test");
245 let schema = tool.args_schema();
246
247 assert_eq!(schema["type"], "object");
248 assert!(schema["properties"].is_object());
249 assert!(schema["properties"]["name"].is_object());
250 assert!(schema["properties"]["value"].is_object());
251 assert_eq!(schema["properties"]["name"]["type"], "string");
252 assert_eq!(schema["properties"]["value"]["type"], "integer");
253 }
254
255 #[tokio::test]
256 async fn test_mock_tool_run_success() {
257 let tool = MockTool::new("success_tool", "Success test");
258 let input = json!({
259 "name": "test",
260 "value": 42
261 });
262
263 let result = tool.execute(input).await;
264 assert!(result.is_ok());
265
266 let output = result.unwrap();
267 assert_eq!(output["processed_name"], "test");
268 assert_eq!(output["doubled_value"], 84);
269 }
270
271 #[tokio::test]
272 async fn test_mock_tool_run_failure() {
273 let tool = MockTool::with_failure("failure_tool", "Failure test");
274 let input = json!({
275 "name": "test",
276 "value": 42
277 });
278
279 let result = tool.execute(input).await;
280 assert!(result.is_err());
281 assert!(result
282 .unwrap_err()
283 .to_string()
284 .contains("Mock tool failure"));
285 }
286
287 #[tokio::test]
288 async fn test_mock_tool_run_invalid_input() {
289 let tool = MockTool::new("invalid_input_tool", "Invalid input test");
290 let input = json!({
291 "invalid_field": "test"
292 });
293
294 let result = tool.execute(input).await;
295 assert!(result.is_err());
296 assert!(matches!(result.unwrap_err(), ToolCallError::SerdeError(_)));
297 }
298
299 #[tokio::test]
300 async fn test_mock_tool_run_with_extra_fields() {
301 let tool = MockTool::new("extra_fields_tool", "Extra fields test");
302 let input = json!({
303 "name": "test",
304 "value": 42,
305 "extra_field": "ignored"
306 });
307
308 let result = tool.execute(input).await;
309 assert!(result.is_ok());
310
311 let output = result.unwrap();
312 assert_eq!(output["processed_name"], "test");
313 assert_eq!(output["doubled_value"], 84);
314 }
315
316 #[test]
317 fn test_mock_tool_debug() {
318 let tool = MockTool::new("debug_tool", "Debug test");
319 let debug_str = format!("{tool:?}");
320 assert!(debug_str.contains("MockTool"));
321 assert!(debug_str.contains("debug_tool"));
322 }
323
324 #[test]
325 fn test_tool_input_trait() {
326 let schema = TestInput::io_schema();
327 assert!(schema.contains("object"));
328 assert!(schema.contains("name"));
329 assert!(schema.contains("value"));
330 assert!(schema.contains("string"));
331 assert!(schema.contains("integer"));
332 }
333
334 #[test]
335 fn test_test_input_serialization() {
336 let input = TestInput {
337 name: "test".to_string(),
338 value: 42,
339 };
340 let serialized = serde_json::to_string(&input).unwrap();
341 assert!(serialized.contains("test"));
342 assert!(serialized.contains("42"));
343 }
344
345 #[test]
346 fn test_test_input_deserialization() {
347 let json = r#"{"name":"test","value":42}"#;
348 let input: TestInput = serde_json::from_str(json).unwrap();
349 assert_eq!(input.name, "test");
350 assert_eq!(input.value, 42);
351 }
352
353 #[test]
354 fn test_test_input_debug() {
355 let input = TestInput {
356 name: "debug".to_string(),
357 value: 123,
358 };
359 let debug_str = format!("{input:?}");
360 assert!(debug_str.contains("TestInput"));
361 assert!(debug_str.contains("debug"));
362 assert!(debug_str.contains("123"));
363 }
364
365 #[test]
366 fn test_boxed_tool_to_tool_conversion() {
367 let mock_tool = MockTool::new("convert_tool", "Conversion test");
368 let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
369
370 let tool: Tool = to_llm_tool(&boxed_tool);
371 assert_eq!(tool.tool_type, "function");
372 assert_eq!(tool.function.name, "convert_tool");
373 assert_eq!(tool.function.description, "Conversion test");
374 assert_eq!(tool.function.parameters["type"], "object");
375 }
376
377 #[test]
378 fn test_tool_conversion_preserves_schema() {
379 let mock_tool = MockTool::new("schema_tool", "Schema preservation test");
380 let boxed_tool: Box<dyn ToolT> = Box::new(mock_tool);
381
382 let tool: Tool = to_llm_tool(&boxed_tool);
383 let schema = &tool.function.parameters;
384
385 assert_eq!(schema["type"], "object");
386 assert_eq!(schema["properties"]["name"]["type"], "string");
387 assert_eq!(schema["properties"]["value"]["type"], "integer");
388 assert_eq!(schema["required"][0], "name");
389 assert_eq!(schema["required"][1], "value");
390 }
391
392 #[test]
393 fn test_tool_trait_object_usage() {
394 let tools: Vec<Box<dyn ToolT>> = vec![
395 Box::new(MockTool::new("tool1", "First tool")),
396 Box::new(MockTool::new("tool2", "Second tool")),
397 Box::new(MockTool::with_failure("tool3", "Third tool")),
398 ];
399
400 for tool in &tools {
401 assert!(!tool.name().is_empty());
402 assert!(!tool.description().is_empty());
403 assert!(tool.args_schema().is_object());
404 }
405 }
406
407 #[tokio::test]
408 async fn test_tool_run_with_different_inputs() {
409 let tool = MockTool::new("varied_input_tool", "Varied input test");
410
411 let inputs = vec![
412 json!({"name": "test1", "value": 1}),
413 json!({"name": "test2", "value": -5}),
414 json!({"name": "", "value": 0}),
415 json!({"name": "long_name_test", "value": 999999}),
416 ];
417
418 for input in inputs {
419 let result = tool.execute(input.clone()).await;
420 assert!(result.is_ok());
421
422 let output = result.unwrap();
423 assert_eq!(output["processed_name"], input["name"]);
424 assert_eq!(
425 output["doubled_value"],
426 input["value"].as_i64().unwrap() * 2
427 );
428 }
429 }
430
431 #[test]
432 fn test_tool_error_chaining() {
433 let json_error = serde_json::from_str::<Value>("invalid").unwrap_err();
434 let tool_error = ToolCallError::SerdeError(json_error);
435
436 use std::error::Error;
438 assert!(tool_error.source().is_some());
439 }
440
441 #[test]
442 fn test_tool_with_empty_name() {
443 let tool = MockTool::new("", "Empty name test");
444 assert_eq!(tool.name(), "");
445 assert_eq!(tool.description(), "Empty name test");
446 }
447
448 #[test]
449 fn test_tool_with_empty_description() {
450 let tool = MockTool::new("empty_desc", "");
451 assert_eq!(tool.name(), "empty_desc");
452 assert_eq!(tool.description(), "");
453 }
454
455 #[test]
456 fn test_tool_schema_complex() {
457 let tool = MockTool::new("complex_tool", "Complex schema test");
458 let schema = tool.args_schema();
459
460 assert!(schema.is_object());
462 assert!(schema["properties"].is_object());
463 assert!(schema["required"].is_array());
464 assert_eq!(schema["required"].as_array().unwrap().len(), 2);
465 }
466
467 #[test]
468 fn test_multiple_tool_instances() {
469 let tool1 = MockTool::new("tool1", "First instance");
470 let tool2 = MockTool::new("tool2", "Second instance");
471
472 assert_ne!(tool1.name(), tool2.name());
473 assert_ne!(tool1.description(), tool2.description());
474
475 assert_eq!(tool1.args_schema(), tool2.args_schema());
477 }
478
479 #[test]
480 fn test_tool_send_sync() {
481 fn assert_send_sync<T: Send + Sync>() {}
482 assert_send_sync::<MockTool>();
483 }
484
485 #[test]
486 fn test_tool_trait_object_send_sync() {
487 fn assert_send_sync<T: Send + Sync>() {}
488 assert_send_sync::<Box<dyn ToolT>>();
489 }
490
491 #[test]
492 fn test_tool_call_result_creation() {
493 let result = ToolCallResult {
494 tool_name: "test_tool".to_string(),
495 success: true,
496 arguments: json!({"param": "value"}),
497 result: json!({"output": "success"}),
498 };
499
500 assert_eq!(result.tool_name, "test_tool");
501 assert!(result.success);
502 assert_eq!(result.arguments, json!({"param": "value"}));
503 assert_eq!(result.result, json!({"output": "success"}));
504 }
505
506 #[test]
507 fn test_tool_call_result_serialization() {
508 let result = ToolCallResult {
509 tool_name: "serialize_tool".to_string(),
510 success: false,
511 arguments: json!({"input": "test"}),
512 result: json!({"error": "failed"}),
513 };
514
515 let serialized = serde_json::to_string(&result).unwrap();
516 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
517
518 assert_eq!(deserialized.tool_name, "serialize_tool");
519 assert!(!deserialized.success);
520 assert_eq!(deserialized.arguments, json!({"input": "test"}));
521 assert_eq!(deserialized.result, json!({"error": "failed"}));
522 }
523
524 #[test]
525 fn test_tool_call_result_clone() {
526 let result = ToolCallResult {
527 tool_name: "clone_tool".to_string(),
528 success: true,
529 arguments: json!({"data": [1, 2, 3]}),
530 result: json!({"processed": [2, 4, 6]}),
531 };
532
533 let cloned = result.clone();
534 assert_eq!(result.tool_name, cloned.tool_name);
535 assert_eq!(result.success, cloned.success);
536 assert_eq!(result.arguments, cloned.arguments);
537 assert_eq!(result.result, cloned.result);
538 }
539
540 #[test]
541 fn test_tool_call_result_debug() {
542 let result = ToolCallResult {
543 tool_name: "debug_tool".to_string(),
544 success: true,
545 arguments: json!({}),
546 result: json!(null),
547 };
548
549 let debug_str = format!("{result:?}");
550 assert!(debug_str.contains("ToolCallResult"));
551 assert!(debug_str.contains("debug_tool"));
552 }
553
554 #[test]
555 fn test_tool_call_result_with_null_values() {
556 let result = ToolCallResult {
557 tool_name: "null_tool".to_string(),
558 success: false,
559 arguments: json!(null),
560 result: json!(null),
561 };
562
563 let serialized = serde_json::to_string(&result).unwrap();
564 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
565
566 assert_eq!(deserialized.tool_name, "null_tool");
567 assert!(!deserialized.success);
568 assert_eq!(deserialized.arguments, json!(null));
569 assert_eq!(deserialized.result, json!(null));
570 }
571
572 #[test]
573 fn test_tool_call_result_with_complex_data() {
574 let complex_args = json!({
575 "nested": {
576 "array": [1, 2, {"key": "value"}],
577 "string": "test",
578 "number": 42.5
579 }
580 });
581
582 let complex_result = json!({
583 "status": "completed",
584 "data": {
585 "items": ["a", "b", "c"],
586 "count": 3
587 }
588 });
589
590 let result = ToolCallResult {
591 tool_name: "complex_tool".to_string(),
592 success: true,
593 arguments: complex_args.clone(),
594 result: complex_result.clone(),
595 };
596
597 let serialized = serde_json::to_string(&result).unwrap();
598 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
599
600 assert_eq!(deserialized.arguments, complex_args);
601 assert_eq!(deserialized.result, complex_result);
602 }
603
604 #[test]
605 fn test_tool_call_result_empty_tool_name() {
606 let result = ToolCallResult {
607 tool_name: String::new(),
608 success: true,
609 arguments: json!({}),
610 result: json!({}),
611 };
612
613 assert!(result.tool_name.is_empty());
614 assert!(result.success);
615 }
616
617 #[test]
618 fn test_tool_call_result_large_data() {
619 let large_string = "x".repeat(10000);
620 let result = ToolCallResult {
621 tool_name: "large_tool".to_string(),
622 success: true,
623 arguments: json!({"large_param": large_string}),
624 result: json!({"processed": true}),
625 };
626
627 let serialized = serde_json::to_string(&result).unwrap();
628 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
629
630 assert_eq!(deserialized.tool_name, "large_tool");
631 assert!(deserialized.success);
632 assert!(
633 deserialized.arguments["large_param"]
634 .as_str()
635 .unwrap()
636 .len()
637 == 10000
638 );
639 }
640
641 #[test]
642 fn test_tool_call_result_equality() {
643 let result1 = ToolCallResult {
644 tool_name: "equal_tool".to_string(),
645 success: true,
646 arguments: json!({"param": "value"}),
647 result: json!({"output": "result"}),
648 };
649
650 let result2 = ToolCallResult {
651 tool_name: "equal_tool".to_string(),
652 success: true,
653 arguments: json!({"param": "value"}),
654 result: json!({"output": "result"}),
655 };
656
657 let result3 = ToolCallResult {
658 tool_name: "different_tool".to_string(),
659 success: true,
660 arguments: json!({"param": "value"}),
661 result: json!({"output": "result"}),
662 };
663
664 let serialized1 = serde_json::to_string(&result1).unwrap();
666 let serialized2 = serde_json::to_string(&result2).unwrap();
667 let serialized3 = serde_json::to_string(&result3).unwrap();
668
669 assert_eq!(serialized1, serialized2);
670 assert_ne!(serialized1, serialized3);
671 }
672
673 #[test]
674 fn test_tool_call_result_with_unicode() {
675 let result = ToolCallResult {
676 tool_name: "unicode_tool".to_string(),
677 success: true,
678 arguments: json!({"message": "Hello δΈη! π"}),
679 result: json!({"response": "Processed: Hello δΈη! π"}),
680 };
681
682 let serialized = serde_json::to_string(&result).unwrap();
683 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
684
685 assert_eq!(deserialized.arguments["message"], "Hello δΈη! π");
686 assert_eq!(deserialized.result["response"], "Processed: Hello δΈη! π");
687 }
688
689 #[test]
690 fn test_tool_call_result_with_arrays() {
691 let result = ToolCallResult {
692 tool_name: "array_tool".to_string(),
693 success: true,
694 arguments: json!({"numbers": [1, 2, 3, 4, 5]}),
695 result: json!({"sum": 15, "count": 5}),
696 };
697
698 let serialized = serde_json::to_string(&result).unwrap();
699 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
700
701 assert_eq!(deserialized.arguments["numbers"], json!([1, 2, 3, 4, 5]));
702 assert_eq!(deserialized.result["sum"], 15);
703 assert_eq!(deserialized.result["count"], 5);
704 }
705
706 #[test]
707 fn test_tool_call_result_boolean_values() {
708 let result = ToolCallResult {
709 tool_name: "bool_tool".to_string(),
710 success: false,
711 arguments: json!({"enabled": true, "debug": false}),
712 result: json!({"valid": false, "error": true}),
713 };
714
715 let serialized = serde_json::to_string(&result).unwrap();
716 let deserialized: ToolCallResult = serde_json::from_str(&serialized).unwrap();
717
718 assert!(!deserialized.success);
719 assert_eq!(deserialized.arguments["enabled"], true);
720 assert_eq!(deserialized.arguments["debug"], false);
721 assert_eq!(deserialized.result["valid"], false);
722 assert_eq!(deserialized.result["error"], true);
723 }
724}