Skip to main content

llm_bridge/
tool.rs

1use serde_json::{json, Map, Value};
2use std::collections::HashMap;
3
4
5
6#[derive(Debug, Clone)]
7pub struct Tool {
8    name: String,
9    description: String,
10    parameters: HashMap<String, ToolParameter>,
11}
12
13#[derive(Debug, Clone)]
14pub struct ToolParameter {
15    parameter_type: String,
16    description: String,
17    required: bool,
18    enum_values: Option<Vec<String>>,
19}
20
21pub struct ToolBuilder {
22    name: Option<String>,
23    description: Option<String>,
24    parameters: HashMap<String, ToolParameter>,
25}
26
27impl ToolBuilder {
28    
29    pub fn new() -> Self {
30        ToolBuilder {
31            name: None,
32            description: None,
33            parameters: HashMap::new(),
34        }
35    }
36
37    pub fn name(mut self, name: &str) -> Self {
38        self.name = Some(name.to_string());
39        self
40    }
41
42    pub fn description(mut self, description: &str) -> Self {
43        self.description = Some(description.to_string());
44        self
45    }
46
47    pub fn add_parameter(
48        mut self,
49        name: &str,
50        parameter_type: &str,
51        description: &str,
52        required: bool,
53    ) -> Self {
54        self.parameters.insert(
55            name.to_string(),
56            ToolParameter {
57                parameter_type: parameter_type.to_string(),
58                description: description.to_string(),
59                required,
60                enum_values: None,
61            },
62        );
63        self
64    }
65
66    pub fn add_enum_parameter(
67        mut self,
68        name: &str,
69        description: &str,
70        required: bool,
71        enum_values: Vec<String>,
72    ) -> Self {
73        self.parameters.insert(
74            name.to_string(),
75            ToolParameter {
76                parameter_type: "string".to_string(),
77                description: description.to_string(),
78                required,
79                enum_values: Some(enum_values),
80            },
81        );
82        self
83    }
84
85    pub fn build(self) -> Result<Tool, String> {
86        let name = self.name.ok_or("Tool name is required")?;
87        let description = self.description.ok_or("Tool description is required")?;
88
89        Ok(Tool {
90            name,
91            description,
92            parameters: self.parameters,
93        })
94    }
95}
96
97impl Tool {
98    pub fn builder() -> ToolBuilder {
99        ToolBuilder::new()
100    }
101
102    pub fn to_anthropic_format(&self) -> Value {
103        let mut properties = serde_json::Map::new();
104        let mut required = Vec::new();
105
106        self.process_tool_input(&mut properties, &mut required);
107
108        json!({
109            "name": self.name,
110            "description": self.description,
111            "input_schema": {
112                "type": "object",
113                "properties": properties,
114                "required": ["location"]
115            }
116        })
117    }
118
119    pub fn to_openai_format(&self) -> Value {
120        let mut properties = serde_json::Map::new();
121        let mut required = Vec::new();
122
123        self.process_tool_input(&mut properties, &mut required);
124
125        json!({
126            "type": "function",
127            "function": {
128                "name": self.name,
129                "description": self.description,
130                "parameters": {
131                    "type": "object",
132                    "properties": properties,
133                    "required": required
134                }
135            }
136        })
137    }
138
139    fn process_tool_input(&self, properties: &mut Map<String, Value>, required: &mut Vec<Value>) {
140        for (name, param) in &self.parameters {
141            let mut property = serde_json::Map::new();
142            property.insert(
143                "type".to_string(),
144                Value::String(param.parameter_type.clone()),
145            );
146            property.insert(
147                "description".to_string(),
148                Value::String(param.description.clone()),
149            );
150
151            if let Some(enum_values) = &param.enum_values {
152                property.insert(
153                    "enum".to_string(),
154                    Value::Array(
155                        enum_values
156                            .iter()
157                            .map(|v| Value::String(v.clone()))
158                            .collect(),
159                    ),
160                );
161            }
162
163            properties.insert(name.clone(), Value::Object(property));
164
165            if param.required {
166                required.push(Value::String(name.clone()));
167            }
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use serde_json::json;
176
177    #[test]
178    fn test_tool_builder() {
179        let tool = Tool::builder()
180            .name("get_weather")
181            .description("Get the current weather in a given location")
182            .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
183            .add_enum_parameter("unit", "The unit of temperature to use", false, vec!["celsius".to_string(), "fahrenheit".to_string()])
184            .build()
185            .expect("Failed to build tool");
186
187        assert_eq!(tool.name, "get_weather");
188        assert_eq!(tool.description, "Get the current weather in a given location");
189        assert_eq!(tool.parameters.len(), 2);
190
191        let location_param = tool.parameters.get("location").expect("Location parameter not found");
192        assert_eq!(location_param.parameter_type, "string");
193        assert_eq!(location_param.description, "The city and state, e.g. San Francisco, CA");
194        assert!(location_param.required);
195        assert!(location_param.enum_values.is_none());
196
197        let unit_param = tool.parameters.get("unit").expect("Unit parameter not found");
198        assert_eq!(unit_param.parameter_type, "string");
199        assert_eq!(unit_param.description, "The unit of temperature to use");
200        assert!(!unit_param.required);
201        assert_eq!(unit_param.enum_values, Some(vec!["celsius".to_string(), "fahrenheit".to_string()]));
202    }
203
204    #[test]
205    fn test_tool_builder_missing_name() {
206        let result = Tool::builder()
207            .description("Get the current weather in a given location")
208            .build();
209
210        assert!(result.is_err());
211        assert_eq!(result.unwrap_err(), "Tool name is required");
212    }
213
214    #[test]
215    fn test_tool_builder_missing_description() {
216        let result = Tool::builder()
217            .name("get_weather")
218            .build();
219
220        assert!(result.is_err());
221        assert_eq!(result.unwrap_err(), "Tool description is required");
222    }
223
224    #[test]
225    fn test_to_anthropic_format() {
226        let tool = Tool::builder()
227            .name("get_weather")
228            .description("Get the current weather in a given location")
229            .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
230            .add_enum_parameter("unit", "The unit of temperature, either 'celsius' or 'fahrenheit'", false, vec!["celsius".to_string(), "fahrenheit".to_string()])
231            .build()
232            .expect("Failed to build tool");
233
234        let anthropic_format = tool.to_anthropic_format();
235
236        let expected = json!({
237            "name": "get_weather",
238            "description": "Get the current weather in a given location",
239            "input_schema": {
240              "type": "object",
241              "properties": {
242                "location": {
243                  "type": "string",
244                  "description": "The city and state, e.g. San Francisco, CA"
245                },
246                "unit": {
247                  "type": "string",
248                  "enum": ["celsius", "fahrenheit"],
249                  "description": "The unit of temperature, either 'celsius' or 'fahrenheit'"
250                }
251              },
252              "required": ["location"]
253            }
254        });
255
256        assert_eq!(anthropic_format, expected);
257    }
258
259    #[test]
260    fn test_to_openai_format() {
261        let tool = Tool::builder()
262            .name("get_current_weather")
263            .description("Get the current weather in a given location")
264            .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
265            .add_enum_parameter("format", "The temperature unit to use. Infer this from the users location.", true, vec!["celsius".to_string(), "fahrenheit".to_string()])
266            .build()
267            .expect("Failed to build tool");
268
269        let openai_format = tool.to_openai_format();
270
271        let expected = json!({
272            "type": "function",
273            "function": {
274                "name": "get_current_weather",
275                "description": "Get the current weather in a given location",
276                "parameters": {
277                    "type": "object",
278                    "properties": {
279                        "location": {
280                            "type": "string",
281                            "description": "The city and state, e.g. San Francisco, CA",
282                        },
283                        "format": {
284                            "type": "string",
285                            "enum": ["celsius", "fahrenheit"],
286                            "description": "The temperature unit to use. Infer this from the users location.",
287                        },
288                    },
289                    "required": ["location", "format"],
290                },
291            }
292        });
293        assert_eq!(openai_format["type"], expected["type"]);
294        assert_eq!(openai_format["function"]["name"], expected["function"]["name"]);
295        assert_eq!(openai_format["function"]["description"], expected["function"]["description"]);
296
297        let actual_params = &openai_format["function"]["parameters"];
298        let expected_params = &expected["function"]["parameters"];
299
300        assert_eq!(actual_params["type"], expected_params["type"]);
301        assert_eq!(actual_params["properties"], expected_params["properties"]);
302        // sort required prior to comparison
303        let mut actual_required: Vec<String> = actual_params["required"]
304            .as_array()
305            .unwrap()
306            .iter()
307            .map(|v| v.as_str().unwrap().to_string())
308            .collect();
309        let mut expected_required: Vec<String> = expected_params["required"]
310            .as_array()
311            .unwrap()
312            .iter()
313            .map(|v| v.as_str().unwrap().to_string())
314            .collect();
315
316        actual_required.sort();
317        expected_required.sort();
318        assert_eq!(actual_required, expected_required);
319    }
320}