1use crate::error::{HeliosError, Result};
7use crate::tools::{Tool, ToolParameter, ToolResult};
8use async_trait::async_trait;
9use serde_json::Value;
10use std::collections::HashMap;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::Arc;
14
15pub type ToolFunction =
18 Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send>> + Send + Sync>;
19
20pub struct ToolBuilder {
45 name: String,
46 description: String,
47 parameters: HashMap<String, ToolParameter>,
48 parameter_order: Vec<String>,
49 function: Option<ToolFunction>,
50}
51
52impl ToolBuilder {
53 pub fn new(name: impl Into<String>) -> Self {
59 Self {
60 name: name.into(),
61 description: String::new(),
62 parameters: HashMap::new(),
63 parameter_order: Vec::new(),
64 function: None,
65 }
66 }
67
68 pub fn from_fn<F>(
105 func_name: impl Into<String>,
106 description: impl Into<String>,
107 params: impl Into<String>,
108 func: F,
109 ) -> Self
110 where
111 F: Fn(Value) -> Result<ToolResult> + Send + Sync + 'static,
112 {
113 Self::new(func_name)
114 .description(description)
115 .parameters(params)
116 .sync_function(func)
117 }
118
119 pub fn simple(
143 name: impl Into<String>,
144 description: impl Into<String>,
145 params: impl Into<String>,
146 ) -> Self {
147 Self::new(name).description(description).parameters(params)
148 }
149
150 pub fn from_async_fn<F, Fut>(
186 func_name: impl Into<String>,
187 description: impl Into<String>,
188 params: impl Into<String>,
189 func: F,
190 ) -> Self
191 where
192 F: Fn(Value) -> Fut + Send + Sync + 'static,
193 Fut: Future<Output = Result<ToolResult>> + Send + 'static,
194 {
195 Self::new(func_name)
196 .description(description)
197 .parameters(params)
198 .function(func)
199 }
200
201 pub fn description(mut self, description: impl Into<String>) -> Self {
207 self.description = description.into();
208 self
209 }
210
211 pub fn parameter(
220 mut self,
221 name: impl Into<String>,
222 param_type: impl Into<String>,
223 description: impl Into<String>,
224 required: bool,
225 ) -> Self {
226 self.parameters.insert(
227 name.into(),
228 ToolParameter {
229 param_type: param_type.into(),
230 description: description.into(),
231 required: Some(required),
232 },
233 );
234 self
235 }
236
237 pub fn optional_parameter(
245 self,
246 name: impl Into<String>,
247 param_type: impl Into<String>,
248 description: impl Into<String>,
249 ) -> Self {
250 self.parameter(name, param_type, description, false)
251 }
252
253 pub fn required_parameter(
261 self,
262 name: impl Into<String>,
263 param_type: impl Into<String>,
264 description: impl Into<String>,
265 ) -> Self {
266 self.parameter(name, param_type, description, true)
267 }
268
269 pub fn parameters(mut self, params: impl Into<String>) -> Self {
305 let params_str = params.into();
306
307 for param in params_str.split(',') {
308 let param = param.trim();
309 if param.is_empty() {
310 continue;
311 }
312
313 let parts: Vec<&str> = param.splitn(3, ':').collect();
314 if parts.len() < 2 {
315 continue;
316 }
317
318 let name = parts[0].trim();
319 let param_type = parts[1].trim();
320 let description = if parts.len() >= 3 {
321 parts[2].trim()
322 } else {
323 ""
324 };
325
326 let json_type = match param_type.to_lowercase().as_str() {
328 "i32" | "i64" | "u32" | "u64" | "isize" | "usize" | "integer" => "integer",
329 "f32" | "f64" | "number" => "number",
330 "str" | "string" => "string",
331 "bool" | "boolean" => "boolean",
332 "object" => "object",
333 "array" => "array",
334 _ => param_type, };
336
337 let name_string = name.to_string();
338 self.parameters.insert(
339 name_string.clone(),
340 ToolParameter {
341 param_type: json_type.to_string(),
342 description: description.to_string(),
343 required: Some(true),
344 },
345 );
346 self.parameter_order.push(name_string);
347 }
348
349 self
350 }
351
352 pub fn function<F, Fut>(mut self, f: F) -> Self
361 where
362 F: Fn(Value) -> Fut + Send + Sync + 'static,
363 Fut: Future<Output = Result<ToolResult>> + Send + 'static,
364 {
365 self.function = Some(Arc::new(move |args| Box::pin(f(args))));
366 self
367 }
368
369 pub fn sync_function<F>(mut self, f: F) -> Self
377 where
378 F: Fn(Value) -> Result<ToolResult> + Send + Sync + 'static,
379 {
380 self.function = Some(Arc::new(move |args| {
381 let result = f(args);
382 Box::pin(async move { result })
383 }));
384 self
385 }
386
387 pub fn ftool<F, T1, T2, R>(self, f: F) -> Self
424 where
425 F: Fn(T1, T2) -> R + Send + Sync + 'static,
426 T1: FromValue + Send + 'static,
427 T2: FromValue + Send + 'static,
428 R: ToString + Send + 'static,
429 {
430 let param_order = self.parameter_order.clone();
431 self.sync_function(move |args| {
432 let obj = args.as_object().ok_or_else(|| {
433 HeliosError::ToolError("Expected JSON object for arguments".to_string())
434 })?;
435
436 if param_order.len() < 2 {
437 return Ok(ToolResult::error("Expected at least 2 parameters"));
438 }
439
440 let p1 = obj
441 .get(¶m_order[0])
442 .ok_or_else(|| {
443 HeliosError::ToolError(format!("Missing parameter: {}", param_order[0]))
444 })?
445 .clone();
446 let p2 = obj
447 .get(¶m_order[1])
448 .ok_or_else(|| {
449 HeliosError::ToolError(format!("Missing parameter: {}", param_order[1]))
450 })?
451 .clone();
452
453 let p1 = T1::from_value(p1)?;
454 let p2 = T2::from_value(p2)?;
455
456 let result = f(p1, p2);
457 Ok(ToolResult::success(result.to_string()))
458 })
459 }
460
461 pub fn ftool3<F, T1, T2, T3, R>(self, f: F) -> Self
482 where
483 F: Fn(T1, T2, T3) -> R + Send + Sync + 'static,
484 T1: FromValue + Send + 'static,
485 T2: FromValue + Send + 'static,
486 T3: FromValue + Send + 'static,
487 R: ToString + Send + 'static,
488 {
489 let param_order = self.parameter_order.clone();
490 self.sync_function(move |args| {
491 let obj = args.as_object().ok_or_else(|| {
492 HeliosError::ToolError("Expected JSON object for arguments".to_string())
493 })?;
494
495 if param_order.len() < 3 {
496 return Ok(ToolResult::error("Expected at least 3 parameters"));
497 }
498
499 let p1 = obj
500 .get(¶m_order[0])
501 .ok_or_else(|| {
502 HeliosError::ToolError(format!("Missing parameter: {}", param_order[0]))
503 })?
504 .clone();
505 let p2 = obj
506 .get(¶m_order[1])
507 .ok_or_else(|| {
508 HeliosError::ToolError(format!("Missing parameter: {}", param_order[1]))
509 })?
510 .clone();
511 let p3 = obj
512 .get(¶m_order[2])
513 .ok_or_else(|| {
514 HeliosError::ToolError(format!("Missing parameter: {}", param_order[2]))
515 })?
516 .clone();
517
518 let p1 = T1::from_value(p1)?;
519 let p2 = T2::from_value(p2)?;
520 let p3 = T3::from_value(p3)?;
521
522 let result = f(p1, p2, p3);
523 Ok(ToolResult::success(result.to_string()))
524 })
525 }
526
527 pub fn ftool4<F, T1, T2, T3, T4, R>(self, f: F) -> Self
529 where
530 F: Fn(T1, T2, T3, T4) -> R + Send + Sync + 'static,
531 T1: FromValue + Send + 'static,
532 T2: FromValue + Send + 'static,
533 T3: FromValue + Send + 'static,
534 T4: FromValue + Send + 'static,
535 R: ToString + Send + 'static,
536 {
537 let param_order = self.parameter_order.clone();
538 self.sync_function(move |args| {
539 let obj = args.as_object().ok_or_else(|| {
540 HeliosError::ToolError("Expected JSON object for arguments".to_string())
541 })?;
542
543 if param_order.len() < 4 {
544 return Ok(ToolResult::error("Expected at least 4 parameters"));
545 }
546
547 let p1 = T1::from_value(obj.get(¶m_order[0]).cloned().unwrap_or(Value::Null))?;
548 let p2 = T2::from_value(obj.get(¶m_order[1]).cloned().unwrap_or(Value::Null))?;
549 let p3 = T3::from_value(obj.get(¶m_order[2]).cloned().unwrap_or(Value::Null))?;
550 let p4 = T4::from_value(obj.get(¶m_order[3]).cloned().unwrap_or(Value::Null))?;
551
552 let result = f(p1, p2, p3, p4);
553 Ok(ToolResult::success(result.to_string()))
554 })
555 }
556
557 pub fn build(self) -> Box<dyn Tool> {
563 if self.function.is_none() {
564 panic!("Tool function must be set before building");
565 }
566
567 Box::new(CustomTool {
568 name: self.name,
569 description: self.description,
570 parameters: self.parameters,
571 function: self.function.unwrap(),
572 })
573 }
574
575 pub fn try_build(self) -> Result<Box<dyn Tool>> {
579 if self.function.is_none() {
580 return Err(HeliosError::ConfigError(
581 "Tool function must be set before building".to_string(),
582 ));
583 }
584
585 Ok(Box::new(CustomTool {
586 name: self.name,
587 description: self.description,
588 parameters: self.parameters,
589 function: self.function.unwrap(),
590 }))
591 }
592}
593
594struct CustomTool {
596 name: String,
597 description: String,
598 parameters: std::collections::HashMap<String, ToolParameter>,
599 function: ToolFunction,
600}
601
602#[async_trait]
603impl Tool for CustomTool {
604 fn name(&self) -> &str {
605 &self.name
606 }
607
608 fn description(&self) -> &str {
609 &self.description
610 }
611
612 fn parameters(&self) -> HashMap<String, ToolParameter> {
613 self.parameters.clone()
614 }
615
616 async fn execute(&self, args: Value) -> Result<ToolResult> {
617 (self.function)(args).await
618 }
619}
620
621pub trait FromValue: Sized {
624 fn from_value(value: Value) -> Result<Self>;
625}
626
627impl FromValue for i32 {
628 fn from_value(value: Value) -> Result<Self> {
629 value
630 .as_i64()
631 .map(|n| n as i32)
632 .ok_or_else(|| HeliosError::ToolError("Expected integer value".to_string()))
633 }
634}
635
636impl FromValue for i64 {
637 fn from_value(value: Value) -> Result<Self> {
638 value
639 .as_i64()
640 .ok_or_else(|| HeliosError::ToolError("Expected integer value".to_string()))
641 }
642}
643
644impl FromValue for u32 {
645 fn from_value(value: Value) -> Result<Self> {
646 value
647 .as_u64()
648 .map(|n| n as u32)
649 .ok_or_else(|| HeliosError::ToolError("Expected unsigned integer value".to_string()))
650 }
651}
652
653impl FromValue for u64 {
654 fn from_value(value: Value) -> Result<Self> {
655 value
656 .as_u64()
657 .ok_or_else(|| HeliosError::ToolError("Expected unsigned integer value".to_string()))
658 }
659}
660
661impl FromValue for f32 {
662 fn from_value(value: Value) -> Result<Self> {
663 value
664 .as_f64()
665 .map(|n| n as f32)
666 .ok_or_else(|| HeliosError::ToolError("Expected float value".to_string()))
667 }
668}
669
670impl FromValue for f64 {
671 fn from_value(value: Value) -> Result<Self> {
672 value
673 .as_f64()
674 .ok_or_else(|| HeliosError::ToolError("Expected float value".to_string()))
675 }
676}
677
678impl FromValue for bool {
679 fn from_value(value: Value) -> Result<Self> {
680 value
681 .as_bool()
682 .ok_or_else(|| HeliosError::ToolError("Expected boolean value".to_string()))
683 }
684}
685
686impl FromValue for String {
687 fn from_value(value: Value) -> Result<Self> {
688 value
689 .as_str()
690 .map(|s| s.to_string())
691 .ok_or_else(|| HeliosError::ToolError("Expected string value".to_string()))
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698 use serde_json::json;
699
700 #[tokio::test]
701 async fn test_basic_tool_builder() {
702 async fn add_numbers(args: Value) -> Result<ToolResult> {
703 let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
704 let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
705 Ok(ToolResult::success((a + b).to_string()))
706 }
707
708 let tool = ToolBuilder::new("add")
709 .description("Add two numbers")
710 .parameter("a", "number", "First number", true)
711 .parameter("b", "number", "Second number", true)
712 .function(add_numbers)
713 .build();
714
715 assert_eq!(tool.name(), "add");
716 assert_eq!(tool.description(), "Add two numbers");
717
718 let result = tool.execute(json!({ "a": 5.0, "b": 3.0 })).await.unwrap();
719 assert!(result.success);
720 assert_eq!(result.output, "8");
721 }
722
723 #[tokio::test]
724 async fn test_sync_function_builder() {
725 let tool = ToolBuilder::new("echo")
726 .description("Echo a message")
727 .parameter("message", "string", "Message to echo", true)
728 .sync_function(|args: Value| {
729 let msg = args.get("message").and_then(|v| v.as_str()).unwrap_or("");
730 Ok(ToolResult::success(format!("Echo: {}", msg)))
731 })
732 .build();
733
734 let result = tool.execute(json!({ "message": "hello" })).await.unwrap();
735 assert!(result.success);
736 assert_eq!(result.output, "Echo: hello");
737 }
738
739 #[tokio::test]
740 async fn test_optional_parameters() {
741 let tool = ToolBuilder::new("greet")
742 .description("Greet someone")
743 .required_parameter("name", "string", "Name of person to greet")
744 .optional_parameter("title", "string", "Optional title (Mr, Mrs, etc)")
745 .sync_function(|args: Value| {
746 let name = args
747 .get("name")
748 .and_then(|v| v.as_str())
749 .unwrap_or("stranger");
750 let title = args.get("title").and_then(|v| v.as_str());
751
752 let greeting = if let Some(t) = title {
753 format!("Hello, {} {}!", t, name)
754 } else {
755 format!("Hello, {}!", name)
756 };
757
758 Ok(ToolResult::success(greeting))
759 })
760 .build();
761
762 let result1 = tool.execute(json!({ "name": "Alice" })).await.unwrap();
764 assert_eq!(result1.output, "Hello, Alice!");
765
766 let result2 = tool
768 .execute(json!({ "name": "Smith", "title": "Dr" }))
769 .await
770 .unwrap();
771 assert_eq!(result2.output, "Hello, Dr Smith!");
772 }
773
774 #[tokio::test]
775 async fn test_closure_capture() {
776 let multiplier = 10;
777
778 let tool = ToolBuilder::new("multiply")
779 .description("Multiply a number by a fixed value")
780 .parameter("value", "number", "Value to multiply", true)
781 .sync_function(move |args: Value| {
782 let value = args.get("value").and_then(|v| v.as_f64()).unwrap_or(0.0);
783 Ok(ToolResult::success((value * multiplier as f64).to_string()))
784 })
785 .build();
786
787 let result = tool.execute(json!({ "value": 5.0 })).await.unwrap();
788 assert_eq!(result.output, "50");
789 }
790
791 #[tokio::test]
792 async fn test_error_handling() {
793 let tool = ToolBuilder::new("fail")
794 .description("A tool that fails")
795 .sync_function(|_args: Value| {
796 Err(HeliosError::ToolError("Intentional failure".to_string()))
797 })
798 .build();
799
800 let result = tool.execute(json!({})).await;
801 assert!(result.is_err());
802 }
803
804 #[test]
805 #[should_panic(expected = "Tool function must be set before building")]
806 fn test_build_without_function() {
807 let _tool = ToolBuilder::new("incomplete")
808 .description("This will fail")
809 .build();
810 }
811
812 #[tokio::test]
813 async fn test_try_build_without_function() {
814 let result = ToolBuilder::new("incomplete")
815 .description("This will fail")
816 .try_build();
817
818 assert!(result.is_err());
819 }
820
821 #[tokio::test]
822 async fn test_complex_json_arguments() {
823 let tool = ToolBuilder::new("process_data")
824 .description("Process complex JSON data")
825 .parameter("data", "object", "Data object to process", true)
826 .sync_function(|args: Value| {
827 let data = args
828 .get("data")
829 .ok_or_else(|| HeliosError::ToolError("Missing data parameter".to_string()))?;
830
831 let count = if let Some(obj) = data.as_object() {
832 obj.len()
833 } else {
834 0
835 };
836
837 Ok(ToolResult::success(format!("Processed {} fields", count)))
838 })
839 .build();
840
841 let result = tool
842 .execute(json!({
843 "data": {
844 "field1": "value1",
845 "field2": 42,
846 "field3": true
847 }
848 }))
849 .await
850 .unwrap();
851
852 assert_eq!(result.output, "Processed 3 fields");
853 }
854
855 #[tokio::test]
856 async fn test_parameters_method() {
857 let tool = ToolBuilder::new("calculate_area")
858 .description("Calculate area of a rectangle")
859 .parameters("width:i32:The width, height:i32:The height")
860 .sync_function(|args: Value| {
861 let width = args.get("width").and_then(|v| v.as_i64()).unwrap_or(0);
862 let height = args.get("height").and_then(|v| v.as_i64()).unwrap_or(0);
863 Ok(ToolResult::success(format!("Area: {}", width * height)))
864 })
865 .build();
866
867 assert_eq!(tool.name(), "calculate_area");
868
869 let params = tool.parameters();
870 assert!(params.contains_key("width"));
871 assert!(params.contains_key("height"));
872 assert_eq!(params.get("width").unwrap().param_type, "integer");
873 assert_eq!(params.get("height").unwrap().param_type, "integer");
874
875 let result = tool
876 .execute(json!({"width": 5, "height": 10}))
877 .await
878 .unwrap();
879 assert_eq!(result.output, "Area: 50");
880 }
881
882 #[tokio::test]
883 async fn test_parameters_with_float_types() {
884 let tool = ToolBuilder::new("calculate_volume")
885 .description("Calculate volume")
886 .parameters("width:f64:Width in meters, height:f32:Height in meters, depth:number:Depth in meters")
887 .sync_function(|args: Value| {
888 let width = args.get("width").and_then(|v| v.as_f64()).unwrap_or(0.0);
889 let height = args.get("height").and_then(|v| v.as_f64()).unwrap_or(0.0);
890 let depth = args.get("depth").and_then(|v| v.as_f64()).unwrap_or(0.0);
891 Ok(ToolResult::success(format!("Volume: {:.2}", width * height * depth)))
892 })
893 .build();
894
895 let params = tool.parameters();
896 assert_eq!(params.get("width").unwrap().param_type, "number");
897 assert_eq!(params.get("height").unwrap().param_type, "number");
898 assert_eq!(params.get("depth").unwrap().param_type, "number");
899 }
900
901 #[tokio::test]
902 async fn test_parameters_with_string_and_bool() {
903 let tool = ToolBuilder::new("greet")
904 .description("Greet someone")
905 .parameters("name:string:Person's name, formal:bool:Use formal greeting")
906 .sync_function(|args: Value| {
907 let name = args
908 .get("name")
909 .and_then(|v| v.as_str())
910 .unwrap_or("stranger");
911 let formal = args
912 .get("formal")
913 .and_then(|v| v.as_bool())
914 .unwrap_or(false);
915 let greeting = if formal {
916 format!("Good day, {}", name)
917 } else {
918 format!("Hey {}", name)
919 };
920 Ok(ToolResult::success(greeting))
921 })
922 .build();
923
924 let params = tool.parameters();
925 assert_eq!(params.get("name").unwrap().param_type, "string");
926 assert_eq!(params.get("formal").unwrap().param_type, "boolean");
927 }
928
929 #[tokio::test]
930 async fn test_from_fn() {
931 fn add(a: i32, b: i32) -> i32 {
932 a + b
933 }
934
935 let tool = ToolBuilder::from_fn(
936 "add",
937 "Add two numbers",
938 "a:i32:First number, b:i32:Second number",
939 |args| {
940 let a = args.get("a").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
941 let b = args.get("b").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
942 Ok(ToolResult::success(add(a, b).to_string()))
943 },
944 )
945 .build();
946
947 assert_eq!(tool.name(), "add");
948 assert_eq!(tool.description(), "Add two numbers");
949
950 let result = tool.execute(json!({"a": 3, "b": 7})).await.unwrap();
951 assert_eq!(result.output, "10");
952 }
953
954 #[tokio::test]
955 async fn test_from_async_fn() {
956 async fn fetch_data(id: i32) -> String {
957 format!("Data for ID: {}", id)
958 }
959
960 let tool = ToolBuilder::from_async_fn(
961 "fetch_data",
962 "Fetch data by ID",
963 "id:i32:The ID to fetch",
964 |args| async move {
965 let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
966 Ok(ToolResult::success(fetch_data(id).await))
967 },
968 )
969 .build();
970
971 assert_eq!(tool.name(), "fetch_data");
972
973 let result = tool.execute(json!({"id": 42})).await.unwrap();
974 assert_eq!(result.output, "Data for ID: 42");
975 }
976
977 #[tokio::test]
978 async fn test_parameters_empty_and_whitespace() {
979 let tool = ToolBuilder::new("test")
980 .description("Test tool")
981 .parameters("a:i32:First, , b:i32:Second, , c:string:Third ")
982 .sync_function(|_| Ok(ToolResult::success("ok".to_string())))
983 .build();
984
985 let params = tool.parameters();
986 assert_eq!(params.len(), 3);
988 assert!(params.contains_key("a"));
989 assert!(params.contains_key("b"));
990 assert!(params.contains_key("c"));
991 }
992
993 #[tokio::test]
994 async fn test_parameters_without_description() {
995 let tool = ToolBuilder::new("test")
996 .description("Test tool")
997 .parameters("x:i32, y:i32")
998 .sync_function(|_| Ok(ToolResult::success("ok".to_string())))
999 .build();
1000
1001 let params = tool.parameters();
1002 assert_eq!(params.len(), 2);
1003 assert_eq!(params.get("x").unwrap().description, "");
1004 assert_eq!(params.get("y").unwrap().description, "");
1005 }
1006}