1use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Tool {
10 #[serde(rename = "type")]
12 pub tool_type: ToolType,
13 #[serde(skip_serializing_if = "Option::is_none")]
15 pub function: Option<FunctionTool>,
16 #[serde(skip_serializing_if = "Option::is_none")]
18 pub max_keyword: Option<u32>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub force_search: Option<bool>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub limit: Option<u32>,
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub user_location: Option<UserLocation>,
28}
29
30impl Tool {
31 pub fn function(name: impl Into<String>, description: impl Into<String>) -> Self {
33 Self {
34 tool_type: ToolType::Function,
35 function: Some(FunctionTool::new(name, description)),
36 max_keyword: None,
37 force_search: None,
38 limit: None,
39 user_location: None,
40 }
41 }
42
43 pub fn function_with_params(
45 name: impl Into<String>,
46 description: impl Into<String>,
47 parameters: HashMap<String, Value>,
48 ) -> Self {
49 Self {
50 tool_type: ToolType::Function,
51 function: Some(FunctionTool::with_params(name, description, parameters)),
52 max_keyword: None,
53 force_search: None,
54 limit: None,
55 user_location: None,
56 }
57 }
58
59 pub fn web_search() -> Self {
65 Self {
66 tool_type: ToolType::WebSearch,
67 function: None,
68 max_keyword: None,
69 force_search: None,
70 limit: None,
71 user_location: None,
72 }
73 }
74
75 pub fn max_keyword(mut self, max: u32) -> Self {
77 self.max_keyword = Some(max);
78 self
79 }
80
81 pub fn force_search(mut self, force: bool) -> Self {
83 self.force_search = Some(force);
84 self
85 }
86
87 pub fn limit(mut self, limit: u32) -> Self {
89 self.limit = Some(limit);
90 self
91 }
92
93 pub fn user_location(mut self, location: UserLocation) -> Self {
95 self.user_location = Some(location);
96 self
97 }
98
99 pub fn strict(mut self, strict: bool) -> Self {
101 if let Some(ref mut function) = self.function {
102 function.strict = Some(strict);
103 }
104 self
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum ToolType {
112 Function,
114 WebSearch,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct UserLocation {
121 #[serde(rename = "type")]
123 pub location_type: String,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub country: Option<String>,
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub region: Option<String>,
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub city: Option<String>,
133}
134
135impl UserLocation {
136 pub fn new(location_type: impl Into<String>) -> Self {
138 Self {
139 location_type: location_type.into(),
140 country: None,
141 region: None,
142 city: None,
143 }
144 }
145
146 pub fn approximate() -> Self {
148 Self::new("approximate")
149 }
150
151 pub fn country(mut self, country: impl Into<String>) -> Self {
153 self.country = Some(country.into());
154 self
155 }
156
157 pub fn region(mut self, region: impl Into<String>) -> Self {
159 self.region = Some(region.into());
160 self
161 }
162
163 pub fn city(mut self, city: impl Into<String>) -> Self {
165 self.city = Some(city.into());
166 self
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct FunctionTool {
173 pub name: String,
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub description: Option<String>,
178 #[serde(skip_serializing_if = "Option::is_none")]
180 pub parameters: Option<HashMap<String, Value>>,
181 #[serde(skip_serializing_if = "Option::is_none")]
183 pub strict: Option<bool>,
184}
185
186impl FunctionTool {
187 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189 Self {
190 name: name.into(),
191 description: Some(description.into()),
192 parameters: None,
193 strict: None,
194 }
195 }
196
197 pub fn with_params(
199 name: impl Into<String>,
200 description: impl Into<String>,
201 parameters: HashMap<String, Value>,
202 ) -> Self {
203 Self {
204 name: name.into(),
205 description: Some(description.into()),
206 parameters: Some(parameters),
207 strict: None,
208 }
209 }
210
211 pub fn parameters(mut self, parameters: HashMap<String, Value>) -> Self {
213 self.parameters = Some(parameters);
214 self
215 }
216
217 pub fn strict(mut self, strict: bool) -> Self {
219 self.strict = Some(strict);
220 self
221 }
222}
223
224#[derive(Debug, Clone, Default)]
226pub struct ParameterBuilder {
227 params: HashMap<String, Value>,
228}
229
230impl ParameterBuilder {
231 pub fn new() -> Self {
233 Self::default()
234 }
235
236 pub fn type_object(mut self) -> Self {
238 self.params.insert("type".to_string(), Value::String("object".to_string()));
239 self
240 }
241
242 pub fn required_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
244 let properties = self
246 .params
247 .entry("properties".to_string())
248 .or_insert_with(|| Value::Object(serde_json::Map::new()));
249
250 if let Value::Object(props) = properties {
251 props.insert(name.to_string(), Value::Object(schema.into_iter().collect()));
252 }
253
254 let required = self
256 .params
257 .entry("required".to_string())
258 .or_insert_with(|| Value::Array(Vec::new()));
259
260 if let Value::Array(req) = required {
261 req.push(Value::String(name.to_string()));
262 }
263
264 self
265 }
266
267 pub fn optional_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
269 let properties = self
270 .params
271 .entry("properties".to_string())
272 .or_insert_with(|| Value::Object(serde_json::Map::new()));
273
274 if let Value::Object(props) = properties {
275 props.insert(name.to_string(), Value::Object(schema.into_iter().collect()));
276 }
277
278 self
279 }
280
281 pub fn build(self) -> HashMap<String, Value> {
283 self.params
284 }
285}
286
287pub mod schema {
289 use serde_json::Value;
290 use std::collections::HashMap;
291
292 pub fn string() -> HashMap<String, Value> {
294 let mut map = HashMap::new();
295 map.insert("type".to_string(), Value::String("string".to_string()));
296 map
297 }
298
299 pub fn string_with_description(desc: &str) -> HashMap<String, Value> {
301 let mut map = string();
302 map.insert("description".to_string(), Value::String(desc.to_string()));
303 map
304 }
305
306 pub fn number() -> HashMap<String, Value> {
308 let mut map = HashMap::new();
309 map.insert("type".to_string(), Value::String("number".to_string()));
310 map
311 }
312
313 pub fn integer() -> HashMap<String, Value> {
315 let mut map = HashMap::new();
316 map.insert("type".to_string(), Value::String("integer".to_string()));
317 map
318 }
319
320 pub fn boolean() -> HashMap<String, Value> {
322 let mut map = HashMap::new();
323 map.insert("type".to_string(), Value::String("boolean".to_string()));
324 map
325 }
326
327 pub fn array(items: HashMap<String, Value>) -> HashMap<String, Value> {
329 let mut map = HashMap::new();
330 map.insert("type".to_string(), Value::String("array".to_string()));
331 map.insert("items".to_string(), Value::Object(items.into_iter().collect()));
332 map
333 }
334
335 pub fn enum_values(values: &[&str]) -> HashMap<String, Value> {
337 let mut map = HashMap::new();
338 map.insert("type".to_string(), Value::String("string".to_string()));
339 map.insert(
340 "enum".to_string(),
341 Value::Array(values.iter().map(|v| Value::String(v.to_string())).collect()),
342 );
343 map
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_tool_creation() {
353 let tool = Tool::function("get_weather", "Get the current weather");
354 assert_eq!(tool.tool_type, ToolType::Function);
355 assert!(tool.function.is_some());
356 }
357
358 #[test]
359 fn test_web_search_tool() {
360 let tool = Tool::web_search();
361 assert_eq!(tool.tool_type, ToolType::WebSearch);
362 assert!(tool.function.is_none());
363 }
364
365 #[test]
366 fn test_tool_serialization() {
367 let tool = Tool::function("get_weather", "Get weather info");
368 let json = serde_json::to_string(&tool).unwrap();
369 assert!(json.contains("\"type\":\"function\""));
370 assert!(json.contains("\"name\":\"get_weather\""));
371 }
372
373 #[test]
374 fn test_tool_with_parameters() {
375 let mut params = HashMap::new();
376 params.insert("type".to_string(), Value::String("object".to_string()));
377
378 let tool = Tool::function_with_params(
379 "get_weather",
380 "Get weather for a location",
381 params.clone(),
382 );
383
384 assert!(tool.function.as_ref().unwrap().parameters.is_some());
385 }
386
387 #[test]
388 fn test_parameter_builder() {
389 let params = ParameterBuilder::new()
390 .type_object()
391 .required_property("location", schema::string())
392 .build();
393
394 assert_eq!(params.get("type").unwrap(), &Value::String("object".to_string()));
395 }
396
397 #[test]
398 fn test_schema_helpers() {
399 let s = schema::string();
400 assert_eq!(s.get("type").unwrap(), &Value::String("string".to_string()));
401
402 let n = schema::number();
403 assert_eq!(n.get("type").unwrap(), &Value::String("number".to_string()));
404
405 let e = schema::enum_values(&["a", "b", "c"]);
406 assert!(e.contains_key("enum"));
407 }
408
409 #[test]
410 fn test_strict_mode() {
411 let tool = Tool::function("test", "test function").strict(true);
412 assert_eq!(tool.function.unwrap().strict, Some(true));
413 }
414
415 #[test]
416 fn test_web_search_with_options() {
417 let tool = Tool::web_search()
418 .max_keyword(3)
419 .force_search(true)
420 .limit(5);
421
422 assert_eq!(tool.tool_type, ToolType::WebSearch);
423 assert_eq!(tool.max_keyword, Some(3));
424 assert_eq!(tool.force_search, Some(true));
425 assert_eq!(tool.limit, Some(5));
426 }
427
428 #[test]
429 fn test_web_search_serialization() {
430 let tool = Tool::web_search()
431 .max_keyword(3)
432 .force_search(true);
433
434 let json = serde_json::to_string(&tool).unwrap();
435 assert!(json.contains("\"type\":\"web_search\""));
436 assert!(json.contains("\"max_keyword\":3"));
437 assert!(json.contains("\"force_search\":true"));
438 }
439
440 #[test]
441 fn test_user_location() {
442 let location = UserLocation::approximate()
443 .country("China")
444 .region("Hubei")
445 .city("Wuhan");
446
447 assert_eq!(location.location_type, "approximate");
448 assert_eq!(location.country, Some("China".to_string()));
449 assert_eq!(location.region, Some("Hubei".to_string()));
450 assert_eq!(location.city, Some("Wuhan".to_string()));
451 }
452
453 #[test]
454 fn test_user_location_serialization() {
455 let location = UserLocation::approximate()
456 .country("China")
457 .city("Beijing");
458
459 let json = serde_json::to_string(&location).unwrap();
460 assert!(json.contains("\"type\":\"approximate\""));
461 assert!(json.contains("\"country\":\"China\""));
462 assert!(json.contains("\"city\":\"Beijing\""));
463 assert!(!json.contains("region")); }
465
466 #[test]
467 fn test_web_search_with_location() {
468 let tool = Tool::web_search()
469 .user_location(
470 UserLocation::approximate()
471 .country("China")
472 .city("Shanghai"),
473 );
474
475 assert!(tool.user_location.is_some());
476 let loc = tool.user_location.unwrap();
477 assert_eq!(loc.country, Some("China".to_string()));
478 assert_eq!(loc.city, Some("Shanghai".to_string()));
479 }
480}