Skip to main content

mimo_api/types/
tool.rs

1//! Tool types for the MiMo API.
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7/// A tool that can be called by the model.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Tool {
10    /// Tool type
11    #[serde(rename = "type")]
12    pub tool_type: ToolType,
13    /// Function tool (if type is function)
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub function: Option<FunctionTool>,
16    /// Maximum number of keywords for web search (web_search only)
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub max_keyword: Option<u32>,
19    /// Force search even if model thinks it's unnecessary (web_search only)
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub force_search: Option<bool>,
22    /// Limit the number of search results (web_search only)
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub limit: Option<u32>,
25    /// User location for localized search (web_search only)
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub user_location: Option<UserLocation>,
28}
29
30impl Tool {
31    /// Create a new function tool.
32    pub fn function(name: impl Into<String>, description: impl Into<String>) -> Self {
33        Self {
34            tool_type: ToolType::Function,
35            function: Some(FunctionTool::new(name, description)),
36            max_keyword: None,
37            force_search: None,
38            limit: None,
39            user_location: None,
40        }
41    }
42
43    /// Create a new function tool with parameters.
44    pub fn function_with_params(
45        name: impl Into<String>,
46        description: impl Into<String>,
47        parameters: HashMap<String, Value>,
48    ) -> Self {
49        Self {
50            tool_type: ToolType::Function,
51            function: Some(FunctionTool::with_params(name, description, parameters)),
52            max_keyword: None,
53            force_search: None,
54            limit: None,
55            user_location: None,
56        }
57    }
58
59    /// Create a web search tool.
60    ///
61    /// **Note:** You must first enable the "联网服务插件" (Web Search Plugin)
62    /// in the MiMo console and set `web_search_enabled(true)` in your
63    /// `ChatRequest` before using this feature.
64    pub fn web_search() -> Self {
65        Self {
66            tool_type: ToolType::WebSearch,
67            function: None,
68            max_keyword: None,
69            force_search: None,
70            limit: None,
71            user_location: None,
72        }
73    }
74
75    /// Set maximum number of keywords for web search.
76    pub fn max_keyword(mut self, max: u32) -> Self {
77        self.max_keyword = Some(max);
78        self
79    }
80
81    /// Set whether to force search.
82    pub fn force_search(mut self, force: bool) -> Self {
83        self.force_search = Some(force);
84        self
85    }
86
87    /// Set the result limit for web search.
88    pub fn limit(mut self, limit: u32) -> Self {
89        self.limit = Some(limit);
90        self
91    }
92
93    /// Set the user location for localized search.
94    pub fn user_location(mut self, location: UserLocation) -> Self {
95        self.user_location = Some(location);
96        self
97    }
98
99    /// Set whether to use strict mode.
100    pub fn strict(mut self, strict: bool) -> Self {
101        if let Some(ref mut function) = self.function {
102            function.strict = Some(strict);
103        }
104        self
105    }
106}
107
108/// Tool type.
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum ToolType {
112    /// Function tool
113    Function,
114    /// Web search tool
115    WebSearch,
116}
117
118/// User location for localized web search.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct UserLocation {
121    /// Location type (e.g., "approximate")
122    #[serde(rename = "type")]
123    pub location_type: String,
124    /// Country name
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub country: Option<String>,
127    /// Region/Province name
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub region: Option<String>,
130    /// City name
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub city: Option<String>,
133}
134
135impl UserLocation {
136    /// Create a new user location.
137    pub fn new(location_type: impl Into<String>) -> Self {
138        Self {
139            location_type: location_type.into(),
140            country: None,
141            region: None,
142            city: None,
143        }
144    }
145
146    /// Create an approximate location.
147    pub fn approximate() -> Self {
148        Self::new("approximate")
149    }
150
151    /// Set the country.
152    pub fn country(mut self, country: impl Into<String>) -> Self {
153        self.country = Some(country.into());
154        self
155    }
156
157    /// Set the region.
158    pub fn region(mut self, region: impl Into<String>) -> Self {
159        self.region = Some(region.into());
160        self
161    }
162
163    /// Set the city.
164    pub fn city(mut self, city: impl Into<String>) -> Self {
165        self.city = Some(city.into());
166        self
167    }
168}
169
170/// Function tool definition.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct FunctionTool {
173    /// Function name (alphanumeric, underscore, hyphen; max 64 chars)
174    pub name: String,
175    /// Function description
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub description: Option<String>,
178    /// Function parameters (JSON Schema)
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub parameters: Option<HashMap<String, Value>>,
181    /// Whether to use strict mode
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub strict: Option<bool>,
184}
185
186impl FunctionTool {
187    /// Create a new function tool.
188    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189        Self {
190            name: name.into(),
191            description: Some(description.into()),
192            parameters: None,
193            strict: None,
194        }
195    }
196
197    /// Create a new function tool with parameters.
198    pub fn with_params(
199        name: impl Into<String>,
200        description: impl Into<String>,
201        parameters: HashMap<String, Value>,
202    ) -> Self {
203        Self {
204            name: name.into(),
205            description: Some(description.into()),
206            parameters: Some(parameters),
207            strict: None,
208        }
209    }
210
211    /// Set the parameters.
212    pub fn parameters(mut self, parameters: HashMap<String, Value>) -> Self {
213        self.parameters = Some(parameters);
214        self
215    }
216
217    /// Set strict mode.
218    pub fn strict(mut self, strict: bool) -> Self {
219        self.strict = Some(strict);
220        self
221    }
222}
223
224/// Builder for creating JSON Schema parameters.
225#[derive(Debug, Clone, Default)]
226pub struct ParameterBuilder {
227    params: HashMap<String, Value>,
228}
229
230impl ParameterBuilder {
231    /// Create a new parameter builder.
232    pub fn new() -> Self {
233        Self::default()
234    }
235
236    /// Set the schema type.
237    pub fn type_object(mut self) -> Self {
238        self.params
239            .insert("type".to_string(), Value::String("object".to_string()));
240        self
241    }
242
243    /// Add a required property.
244    pub fn required_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
245        // Add to properties
246        let properties = self
247            .params
248            .entry("properties".to_string())
249            .or_insert_with(|| Value::Object(serde_json::Map::new()));
250
251        if let Value::Object(props) = properties {
252            props.insert(
253                name.to_string(),
254                Value::Object(schema.into_iter().collect()),
255            );
256        }
257
258        // Add to required
259        let required = self
260            .params
261            .entry("required".to_string())
262            .or_insert_with(|| Value::Array(Vec::new()));
263
264        if let Value::Array(req) = required {
265            req.push(Value::String(name.to_string()));
266        }
267
268        self
269    }
270
271    /// Add an optional property.
272    pub fn optional_property(mut self, name: &str, schema: HashMap<String, Value>) -> Self {
273        let properties = self
274            .params
275            .entry("properties".to_string())
276            .or_insert_with(|| Value::Object(serde_json::Map::new()));
277
278        if let Value::Object(props) = properties {
279            props.insert(
280                name.to_string(),
281                Value::Object(schema.into_iter().collect()),
282            );
283        }
284
285        self
286    }
287
288    /// Build the parameters.
289    pub fn build(self) -> HashMap<String, Value> {
290        self.params
291    }
292}
293
294/// Helper functions for creating property schemas.
295pub mod schema {
296    use serde_json::Value;
297    use std::collections::HashMap;
298
299    /// Create a string property schema.
300    pub fn string() -> HashMap<String, Value> {
301        let mut map = HashMap::new();
302        map.insert("type".to_string(), Value::String("string".to_string()));
303        map
304    }
305
306    /// Create a string property with description.
307    pub fn string_with_description(desc: &str) -> HashMap<String, Value> {
308        let mut map = string();
309        map.insert("description".to_string(), Value::String(desc.to_string()));
310        map
311    }
312
313    /// Create a number property schema.
314    pub fn number() -> HashMap<String, Value> {
315        let mut map = HashMap::new();
316        map.insert("type".to_string(), Value::String("number".to_string()));
317        map
318    }
319
320    /// Create an integer property schema.
321    pub fn integer() -> HashMap<String, Value> {
322        let mut map = HashMap::new();
323        map.insert("type".to_string(), Value::String("integer".to_string()));
324        map
325    }
326
327    /// Create a boolean property schema.
328    pub fn boolean() -> HashMap<String, Value> {
329        let mut map = HashMap::new();
330        map.insert("type".to_string(), Value::String("boolean".to_string()));
331        map
332    }
333
334    /// Create an array property schema.
335    pub fn array(items: HashMap<String, Value>) -> HashMap<String, Value> {
336        let mut map = HashMap::new();
337        map.insert("type".to_string(), Value::String("array".to_string()));
338        map.insert(
339            "items".to_string(),
340            Value::Object(items.into_iter().collect()),
341        );
342        map
343    }
344
345    /// Create an enum property schema.
346    pub fn enum_values(values: &[&str]) -> HashMap<String, Value> {
347        let mut map = HashMap::new();
348        map.insert("type".to_string(), Value::String("string".to_string()));
349        map.insert(
350            "enum".to_string(),
351            Value::Array(
352                values
353                    .iter()
354                    .map(|v| Value::String(v.to_string()))
355                    .collect(),
356            ),
357        );
358        map
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_tool_creation() {
368        let tool = Tool::function("get_weather", "Get the current weather");
369        assert_eq!(tool.tool_type, ToolType::Function);
370        assert!(tool.function.is_some());
371    }
372
373    #[test]
374    fn test_web_search_tool() {
375        let tool = Tool::web_search();
376        assert_eq!(tool.tool_type, ToolType::WebSearch);
377        assert!(tool.function.is_none());
378    }
379
380    #[test]
381    fn test_tool_serialization() {
382        let tool = Tool::function("get_weather", "Get weather info");
383        let json = serde_json::to_string(&tool).unwrap();
384        assert!(json.contains("\"type\":\"function\""));
385        assert!(json.contains("\"name\":\"get_weather\""));
386    }
387
388    #[test]
389    fn test_tool_with_parameters() {
390        let mut params = HashMap::new();
391        params.insert("type".to_string(), Value::String("object".to_string()));
392
393        let tool =
394            Tool::function_with_params("get_weather", "Get weather for a location", params.clone());
395
396        assert!(tool.function.as_ref().unwrap().parameters.is_some());
397    }
398
399    #[test]
400    fn test_parameter_builder() {
401        let params = ParameterBuilder::new()
402            .type_object()
403            .required_property("location", schema::string())
404            .build();
405
406        assert_eq!(
407            params.get("type").unwrap(),
408            &Value::String("object".to_string())
409        );
410    }
411
412    #[test]
413    fn test_schema_helpers() {
414        let s = schema::string();
415        assert_eq!(s.get("type").unwrap(), &Value::String("string".to_string()));
416
417        let n = schema::number();
418        assert_eq!(n.get("type").unwrap(), &Value::String("number".to_string()));
419
420        let e = schema::enum_values(&["a", "b", "c"]);
421        assert!(e.contains_key("enum"));
422    }
423
424    #[test]
425    fn test_strict_mode() {
426        let tool = Tool::function("test", "test function").strict(true);
427        assert_eq!(tool.function.unwrap().strict, Some(true));
428    }
429
430    #[test]
431    fn test_web_search_with_options() {
432        let tool = Tool::web_search()
433            .max_keyword(3)
434            .force_search(true)
435            .limit(5);
436
437        assert_eq!(tool.tool_type, ToolType::WebSearch);
438        assert_eq!(tool.max_keyword, Some(3));
439        assert_eq!(tool.force_search, Some(true));
440        assert_eq!(tool.limit, Some(5));
441    }
442
443    #[test]
444    fn test_web_search_serialization() {
445        let tool = Tool::web_search().max_keyword(3).force_search(true);
446
447        let json = serde_json::to_string(&tool).unwrap();
448        assert!(json.contains("\"type\":\"web_search\""));
449        assert!(json.contains("\"max_keyword\":3"));
450        assert!(json.contains("\"force_search\":true"));
451    }
452
453    #[test]
454    fn test_user_location() {
455        let location = UserLocation::approximate()
456            .country("China")
457            .region("Hubei")
458            .city("Wuhan");
459
460        assert_eq!(location.location_type, "approximate");
461        assert_eq!(location.country, Some("China".to_string()));
462        assert_eq!(location.region, Some("Hubei".to_string()));
463        assert_eq!(location.city, Some("Wuhan".to_string()));
464    }
465
466    #[test]
467    fn test_user_location_serialization() {
468        let location = UserLocation::approximate().country("China").city("Beijing");
469
470        let json = serde_json::to_string(&location).unwrap();
471        assert!(json.contains("\"type\":\"approximate\""));
472        assert!(json.contains("\"country\":\"China\""));
473        assert!(json.contains("\"city\":\"Beijing\""));
474        assert!(!json.contains("region")); // None should not be serialized
475    }
476
477    #[test]
478    fn test_web_search_with_location() {
479        let tool = Tool::web_search().user_location(
480            UserLocation::approximate()
481                .country("China")
482                .city("Shanghai"),
483        );
484
485        assert!(tool.user_location.is_some());
486        let loc = tool.user_location.unwrap();
487        assert_eq!(loc.country, Some("China".to_string()));
488        assert_eq!(loc.city, Some("Shanghai".to_string()));
489    }
490}