1use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Tool {
10 #[serde(rename = "type")]
12 pub tool_type: ToolType,
13 #[serde(skip_serializing_if = "Option::is_none")]
15 pub function: Option<FunctionTool>,
16 #[serde(skip_serializing_if = "Option::is_none")]
18 pub max_keyword: Option<u32>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub force_search: Option<bool>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub limit: Option<u32>,
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub user_location: Option<UserLocation>,
28}
29
30impl Tool {
31 pub fn function(name: impl Into<String>, description: impl Into<String>) -> Self {
33 Self {
34 tool_type: ToolType::Function,
35 function: Some(FunctionTool::new(name, description)),
36 max_keyword: None,
37 force_search: None,
38 limit: None,
39 user_location: None,
40 }
41 }
42
43 pub fn function_with_params(
45 name: impl Into<String>,
46 description: impl Into<String>,
47 parameters: HashMap<String, Value>,
48 ) -> Self {
49 Self {
50 tool_type: ToolType::Function,
51 function: Some(FunctionTool::with_params(name, description, parameters)),
52 max_keyword: None,
53 force_search: None,
54 limit: None,
55 user_location: None,
56 }
57 }
58
59 pub fn web_search() -> Self {
65 Self {
66 tool_type: ToolType::WebSearch,
67 function: None,
68 max_keyword: None,
69 force_search: None,
70 limit: None,
71 user_location: None,
72 }
73 }
74
75 pub fn max_keyword(mut self, max: u32) -> Self {
77 self.max_keyword = Some(max);
78 self
79 }
80
81 pub fn force_search(mut self, force: bool) -> Self {
83 self.force_search = Some(force);
84 self
85 }
86
87 pub fn limit(mut self, limit: u32) -> Self {
89 self.limit = Some(limit);
90 self
91 }
92
93 pub fn user_location(mut self, location: UserLocation) -> Self {
95 self.user_location = Some(location);
96 self
97 }
98
99 pub fn strict(mut self, strict: bool) -> Self {
101 if let Some(ref mut function) = self.function {
102 function.strict = Some(strict);
103 }
104 self
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum ToolType {
112 Function,
114 WebSearch,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct UserLocation {
121 #[serde(rename = "type")]
123 pub location_type: String,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub country: Option<String>,
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub region: Option<String>,
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub city: Option<String>,
133}
134
135impl UserLocation {
136 pub fn new(location_type: impl Into<String>) -> Self {
138 Self {
139 location_type: location_type.into(),
140 country: None,
141 region: None,
142 city: None,
143 }
144 }
145
146 pub fn approximate() -> Self {
148 Self::new("approximate")
149 }
150
151 pub fn country(mut self, country: impl Into<String>) -> Self {
153 self.country = Some(country.into());
154 self
155 }
156
157 pub fn region(mut self, region: impl Into<String>) -> Self {
159 self.region = Some(region.into());
160 self
161 }
162
163 pub fn city(mut self, city: impl Into<String>) -> Self {
165 self.city = Some(city.into());
166 self
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct FunctionTool {
173 pub name: String,
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub description: Option<String>,
178 #[serde(skip_serializing_if = "Option::is_none")]
180 pub parameters: Option<HashMap<String, Value>>,
181 #[serde(skip_serializing_if = "Option::is_none")]
183 pub strict: Option<bool>,
184}
185
186impl FunctionTool {
187 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189 Self {
190 name: name.into(),
191 description: Some(description.into()),
192 parameters: None,
193 strict: None,
194 }
195 }
196
197 pub fn with_params(
199 name: impl Into<String>,
200 description: impl Into<String>,
201 parameters: HashMap<String, Value>,
202 ) -> Self {
203 Self {
204 name: name.into(),
205 description: Some(description.into()),
206 parameters: Some(parameters),
207 strict: None,
208 }
209 }
210
211 pub fn parameters(mut self, parameters: HashMap<String, Value>) -> Self {
213 self.parameters = Some(parameters);
214 self
215 }
216
217 pub fn strict(mut self, strict: bool) -> Self {
219 self.strict = Some(strict);
220 self
221 }
222}
223
224#[derive(Debug, Clone, Default)]
226pub struct ParameterBuilder {
227 params: HashMap<String, Value>,
228}
229
230impl ParameterBuilder {
231 pub fn new() -> Self {
233 Self::default()
234 }
235
236 pub fn type_object(mut self) -> Self {
238 self.params
239 .insert("type".to_string(), Value::String("object".to_string()));
240 self
241 }
242
243 pub fn required_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
245 let properties = self
247 .params
248 .entry("properties".to_string())
249 .or_insert_with(|| Value::Object(serde_json::Map::new()));
250
251 if let Value::Object(props) = properties {
252 props.insert(
253 name.to_string(),
254 Value::Object(schema.into_iter().collect()),
255 );
256 }
257
258 let required = self
260 .params
261 .entry("required".to_string())
262 .or_insert_with(|| Value::Array(Vec::new()));
263
264 if let Value::Array(req) = required {
265 req.push(Value::String(name.to_string()));
266 }
267
268 self
269 }
270
271 pub fn optional_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
273 let properties = self
274 .params
275 .entry("properties".to_string())
276 .or_insert_with(|| Value::Object(serde_json::Map::new()));
277
278 if let Value::Object(props) = properties {
279 props.insert(
280 name.to_string(),
281 Value::Object(schema.into_iter().collect()),
282 );
283 }
284
285 self
286 }
287
288 pub fn build(self) -> HashMap<String, Value> {
290 self.params
291 }
292}
293
294pub mod schema {
296 use serde_json::Value;
297 use std::collections::HashMap;
298
299 pub fn string() -> HashMap<String, Value> {
301 let mut map = HashMap::new();
302 map.insert("type".to_string(), Value::String("string".to_string()));
303 map
304 }
305
306 pub fn string_with_description(desc: &str) -> HashMap<String, Value> {
308 let mut map = string();
309 map.insert("description".to_string(), Value::String(desc.to_string()));
310 map
311 }
312
313 pub fn number() -> HashMap<String, Value> {
315 let mut map = HashMap::new();
316 map.insert("type".to_string(), Value::String("number".to_string()));
317 map
318 }
319
320 pub fn integer() -> HashMap<String, Value> {
322 let mut map = HashMap::new();
323 map.insert("type".to_string(), Value::String("integer".to_string()));
324 map
325 }
326
327 pub fn boolean() -> HashMap<String, Value> {
329 let mut map = HashMap::new();
330 map.insert("type".to_string(), Value::String("boolean".to_string()));
331 map
332 }
333
334 pub fn array(items: HashMap<String, Value>) -> HashMap<String, Value> {
336 let mut map = HashMap::new();
337 map.insert("type".to_string(), Value::String("array".to_string()));
338 map.insert(
339 "items".to_string(),
340 Value::Object(items.into_iter().collect()),
341 );
342 map
343 }
344
345 pub fn enum_values(values: &[&str]) -> HashMap<String, Value> {
347 let mut map = HashMap::new();
348 map.insert("type".to_string(), Value::String("string".to_string()));
349 map.insert(
350 "enum".to_string(),
351 Value::Array(
352 values
353 .iter()
354 .map(|v| Value::String(v.to_string()))
355 .collect(),
356 ),
357 );
358 map
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_tool_creation() {
368 let tool = Tool::function("get_weather", "Get the current weather");
369 assert_eq!(tool.tool_type, ToolType::Function);
370 assert!(tool.function.is_some());
371 }
372
373 #[test]
374 fn test_web_search_tool() {
375 let tool = Tool::web_search();
376 assert_eq!(tool.tool_type, ToolType::WebSearch);
377 assert!(tool.function.is_none());
378 }
379
380 #[test]
381 fn test_tool_serialization() {
382 let tool = Tool::function("get_weather", "Get weather info");
383 let json = serde_json::to_string(&tool).unwrap();
384 assert!(json.contains("\"type\":\"function\""));
385 assert!(json.contains("\"name\":\"get_weather\""));
386 }
387
388 #[test]
389 fn test_tool_with_parameters() {
390 let mut params = HashMap::new();
391 params.insert("type".to_string(), Value::String("object".to_string()));
392
393 let tool =
394 Tool::function_with_params("get_weather", "Get weather for a location", params.clone());
395
396 assert!(tool.function.as_ref().unwrap().parameters.is_some());
397 }
398
399 #[test]
400 fn test_parameter_builder() {
401 let params = ParameterBuilder::new()
402 .type_object()
403 .required_property("location", schema::string())
404 .build();
405
406 assert_eq!(
407 params.get("type").unwrap(),
408 &Value::String("object".to_string())
409 );
410 }
411
412 #[test]
413 fn test_schema_helpers() {
414 let s = schema::string();
415 assert_eq!(s.get("type").unwrap(), &Value::String("string".to_string()));
416
417 let n = schema::number();
418 assert_eq!(n.get("type").unwrap(), &Value::String("number".to_string()));
419
420 let e = schema::enum_values(&["a", "b", "c"]);
421 assert!(e.contains_key("enum"));
422 }
423
424 #[test]
425 fn test_strict_mode() {
426 let tool = Tool::function("test", "test function").strict(true);
427 assert_eq!(tool.function.unwrap().strict, Some(true));
428 }
429
430 #[test]
431 fn test_web_search_with_options() {
432 let tool = Tool::web_search()
433 .max_keyword(3)
434 .force_search(true)
435 .limit(5);
436
437 assert_eq!(tool.tool_type, ToolType::WebSearch);
438 assert_eq!(tool.max_keyword, Some(3));
439 assert_eq!(tool.force_search, Some(true));
440 assert_eq!(tool.limit, Some(5));
441 }
442
443 #[test]
444 fn test_web_search_serialization() {
445 let tool = Tool::web_search().max_keyword(3).force_search(true);
446
447 let json = serde_json::to_string(&tool).unwrap();
448 assert!(json.contains("\"type\":\"web_search\""));
449 assert!(json.contains("\"max_keyword\":3"));
450 assert!(json.contains("\"force_search\":true"));
451 }
452
453 #[test]
454 fn test_user_location() {
455 let location = UserLocation::approximate()
456 .country("China")
457 .region("Hubei")
458 .city("Wuhan");
459
460 assert_eq!(location.location_type, "approximate");
461 assert_eq!(location.country, Some("China".to_string()));
462 assert_eq!(location.region, Some("Hubei".to_string()));
463 assert_eq!(location.city, Some("Wuhan".to_string()));
464 }
465
466 #[test]
467 fn test_user_location_serialization() {
468 let location = UserLocation::approximate().country("China").city("Beijing");
469
470 let json = serde_json::to_string(&location).unwrap();
471 assert!(json.contains("\"type\":\"approximate\""));
472 assert!(json.contains("\"country\":\"China\""));
473 assert!(json.contains("\"city\":\"Beijing\""));
474 assert!(!json.contains("region")); }
476
477 #[test]
478 fn test_web_search_with_location() {
479 let tool = Tool::web_search().user_location(
480 UserLocation::approximate()
481 .country("China")
482 .city("Shanghai"),
483 );
484
485 assert!(tool.user_location.is_some());
486 let loc = tool.user_location.unwrap();
487 assert_eq!(loc.country, Some("China".to_string()));
488 assert_eq!(loc.city, Some("Shanghai".to_string()));
489 }
490}