1use serde_json::{json, Map, Value};
2use std::collections::HashMap;
3
4
5
6#[derive(Debug, Clone)]
7pub struct Tool {
8 name: String,
9 description: String,
10 parameters: HashMap<String, ToolParameter>,
11}
12
13#[derive(Debug, Clone)]
14pub struct ToolParameter {
15 parameter_type: String,
16 description: String,
17 required: bool,
18 enum_values: Option<Vec<String>>,
19}
20
21pub struct ToolBuilder {
22 name: Option<String>,
23 description: Option<String>,
24 parameters: HashMap<String, ToolParameter>,
25}
26
27impl ToolBuilder {
28
29 pub fn new() -> Self {
30 ToolBuilder {
31 name: None,
32 description: None,
33 parameters: HashMap::new(),
34 }
35 }
36
37 pub fn name(mut self, name: &str) -> Self {
38 self.name = Some(name.to_string());
39 self
40 }
41
42 pub fn description(mut self, description: &str) -> Self {
43 self.description = Some(description.to_string());
44 self
45 }
46
47 pub fn add_parameter(
48 mut self,
49 name: &str,
50 parameter_type: &str,
51 description: &str,
52 required: bool,
53 ) -> Self {
54 self.parameters.insert(
55 name.to_string(),
56 ToolParameter {
57 parameter_type: parameter_type.to_string(),
58 description: description.to_string(),
59 required,
60 enum_values: None,
61 },
62 );
63 self
64 }
65
66 pub fn add_enum_parameter(
67 mut self,
68 name: &str,
69 description: &str,
70 required: bool,
71 enum_values: Vec<String>,
72 ) -> Self {
73 self.parameters.insert(
74 name.to_string(),
75 ToolParameter {
76 parameter_type: "string".to_string(),
77 description: description.to_string(),
78 required,
79 enum_values: Some(enum_values),
80 },
81 );
82 self
83 }
84
85 pub fn build(self) -> Result<Tool, String> {
86 let name = self.name.ok_or("Tool name is required")?;
87 let description = self.description.ok_or("Tool description is required")?;
88
89 Ok(Tool {
90 name,
91 description,
92 parameters: self.parameters,
93 })
94 }
95}
96
97impl Tool {
98 pub fn builder() -> ToolBuilder {
99 ToolBuilder::new()
100 }
101
102 pub fn to_anthropic_format(&self) -> Value {
103 let mut properties = serde_json::Map::new();
104 let mut required = Vec::new();
105
106 self.process_tool_input(&mut properties, &mut required);
107
108 json!({
109 "name": self.name,
110 "description": self.description,
111 "input_schema": {
112 "type": "object",
113 "properties": properties,
114 "required": ["location"]
115 }
116 })
117 }
118
119 pub fn to_openai_format(&self) -> Value {
120 let mut properties = serde_json::Map::new();
121 let mut required = Vec::new();
122
123 self.process_tool_input(&mut properties, &mut required);
124
125 json!({
126 "type": "function",
127 "function": {
128 "name": self.name,
129 "description": self.description,
130 "parameters": {
131 "type": "object",
132 "properties": properties,
133 "required": required
134 }
135 }
136 })
137 }
138
139 fn process_tool_input(&self, properties: &mut Map<String, Value>, required: &mut Vec<Value>) {
140 for (name, param) in &self.parameters {
141 let mut property = serde_json::Map::new();
142 property.insert(
143 "type".to_string(),
144 Value::String(param.parameter_type.clone()),
145 );
146 property.insert(
147 "description".to_string(),
148 Value::String(param.description.clone()),
149 );
150
151 if let Some(enum_values) = ¶m.enum_values {
152 property.insert(
153 "enum".to_string(),
154 Value::Array(
155 enum_values
156 .iter()
157 .map(|v| Value::String(v.clone()))
158 .collect(),
159 ),
160 );
161 }
162
163 properties.insert(name.clone(), Value::Object(property));
164
165 if param.required {
166 required.push(Value::String(name.clone()));
167 }
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use serde_json::json;
176
177 #[test]
178 fn test_tool_builder() {
179 let tool = Tool::builder()
180 .name("get_weather")
181 .description("Get the current weather in a given location")
182 .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
183 .add_enum_parameter("unit", "The unit of temperature to use", false, vec!["celsius".to_string(), "fahrenheit".to_string()])
184 .build()
185 .expect("Failed to build tool");
186
187 assert_eq!(tool.name, "get_weather");
188 assert_eq!(tool.description, "Get the current weather in a given location");
189 assert_eq!(tool.parameters.len(), 2);
190
191 let location_param = tool.parameters.get("location").expect("Location parameter not found");
192 assert_eq!(location_param.parameter_type, "string");
193 assert_eq!(location_param.description, "The city and state, e.g. San Francisco, CA");
194 assert!(location_param.required);
195 assert!(location_param.enum_values.is_none());
196
197 let unit_param = tool.parameters.get("unit").expect("Unit parameter not found");
198 assert_eq!(unit_param.parameter_type, "string");
199 assert_eq!(unit_param.description, "The unit of temperature to use");
200 assert!(!unit_param.required);
201 assert_eq!(unit_param.enum_values, Some(vec!["celsius".to_string(), "fahrenheit".to_string()]));
202 }
203
204 #[test]
205 fn test_tool_builder_missing_name() {
206 let result = Tool::builder()
207 .description("Get the current weather in a given location")
208 .build();
209
210 assert!(result.is_err());
211 assert_eq!(result.unwrap_err(), "Tool name is required");
212 }
213
214 #[test]
215 fn test_tool_builder_missing_description() {
216 let result = Tool::builder()
217 .name("get_weather")
218 .build();
219
220 assert!(result.is_err());
221 assert_eq!(result.unwrap_err(), "Tool description is required");
222 }
223
224 #[test]
225 fn test_to_anthropic_format() {
226 let tool = Tool::builder()
227 .name("get_weather")
228 .description("Get the current weather in a given location")
229 .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
230 .add_enum_parameter("unit", "The unit of temperature, either 'celsius' or 'fahrenheit'", false, vec!["celsius".to_string(), "fahrenheit".to_string()])
231 .build()
232 .expect("Failed to build tool");
233
234 let anthropic_format = tool.to_anthropic_format();
235
236 let expected = json!({
237 "name": "get_weather",
238 "description": "Get the current weather in a given location",
239 "input_schema": {
240 "type": "object",
241 "properties": {
242 "location": {
243 "type": "string",
244 "description": "The city and state, e.g. San Francisco, CA"
245 },
246 "unit": {
247 "type": "string",
248 "enum": ["celsius", "fahrenheit"],
249 "description": "The unit of temperature, either 'celsius' or 'fahrenheit'"
250 }
251 },
252 "required": ["location"]
253 }
254 });
255
256 assert_eq!(anthropic_format, expected);
257 }
258
259 #[test]
260 fn test_to_openai_format() {
261 let tool = Tool::builder()
262 .name("get_current_weather")
263 .description("Get the current weather in a given location")
264 .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
265 .add_enum_parameter("format", "The temperature unit to use. Infer this from the users location.", true, vec!["celsius".to_string(), "fahrenheit".to_string()])
266 .build()
267 .expect("Failed to build tool");
268
269 let openai_format = tool.to_openai_format();
270
271 let expected = json!({
272 "type": "function",
273 "function": {
274 "name": "get_current_weather",
275 "description": "Get the current weather in a given location",
276 "parameters": {
277 "type": "object",
278 "properties": {
279 "location": {
280 "type": "string",
281 "description": "The city and state, e.g. San Francisco, CA",
282 },
283 "format": {
284 "type": "string",
285 "enum": ["celsius", "fahrenheit"],
286 "description": "The temperature unit to use. Infer this from the users location.",
287 },
288 },
289 "required": ["location", "format"],
290 },
291 }
292 });
293 assert_eq!(openai_format["type"], expected["type"]);
294 assert_eq!(openai_format["function"]["name"], expected["function"]["name"]);
295 assert_eq!(openai_format["function"]["description"], expected["function"]["description"]);
296
297 let actual_params = &openai_format["function"]["parameters"];
298 let expected_params = &expected["function"]["parameters"];
299
300 assert_eq!(actual_params["type"], expected_params["type"]);
301 assert_eq!(actual_params["properties"], expected_params["properties"]);
302 let mut actual_required: Vec<String> = actual_params["required"]
304 .as_array()
305 .unwrap()
306 .iter()
307 .map(|v| v.as_str().unwrap().to_string())
308 .collect();
309 let mut expected_required: Vec<String> = expected_params["required"]
310 .as_array()
311 .unwrap()
312 .iter()
313 .map(|v| v.as_str().unwrap().to_string())
314 .collect();
315
316 actual_required.sort();
317 expected_required.sort();
318 assert_eq!(actual_required, expected_required);
319 }
320}