Skip to main content

mimo_api/types/
tool.rs

1//! Tool types for the MiMo API.
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7/// A tool that can be called by the model.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Tool {
10    /// Tool type
11    #[serde(rename = "type")]
12    pub tool_type: ToolType,
13    /// Function tool (if type is function)
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub function: Option<FunctionTool>,
16    /// Maximum number of keywords for web search (web_search only)
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub max_keyword: Option<u32>,
19    /// Force search even if model thinks it's unnecessary (web_search only)
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub force_search: Option<bool>,
22    /// Limit the number of search results (web_search only)
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub limit: Option<u32>,
25    /// User location for localized search (web_search only)
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub user_location: Option<UserLocation>,
28}
29
30impl Tool {
31    /// Create a new function tool.
32    pub fn function(name: impl Into<String>, description: impl Into<String>) -> Self {
33        Self {
34            tool_type: ToolType::Function,
35            function: Some(FunctionTool::new(name, description)),
36            max_keyword: None,
37            force_search: None,
38            limit: None,
39            user_location: None,
40        }
41    }
42
43    /// Create a new function tool with parameters.
44    pub fn function_with_params(
45        name: impl Into<String>,
46        description: impl Into<String>,
47        parameters: HashMap<String, Value>,
48    ) -> Self {
49        Self {
50            tool_type: ToolType::Function,
51            function: Some(FunctionTool::with_params(name, description, parameters)),
52            max_keyword: None,
53            force_search: None,
54            limit: None,
55            user_location: None,
56        }
57    }
58
59    /// Create a web search tool.
60    ///
61    /// **Note:** You must first enable the "联网服务插件" (Web Search Plugin)
62    /// in the MiMo console and set `web_search_enabled(true)` in your
63    /// `ChatRequest` before using this feature.
64    pub fn web_search() -> Self {
65        Self {
66            tool_type: ToolType::WebSearch,
67            function: None,
68            max_keyword: None,
69            force_search: None,
70            limit: None,
71            user_location: None,
72        }
73    }
74
75    /// Set maximum number of keywords for web search.
76    pub fn max_keyword(mut self, max: u32) -> Self {
77        self.max_keyword = Some(max);
78        self
79    }
80
81    /// Set whether to force search.
82    pub fn force_search(mut self, force: bool) -> Self {
83        self.force_search = Some(force);
84        self
85    }
86
87    /// Set the result limit for web search.
88    pub fn limit(mut self, limit: u32) -> Self {
89        self.limit = Some(limit);
90        self
91    }
92
93    /// Set the user location for localized search.
94    pub fn user_location(mut self, location: UserLocation) -> Self {
95        self.user_location = Some(location);
96        self
97    }
98
99    /// Set whether to use strict mode.
100    pub fn strict(mut self, strict: bool) -> Self {
101        if let Some(ref mut function) = self.function {
102            function.strict = Some(strict);
103        }
104        self
105    }
106}
107
108/// Tool type.
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum ToolType {
112    /// Function tool
113    Function,
114    /// Web search tool
115    WebSearch,
116}
117
118/// User location for localized web search.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct UserLocation {
121    /// Location type (e.g., "approximate")
122    #[serde(rename = "type")]
123    pub location_type: String,
124    /// Country name
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub country: Option<String>,
127    /// Region/Province name
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub region: Option<String>,
130    /// City name
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub city: Option<String>,
133}
134
135impl UserLocation {
136    /// Create a new user location.
137    pub fn new(location_type: impl Into<String>) -> Self {
138        Self {
139            location_type: location_type.into(),
140            country: None,
141            region: None,
142            city: None,
143        }
144    }
145
146    /// Create an approximate location.
147    pub fn approximate() -> Self {
148        Self::new("approximate")
149    }
150
151    /// Set the country.
152    pub fn country(mut self, country: impl Into<String>) -> Self {
153        self.country = Some(country.into());
154        self
155    }
156
157    /// Set the region.
158    pub fn region(mut self, region: impl Into<String>) -> Self {
159        self.region = Some(region.into());
160        self
161    }
162
163    /// Set the city.
164    pub fn city(mut self, city: impl Into<String>) -> Self {
165        self.city = Some(city.into());
166        self
167    }
168}
169
170/// Function tool definition.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct FunctionTool {
173    /// Function name (alphanumeric, underscore, hyphen; max 64 chars)
174    pub name: String,
175    /// Function description
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub description: Option<String>,
178    /// Function parameters (JSON Schema)
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub parameters: Option<HashMap<String, Value>>,
181    /// Whether to use strict mode
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub strict: Option<bool>,
184}
185
186impl FunctionTool {
187    /// Create a new function tool.
188    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189        Self {
190            name: name.into(),
191            description: Some(description.into()),
192            parameters: None,
193            strict: None,
194        }
195    }
196
197    /// Create a new function tool with parameters.
198    pub fn with_params(
199        name: impl Into<String>,
200        description: impl Into<String>,
201        parameters: HashMap<String, Value>,
202    ) -> Self {
203        Self {
204            name: name.into(),
205            description: Some(description.into()),
206            parameters: Some(parameters),
207            strict: None,
208        }
209    }
210
211    /// Set the parameters.
212    pub fn parameters(mut self, parameters: HashMap<String, Value>) -> Self {
213        self.parameters = Some(parameters);
214        self
215    }
216
217    /// Set strict mode.
218    pub fn strict(mut self, strict: bool) -> Self {
219        self.strict = Some(strict);
220        self
221    }
222}
223
224/// Builder for creating JSON Schema parameters.
225#[derive(Debug, Clone, Default)]
226pub struct ParameterBuilder {
227    params: HashMap<String, Value>,
228}
229
230impl ParameterBuilder {
231    /// Create a new parameter builder.
232    pub fn new() -> Self {
233        Self::default()
234    }
235
236    /// Set the schema type.
237    pub fn type_object(mut self) -> Self {
238        self.params.insert("type".to_string(), Value::String("object".to_string()));
239        self
240    }
241
242    /// Add a required property.
243    pub fn required_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
244        // Add to properties
245        let properties = self
246            .params
247            .entry("properties".to_string())
248            .or_insert_with(|| Value::Object(serde_json::Map::new()));
249        
250        if let Value::Object(props) = properties {
251            props.insert(name.to_string(), Value::Object(schema.into_iter().collect()));
252        }
253
254        // Add to required
255        let required = self
256            .params
257            .entry("required".to_string())
258            .or_insert_with(|| Value::Array(Vec::new()));
259        
260        if let Value::Array(req) = required {
261            req.push(Value::String(name.to_string()));
262        }
263
264        self
265    }
266
267    /// Add an optional property.
268    pub fn optional_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
269        let properties = self
270            .params
271            .entry("properties".to_string())
272            .or_insert_with(|| Value::Object(serde_json::Map::new()));
273        
274        if let Value::Object(props) = properties {
275            props.insert(name.to_string(), Value::Object(schema.into_iter().collect()));
276        }
277
278        self
279    }
280
281    /// Build the parameters.
282    pub fn build(self) -> HashMap<String, Value> {
283        self.params
284    }
285}
286
287/// Helper functions for creating property schemas.
288pub mod schema {
289    use serde_json::Value;
290    use std::collections::HashMap;
291
292    /// Create a string property schema.
293    pub fn string() -> HashMap<String, Value> {
294        let mut map = HashMap::new();
295        map.insert("type".to_string(), Value::String("string".to_string()));
296        map
297    }
298
299    /// Create a string property with description.
300    pub fn string_with_description(desc: &str) -> HashMap<String, Value> {
301        let mut map = string();
302        map.insert("description".to_string(), Value::String(desc.to_string()));
303        map
304    }
305
306    /// Create a number property schema.
307    pub fn number() -> HashMap<String, Value> {
308        let mut map = HashMap::new();
309        map.insert("type".to_string(), Value::String("number".to_string()));
310        map
311    }
312
313    /// Create an integer property schema.
314    pub fn integer() -> HashMap<String, Value> {
315        let mut map = HashMap::new();
316        map.insert("type".to_string(), Value::String("integer".to_string()));
317        map
318    }
319
320    /// Create a boolean property schema.
321    pub fn boolean() -> HashMap<String, Value> {
322        let mut map = HashMap::new();
323        map.insert("type".to_string(), Value::String("boolean".to_string()));
324        map
325    }
326
327    /// Create an array property schema.
328    pub fn array(items: HashMap<String, Value>) -> HashMap<String, Value> {
329        let mut map = HashMap::new();
330        map.insert("type".to_string(), Value::String("array".to_string()));
331        map.insert("items".to_string(), Value::Object(items.into_iter().collect()));
332        map
333    }
334
335    /// Create an enum property schema.
336    pub fn enum_values(values: &[&str]) -> HashMap<String, Value> {
337        let mut map = HashMap::new();
338        map.insert("type".to_string(), Value::String("string".to_string()));
339        map.insert(
340            "enum".to_string(),
341            Value::Array(values.iter().map(|v| Value::String(v.to_string())).collect()),
342        );
343        map
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_tool_creation() {
353        let tool = Tool::function("get_weather", "Get the current weather");
354        assert_eq!(tool.tool_type, ToolType::Function);
355        assert!(tool.function.is_some());
356    }
357
358    #[test]
359    fn test_web_search_tool() {
360        let tool = Tool::web_search();
361        assert_eq!(tool.tool_type, ToolType::WebSearch);
362        assert!(tool.function.is_none());
363    }
364
365    #[test]
366    fn test_tool_serialization() {
367        let tool = Tool::function("get_weather", "Get weather info");
368        let json = serde_json::to_string(&tool).unwrap();
369        assert!(json.contains("\"type\":\"function\""));
370        assert!(json.contains("\"name\":\"get_weather\""));
371    }
372
373    #[test]
374    fn test_tool_with_parameters() {
375        let mut params = HashMap::new();
376        params.insert("type".to_string(), Value::String("object".to_string()));
377        
378        let tool = Tool::function_with_params(
379            "get_weather",
380            "Get weather for a location",
381            params.clone(),
382        );
383        
384        assert!(tool.function.as_ref().unwrap().parameters.is_some());
385    }
386
387    #[test]
388    fn test_parameter_builder() {
389        let params = ParameterBuilder::new()
390            .type_object()
391            .required_property("location", schema::string())
392            .build();
393
394        assert_eq!(params.get("type").unwrap(), &Value::String("object".to_string()));
395    }
396
397    #[test]
398    fn test_schema_helpers() {
399        let s = schema::string();
400        assert_eq!(s.get("type").unwrap(), &Value::String("string".to_string()));
401
402        let n = schema::number();
403        assert_eq!(n.get("type").unwrap(), &Value::String("number".to_string()));
404
405        let e = schema::enum_values(&["a", "b", "c"]);
406        assert!(e.contains_key("enum"));
407    }
408
409    #[test]
410    fn test_strict_mode() {
411        let tool = Tool::function("test", "test function").strict(true);
412        assert_eq!(tool.function.unwrap().strict, Some(true));
413    }
414
415    #[test]
416    fn test_web_search_with_options() {
417        let tool = Tool::web_search()
418            .max_keyword(3)
419            .force_search(true)
420            .limit(5);
421        
422        assert_eq!(tool.tool_type, ToolType::WebSearch);
423        assert_eq!(tool.max_keyword, Some(3));
424        assert_eq!(tool.force_search, Some(true));
425        assert_eq!(tool.limit, Some(5));
426    }
427
428    #[test]
429    fn test_web_search_serialization() {
430        let tool = Tool::web_search()
431            .max_keyword(3)
432            .force_search(true);
433        
434        let json = serde_json::to_string(&tool).unwrap();
435        assert!(json.contains("\"type\":\"web_search\""));
436        assert!(json.contains("\"max_keyword\":3"));
437        assert!(json.contains("\"force_search\":true"));
438    }
439
440    #[test]
441    fn test_user_location() {
442        let location = UserLocation::approximate()
443            .country("China")
444            .region("Hubei")
445            .city("Wuhan");
446        
447        assert_eq!(location.location_type, "approximate");
448        assert_eq!(location.country, Some("China".to_string()));
449        assert_eq!(location.region, Some("Hubei".to_string()));
450        assert_eq!(location.city, Some("Wuhan".to_string()));
451    }
452
453    #[test]
454    fn test_user_location_serialization() {
455        let location = UserLocation::approximate()
456            .country("China")
457            .city("Beijing");
458        
459        let json = serde_json::to_string(&location).unwrap();
460        assert!(json.contains("\"type\":\"approximate\""));
461        assert!(json.contains("\"country\":\"China\""));
462        assert!(json.contains("\"city\":\"Beijing\""));
463        assert!(!json.contains("region")); // None should not be serialized
464    }
465
466    #[test]
467    fn test_web_search_with_location() {
468        let tool = Tool::web_search()
469            .user_location(
470                UserLocation::approximate()
471                    .country("China")
472                    .city("Shanghai"),
473            );
474        
475        assert!(tool.user_location.is_some());
476        let loc = tool.user_location.unwrap();
477        assert_eq!(loc.country, Some("China".to_string()));
478        assert_eq!(loc.city, Some("Shanghai".to_string()));
479    }
480}