ic_llm/
tool.rs

1use candid::CandidType;
2use serde::{Deserialize, Serialize};
3
4#[derive(CandidType, Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
5pub enum Tool {
6    #[serde(rename = "function")]
7    Function(Function),
8}
9
10#[derive(CandidType, Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
11pub struct Parameters {
12    #[serde(rename = "type")]
13    pub type_: String,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub properties: Option<Vec<Property>>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub required: Option<Vec<String>>,
18}
19
20#[derive(CandidType, Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
21pub struct Function {
22    pub name: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub description: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub parameters: Option<Parameters>,
27}
28
29#[derive(CandidType, Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
30pub struct Property {
31    #[serde(rename = "type")]
32    pub type_: String,
33    pub name: String,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub description: Option<String>,
36    #[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
37    pub enum_: Option<Vec<String>>,
38}
39
40/// Enum representing the types a parameter can have.
41#[derive(Clone, Debug)]
42pub enum ParameterType {
43    String,
44    Boolean,
45    Number,
46    // Can be extended with more types as needed
47}
48
49impl ParameterType {
50    fn as_str(&self) -> &'static str {
51        match self {
52            ParameterType::String => "string",
53            ParameterType::Boolean => "boolean",
54            ParameterType::Number => "number",
55        }
56    }
57}
58
59/// Builder for creating a parameter for a function tool.
60#[derive(Clone, Debug)]
61pub struct ParameterBuilder {
62    name: String,
63    type_: ParameterType,
64    description: Option<String>,
65    required: bool,
66    enum_values: Option<Vec<String>>,
67}
68
69impl ParameterBuilder {
70    /// Create a new parameter builder with a name and type.
71    pub fn new<S: Into<String>>(name: S, type_: ParameterType) -> Self {
72        Self {
73            name: name.into(),
74            type_,
75            description: None,
76            required: false,
77            enum_values: None,
78        }
79    }
80
81    /// Add a description to the parameter.
82    pub fn with_description<S: Into<String>>(mut self, description: S) -> Self {
83        self.description = Some(description.into());
84        self
85    }
86
87    /// Mark the parameter as required.
88    pub fn is_required(mut self) -> Self {
89        self.required = true;
90        self
91    }
92
93    /// Add allowed enum values for the parameter.
94    pub fn with_enum_values<S: Into<String>, I: IntoIterator<Item = S>>(
95        mut self,
96        values: I,
97    ) -> Self {
98        self.enum_values = Some(values.into_iter().map(|s| s.into()).collect());
99        self
100    }
101
102    /// Convert the builder to a Property.
103    fn to_property(&self) -> Property {
104        Property {
105            type_: self.type_.as_str().to_string(),
106            name: self.name.clone(),
107            description: self.description.clone(),
108            enum_: self.enum_values.clone(),
109        }
110    }
111}
112
113/// Builder for creating a function tool.
114pub struct ToolBuilder {
115    function: Function,
116    parameters: Vec<ParameterBuilder>,
117}
118
119impl ToolBuilder {
120    /// Creates a new tool builder with a function name.
121    pub fn new<S: Into<String>>(name: S) -> Self {
122        Self {
123            function: Function {
124                name: name.into(),
125                description: None,
126                parameters: None,
127            },
128            parameters: Vec::new(),
129        }
130    }
131
132    /// Adds a description to the function.
133    pub fn with_description<S: Into<String>>(mut self, description: S) -> Self {
134        self.function.description = Some(description.into());
135        self
136    }
137
138    /// Adds a parameter to the function.
139    pub fn with_parameter(mut self, parameter: ParameterBuilder) -> Self {
140        self.parameters.push(parameter);
141        self
142    }
143
144    /// Builds the final Tool.
145    pub fn build(self) -> Tool {
146        let mut function = self.function;
147
148        if !self.parameters.is_empty() {
149            let properties = self
150                .parameters
151                .iter()
152                .map(|p| p.to_property())
153                .collect::<Vec<_>>();
154            let required = self
155                .parameters
156                .iter()
157                .filter(|p| p.required)
158                .map(|p| p.name.clone())
159                .collect::<Vec<_>>();
160
161            function.parameters = Some(Parameters {
162                type_: "object".to_string(),
163                properties: Some(properties),
164                required: if required.is_empty() {
165                    None
166                } else {
167                    Some(required)
168                },
169            });
170        }
171
172        Tool::Function(function)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn create_simple_tool() {
182        let tool = ToolBuilder::new("test_tool").build();
183
184        let expected = Tool::Function(Function {
185            name: "test_tool".to_string(),
186            description: None,
187            parameters: None,
188        });
189
190        assert_eq!(tool, expected);
191    }
192
193    #[test]
194    fn tool_with_description() {
195        let tool = ToolBuilder::new("test_tool")
196            .with_description("This is a test tool")
197            .build();
198
199        let expected = Tool::Function(Function {
200            name: "test_tool".to_string(),
201            description: Some("This is a test tool".to_string()),
202            parameters: None,
203        });
204
205        assert_eq!(tool, expected);
206    }
207
208    #[test]
209    fn tool_with_single_parameter() {
210        let tool = ToolBuilder::new("test_tool")
211            .with_parameter(
212                ParameterBuilder::new("param1", ParameterType::String)
213                    .with_description("Test parameter")
214                    .is_required(),
215            )
216            .build();
217
218        let expected = Tool::Function(Function {
219            name: "test_tool".to_string(),
220            description: None,
221            parameters: Some(Parameters {
222                type_: "object".to_string(),
223                properties: Some(vec![Property {
224                    name: "param1".to_string(),
225                    type_: "string".to_string(),
226                    description: Some("Test parameter".to_string()),
227                    enum_: None,
228                }]),
229                required: Some(vec!["param1".to_string()]),
230            }),
231        });
232
233        assert_eq!(tool, expected);
234    }
235
236    #[test]
237    fn tool_with_multiple_parameters() {
238        let tool = ToolBuilder::new("weather_tool")
239            .with_description("Get weather information")
240            .with_parameter(
241                ParameterBuilder::new("location", ParameterType::String)
242                    .with_description("City name")
243                    .is_required(),
244            )
245            .with_parameter(
246                ParameterBuilder::new("units", ParameterType::String)
247                    .with_description("Temperature units"),
248            )
249            .with_parameter(
250                ParameterBuilder::new("forecast", ParameterType::Boolean)
251                    .with_description("Include forecast"),
252            )
253            .build();
254
255        let expected = Tool::Function(Function {
256            name: "weather_tool".to_string(),
257            description: Some("Get weather information".to_string()),
258            parameters: Some(Parameters {
259                type_: "object".to_string(),
260                properties: Some(vec![
261                    Property {
262                        name: "location".to_string(),
263                        type_: "string".to_string(),
264                        description: Some("City name".to_string()),
265                        enum_: None,
266                    },
267                    Property {
268                        name: "units".to_string(),
269                        type_: "string".to_string(),
270                        description: Some("Temperature units".to_string()),
271                        enum_: None,
272                    },
273                    Property {
274                        name: "forecast".to_string(),
275                        type_: "boolean".to_string(),
276                        description: Some("Include forecast".to_string()),
277                        enum_: None,
278                    },
279                ]),
280                required: Some(vec!["location".to_string()]),
281            }),
282        });
283
284        assert_eq!(tool, expected);
285    }
286
287    #[test]
288    fn optional_parameter() {
289        let tool = ToolBuilder::new("test_tool")
290            .with_parameter(
291                ParameterBuilder::new("optional_param", ParameterType::String)
292                    .with_description("This parameter is optional"), // Not calling is_required()
293            )
294            .build();
295
296        let expected = Tool::Function(Function {
297            name: "test_tool".to_string(),
298            description: None,
299            parameters: Some(Parameters {
300                type_: "object".to_string(),
301                properties: Some(vec![Property {
302                    name: "optional_param".to_string(),
303                    type_: "string".to_string(),
304                    description: Some("This parameter is optional".to_string()),
305                    enum_: None,
306                }]),
307                required: None,
308            }),
309        });
310
311        assert_eq!(tool, expected);
312    }
313
314    #[test]
315    fn parameter_type_conversion() {
316        assert_eq!(ParameterType::String.as_str(), "string");
317        assert_eq!(ParameterType::Boolean.as_str(), "boolean");
318        assert_eq!(ParameterType::Number.as_str(), "number");
319    }
320
321    #[test]
322    fn weather_tool_example() {
323        // Test the example from the documentation
324        let weather_tool = ToolBuilder::new("get_current_weather")
325            .with_description("Get current weather for a location.")
326            .with_parameter(
327                ParameterBuilder::new("location", ParameterType::String)
328                    .with_description("The location to get the weather for (e.g. Cairo, Egypt)")
329                    .is_required(),
330            )
331            .build();
332
333        let expected = Tool::Function(Function {
334            name: "get_current_weather".to_string(),
335            description: Some("Get current weather for a location.".to_string()),
336            parameters: Some(Parameters {
337                type_: "object".to_string(),
338                properties: Some(vec![Property {
339                    name: "location".to_string(),
340                    type_: "string".to_string(),
341                    description: Some(
342                        "The location to get the weather for (e.g. Cairo, Egypt)".to_string(),
343                    ),
344                    enum_: None,
345                }]),
346                required: Some(vec!["location".to_string()]),
347            }),
348        });
349
350        assert_eq!(weather_tool, expected);
351    }
352
353    #[test]
354    fn number_parameter() {
355        let tool = ToolBuilder::new("calculator")
356            .with_description("Perform mathematical calculations")
357            .with_parameter(
358                ParameterBuilder::new("value", ParameterType::Number)
359                    .with_description("The numeric value to use in calculation")
360                    .is_required(),
361            )
362            .build();
363
364        let expected = Tool::Function(Function {
365            name: "calculator".to_string(),
366            description: Some("Perform mathematical calculations".to_string()),
367            parameters: Some(Parameters {
368                type_: "object".to_string(),
369                properties: Some(vec![Property {
370                    name: "value".to_string(),
371                    type_: "number".to_string(),
372                    description: Some("The numeric value to use in calculation".to_string()),
373                    enum_: None,
374                }]),
375                required: Some(vec!["value".to_string()]),
376            }),
377        });
378
379        assert_eq!(tool, expected);
380    }
381
382    #[test]
383    fn enum_parameter() {
384        let tool = ToolBuilder::new("unit_converter")
385            .with_description("Convert between different units")
386            .with_parameter(
387                ParameterBuilder::new("value", ParameterType::Number)
388                    .with_description("The value to convert")
389                    .is_required(),
390            )
391            .with_parameter(
392                ParameterBuilder::new("unit", ParameterType::String)
393                    .with_description("The unit to convert to")
394                    .with_enum_values(["meters", "feet", "kilometers", "miles"])
395                    .is_required(),
396            )
397            .build();
398
399        let expected = Tool::Function(Function {
400            name: "unit_converter".to_string(),
401            description: Some("Convert between different units".to_string()),
402            parameters: Some(Parameters {
403                type_: "object".to_string(),
404                properties: Some(vec![
405                    Property {
406                        name: "value".to_string(),
407                        type_: "number".to_string(),
408                        description: Some("The value to convert".to_string()),
409                        enum_: None,
410                    },
411                    Property {
412                        name: "unit".to_string(),
413                        type_: "string".to_string(),
414                        description: Some("The unit to convert to".to_string()),
415                        enum_: Some(vec![
416                            "meters".to_string(),
417                            "feet".to_string(),
418                            "kilometers".to_string(),
419                            "miles".to_string(),
420                        ]),
421                    },
422                ]),
423                required: Some(vec!["value".to_string(), "unit".to_string()]),
424            }),
425        });
426
427        assert_eq!(tool, expected);
428    }
429}