llm_chain/tools/
multitool.rs1#[macro_export]
2macro_rules! multitool {
3 ($multitool:ident, $input:ident, $output:ident, $error:ident, $($tool:ident, $tool_input:ident, $tool_output:ident, $tool_error:ident),+) => {
4 #[derive(Serialize, Deserialize)]
5 enum $input {
6 $($tool_input($tool_input)),+
7 }
8
9 $(
10 impl From<$tool_input> for $input {
11 fn from(tool: $tool_input) -> Self {
12 $input::$tool_input(tool)
13 }
14 }
15 )+
16
17 $(
18 impl TryInto<$tool_input> for $input {
19 type Error = $error;
20 fn try_into(self) -> Result<$tool_input, Self::Error> {
21 if let $input::$tool_input(t) = self {
22 Ok(t)
23 } else {
24 Err($error::BadVariant)
25 }
26 }
27 }
28 )+
29
30 #[derive(Serialize, Deserialize)]
31 enum $output {
32 $($tool_output($tool_output)),+
33 }
34
35 $(
36 impl From<$tool_output> for $output {
37 fn from(tool: $tool_output) -> Self {
38 $output::$tool_output(tool)
39 }
40 }
41 )+
42
43 $(
44 impl TryInto<$tool_output> for $output {
45 type Error = $error;
46 fn try_into(self) -> Result<$tool_output, Self::Error> {
47 if let $output::$tool_output(t) = self {
48 Ok(t)
49 } else {
50 Err($error::BadVariant)
51 }
52 }
53 }
54 )+
55
56 #[derive(Debug, Error)]
57 enum $error {
58 #[error("Could not convert")]
59 BadVariant,
60 #[error(transparent)]
61 YamlError(#[from] serde_yaml::Error),
62 $(#[error(transparent)]
63 $tool_error(#[from] $tool_error)),+
64 }
65
66 impl ToolError for $error {}
67
68 enum $multitool {
69 $($tool($tool)),+
70 }
71
72 $(
73 impl From<$tool> for $multitool {
74 fn from(tool: $tool) -> Self {
75 $multitool::$tool(tool)
76 }
77 }
78 )+
79
80 #[async_trait]
81 impl Tool for $multitool {
82 type Input = $input;
83 type Output = $output;
84 type Error = $error;
85
86 async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
87
88 match (self, input) {
89 $(($multitool::$tool(t), $input::$tool_input(i)) => {
90 t.invoke_typed(i).await.map(|o| <$tool_output as Into<Self::Output>>::into(o)).map_err(|e| e.into())
91 }
92 ),+
93 _ => Err($error::BadVariant)
94 }
95 }
96
97 fn description(&self) -> ToolDescription {
99 match self {
100 $($multitool::$tool(t) => t.description()),+
101 }
102 }
103
104 async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, Self::Error> {
111 match self {
112 $($multitool::$tool(t) => t.invoke(input).await.map_err(|e| e.into())),+
113 }
114 }
115
116 fn matches(&self, name: &str) -> bool {
120 match self {
121 $($multitool::$tool(t) => t.description().name == name),+
122 }
123 }
124 }
125 };
126}