llm_chain/tools/
multitool.rs

1#[macro_export]
2macro_rules! multitool {
3    ($multitool:ident, $input:ident, $output:ident, $error:ident, $($tool:ident, $tool_input:ident, $tool_output:ident, $tool_error:ident),+) => {
4        #[derive(Serialize, Deserialize)]
5        enum $input {
6            $($tool_input($tool_input)),+
7        }
8
9        $(
10            impl From<$tool_input> for $input {
11                fn from(tool: $tool_input) -> Self {
12                    $input::$tool_input(tool)
13                }
14            }
15        )+
16
17        $(
18            impl TryInto<$tool_input> for $input {
19                type Error = $error;
20                fn try_into(self) -> Result<$tool_input, Self::Error> {
21                    if let $input::$tool_input(t) = self {
22                        Ok(t)
23                    } else {
24                        Err($error::BadVariant)
25                    }
26                }
27            }
28        )+
29
30        #[derive(Serialize, Deserialize)]
31        enum $output {
32            $($tool_output($tool_output)),+
33        }
34
35        $(
36            impl From<$tool_output> for $output {
37                fn from(tool: $tool_output) -> Self {
38                    $output::$tool_output(tool)
39                }
40            }
41        )+
42
43        $(
44            impl TryInto<$tool_output> for $output {
45                type Error = $error;
46                fn try_into(self) -> Result<$tool_output, Self::Error> {
47                    if let $output::$tool_output(t) = self {
48                        Ok(t)
49                    } else {
50                        Err($error::BadVariant)
51                    }
52                }
53            }
54        )+
55
56        #[derive(Debug, Error)]
57        enum $error {
58            #[error("Could not convert")]
59            BadVariant,
60            #[error(transparent)]
61            YamlError(#[from] serde_yaml::Error),
62            $(#[error(transparent)]
63            $tool_error(#[from] $tool_error)),+
64        }
65
66        impl ToolError for $error {}
67
68        enum $multitool {
69            $($tool($tool)),+
70        }
71
72        $(
73            impl From<$tool> for $multitool {
74                fn from(tool: $tool) -> Self {
75                    $multitool::$tool(tool)
76                }
77            }
78        )+
79
80        #[async_trait]
81        impl Tool for $multitool {
82            type Input = $input;
83            type Output = $output;
84            type Error = $error;
85
86            async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
87
88                match (self, input) {
89                    $(($multitool::$tool(t), $input::$tool_input(i)) => {
90                            t.invoke_typed(i).await.map(|o| <$tool_output as Into<Self::Output>>::into(o)).map_err(|e| e.into())
91                        }
92                    ),+
93                    _ => Err($error::BadVariant)
94                }
95            }
96
97            /// Returns the `ToolDescription` containing metadata about the tool.
98            fn description(&self) -> ToolDescription {
99                match self {
100                    $($multitool::$tool(t) => t.description()),+
101                }
102            }
103
104            /// Invokes the tool with the given YAML-formatted input.
105            ///
106            /// # Errors
107            ///
108            /// Returns an `ToolUseError` if the input is not in the expected format or if the tool
109            /// fails to produce a valid output.
110            async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, Self::Error> {
111                match self {
112                    $($multitool::$tool(t) => t.invoke(input).await.map_err(|e| e.into())),+
113                }
114            }
115
116            /// Checks whether the tool matches the given name.
117            ///
118            /// This function is used to find the appropriate tool in a `ToolCollection` based on its name.
119            fn matches(&self, name: &str) -> bool {
120                match self {
121                    $($multitool::$tool(t) => t.description().name == name),+
122                }
123            }
124        }
125    };
126}