openai_rust_sdk/models/
functions.rs

1use crate::{De, Ser};
2use serde::{self, Deserialize, Serialize};
3use serde_json::Value;
4
5/// Function tool definition with JSON schema parameters
6#[derive(Debug, Clone, PartialEq, Eq, Ser, De)]
7pub struct FunctionTool {
8    /// Name of the function
9    pub name: String,
10    /// Description of what the function does
11    pub description: String,
12    /// JSON schema for the function parameters
13    pub parameters: Value,
14    /// Whether to use strict mode for reliable schema adherence
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub strict: Option<bool>,
17}
18
19/// A function call made by the model
20#[derive(Debug, Clone, Ser, De)]
21pub struct FunctionCall {
22    /// Unique identifier for this function call
23    pub call_id: String,
24    /// Name of the function being called
25    pub name: String,
26    /// JSON string containing the function arguments
27    pub arguments: String,
28}
29
30/// Output from a function call execution
31#[derive(Debug, Clone, Ser, De)]
32pub struct FunctionCallOutput {
33    /// The `call_id` this output corresponds to
34    pub call_id: String,
35    /// The output content from the function execution
36    pub output: String,
37}
38
39/// Different types of tools that can be used
40#[derive(Debug, Clone, Ser, De)]
41#[serde(tag = "type")]
42pub enum Tool {
43    /// Function tool
44    #[serde(rename = "function")]
45    Function {
46        /// The function definition
47        function: FunctionTool,
48    },
49    /// Custom tool (for extensibility)
50    #[serde(rename = "custom")]
51    Custom {
52        /// Custom tool definition
53        custom_tool: CustomTool,
54    },
55}
56
57/// Custom tool definition without explicit schema
58#[derive(Debug, Clone, Ser, De)]
59pub struct CustomTool {
60    /// Name of the custom tool
61    pub name: String,
62    /// Description of what the tool does
63    pub description: String,
64    /// Optional grammar specification
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub grammar: Option<Grammar>,
67}
68
69/// Grammar specification for custom tools
70#[derive(Debug, Clone, Ser, De)]
71#[serde(tag = "type")]
72pub enum Grammar {
73    /// Lark grammar syntax
74    #[serde(rename = "lark")]
75    Lark {
76        /// The Lark grammar definition
77        definition: String,
78    },
79    /// Regular expression syntax
80    #[serde(rename = "regex")]
81    Regex {
82        /// The regex pattern
83        pattern: String,
84        /// Optional flags for the regex
85        #[serde(skip_serializing_if = "Option::is_none")]
86        flags: Option<Vec<String>>,
87    },
88}
89
90/// Tool choice configuration
91#[derive(Debug, Clone, Ser, De)]
92#[serde(untagged)]
93pub enum ToolChoice {
94    /// Let the model choose automatically
95    Auto,
96    /// Require the model to call a tool
97    Required,
98    /// Don't use any tools
99    None,
100    /// Force a specific function to be called
101    Function {
102        /// Type must be "function"
103        r#type: String,
104        /// The function to force
105        function: FunctionSelector,
106    },
107    /// Only allow specific tools
108    AllowedTools {
109        /// List of allowed tool names
110        allowed_tools: Vec<String>,
111    },
112}
113
114/// Function selector for tool choice
115#[derive(Debug, Clone, Ser, De)]
116pub struct FunctionSelector {
117    /// Name of the function to select
118    pub name: String,
119}
120
121impl FunctionTool {
122    /// Create a new function tool
123    pub fn new(name: impl Into<String>, description: impl Into<String>, parameters: Value) -> Self {
124        Self {
125            name: name.into(),
126            description: description.into(),
127            parameters,
128            strict: None,
129        }
130    }
131
132    /// Enable strict mode for this function
133    #[must_use]
134    pub fn with_strict(mut self, strict: bool) -> Self {
135        self.strict = Some(strict);
136        self
137    }
138
139    /// Create a simple function with no parameters
140    pub fn simple(name: impl Into<String>, description: impl Into<String>) -> Self {
141        Self::new(
142            name,
143            description,
144            serde_json::json!({
145                "type": "object",
146                "properties": {},
147                "required": [],
148                "additionalProperties": false
149            }),
150        )
151    }
152}
153
154impl Tool {
155    /// Create a function tool
156    #[must_use]
157    pub fn function(function: FunctionTool) -> Self {
158        Self::Function { function }
159    }
160
161    /// Create a custom tool
162    #[must_use]
163    pub fn custom(custom_tool: CustomTool) -> Self {
164        Self::Custom { custom_tool }
165    }
166
167    /// Get the name of this tool
168    #[must_use]
169    pub fn name(&self) -> &str {
170        match self {
171            Self::Function { function } => &function.name,
172            Self::Custom { custom_tool } => &custom_tool.name,
173        }
174    }
175
176    /// Get the description of this tool
177    #[must_use]
178    pub fn description(&self) -> &str {
179        match self {
180            Self::Function { function } => &function.description,
181            Self::Custom { custom_tool } => &custom_tool.description,
182        }
183    }
184}
185
186impl CustomTool {
187    /// Create a new custom tool
188    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189        Self {
190            name: name.into(),
191            description: description.into(),
192            grammar: None,
193        }
194    }
195
196    /// Add a Lark grammar to this tool
197    pub fn with_lark_grammar(mut self, definition: impl Into<String>) -> Self {
198        self.grammar = Some(Grammar::Lark {
199            definition: definition.into(),
200        });
201        self
202    }
203
204    /// Add a regex grammar to this tool
205    pub fn with_regex_grammar(
206        mut self,
207        pattern: impl Into<String>,
208        flags: Option<Vec<String>>,
209    ) -> Self {
210        self.grammar = Some(Grammar::Regex {
211            pattern: pattern.into(),
212            flags,
213        });
214        self
215    }
216}
217
218impl Grammar {
219    /// Create a Lark grammar
220    pub fn lark(definition: impl Into<String>) -> Self {
221        Self::Lark {
222            definition: definition.into(),
223        }
224    }
225
226    /// Create a regex grammar
227    pub fn regex(pattern: impl Into<String>, flags: Option<Vec<String>>) -> Self {
228        Self::Regex {
229            pattern: pattern.into(),
230            flags,
231        }
232    }
233}
234
235impl ToolChoice {
236    /// Auto tool choice
237    #[must_use]
238    pub fn auto() -> Self {
239        Self::Auto
240    }
241
242    /// Required tool choice
243    #[must_use]
244    pub fn required() -> Self {
245        Self::Required
246    }
247
248    /// No tools
249    #[must_use]
250    pub fn none() -> Self {
251        Self::None
252    }
253
254    /// Force a specific function
255    pub fn function(name: impl Into<String>) -> Self {
256        Self::Function {
257            r#type: "function".to_string(),
258            function: FunctionSelector { name: name.into() },
259        }
260    }
261
262    /// Only allow specific tools
263    #[must_use]
264    pub fn allowed_tools(tools: Vec<String>) -> Self {
265        Self::AllowedTools {
266            allowed_tools: tools,
267        }
268    }
269}
270
271impl FunctionCall {
272    /// Create a new function call
273    pub fn new(
274        call_id: impl Into<String>,
275        name: impl Into<String>,
276        arguments: impl Into<String>,
277    ) -> Self {
278        Self {
279            call_id: call_id.into(),
280            name: name.into(),
281            arguments: arguments.into(),
282        }
283    }
284
285    /// Parse the arguments as JSON
286    pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
287        serde_json::from_str(&self.arguments)
288    }
289
290    /// Get arguments as a JSON Value
291    pub fn arguments_json(&self) -> Result<Value, serde_json::Error> {
292        serde_json::from_str(&self.arguments)
293    }
294}
295
296impl FunctionCallOutput {
297    /// Create a new function call output
298    pub fn new(call_id: impl Into<String>, output: impl Into<String>) -> Self {
299        Self {
300            call_id: call_id.into(),
301            output: output.into(),
302        }
303    }
304
305    /// Create output from a serializable value
306    pub fn from_value<T: Serialize>(
307        call_id: impl Into<String>,
308        value: &T,
309    ) -> Result<Self, serde_json::Error> {
310        let output = serde_json::to_string(value)?;
311        Ok(Self::new(call_id, output))
312    }
313
314    /// Create output from a JSON value
315    pub fn from_json(call_id: impl Into<String>, value: &Value) -> Result<Self, serde_json::Error> {
316        let output = serde_json::to_string(&value)?;
317        Ok(Self::new(call_id, output))
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_function_tool_creation() {
327        let func = FunctionTool::new(
328            "get_weather",
329            "Get weather for a location",
330            serde_json::json!({
331                "type": "object",
332                "properties": {
333                    "location": {"type": "string"},
334                    "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
335                },
336                "required": ["location"]
337            }),
338        );
339
340        assert_eq!(func.name, "get_weather");
341        assert_eq!(func.description, "Get weather for a location");
342        assert!(func.strict.is_none());
343    }
344
345    #[test]
346    fn test_function_tool_with_strict() {
347        let func = FunctionTool::simple("test", "Test function").with_strict(true);
348        assert_eq!(func.strict, Some(true));
349    }
350
351    #[test]
352    fn test_tool_creation() {
353        let func_tool = FunctionTool::simple("test", "Test");
354        let tool = Tool::function(func_tool);
355
356        assert_eq!(tool.name(), "test");
357        assert_eq!(tool.description(), "Test");
358    }
359
360    #[test]
361    fn test_custom_tool_with_grammar() {
362        let tool =
363            CustomTool::new("parser", "Parse text").with_lark_grammar("start: word+\nword: /\\w+/");
364
365        assert_eq!(tool.name, "parser");
366        assert!(tool.grammar.is_some());
367
368        if let Some(Grammar::Lark { definition }) = &tool.grammar {
369            assert!(definition.contains("start: word+"));
370        } else {
371            panic!("Expected Lark grammar");
372        }
373    }
374
375    #[test]
376    fn test_tool_choice_variants() {
377        let auto = ToolChoice::auto();
378        let required = ToolChoice::required();
379        let none = ToolChoice::none();
380        let function = ToolChoice::function("get_weather");
381        let allowed = ToolChoice::allowed_tools(vec!["tool1".to_string(), "tool2".to_string()]);
382
383        // Just test that they can be created without panicking
384        assert!(matches!(auto, ToolChoice::Auto));
385        assert!(matches!(required, ToolChoice::Required));
386        assert!(matches!(none, ToolChoice::None));
387        assert!(matches!(function, ToolChoice::Function { .. }));
388        assert!(matches!(allowed, ToolChoice::AllowedTools { .. }));
389    }
390
391    #[test]
392    fn test_function_call_arguments() {
393        let call = FunctionCall::new(
394            "call-123",
395            "get_weather",
396            r#"{"location": "San Francisco", "unit": "celsius"}"#,
397        );
398
399        let args: Value = call.arguments_json().unwrap();
400        assert_eq!(args["location"], "San Francisco");
401        assert_eq!(args["unit"], "celsius");
402    }
403
404    #[test]
405    fn test_function_call_output() {
406        let output = FunctionCallOutput::new("call-123", "Temperature: 22°C");
407        assert_eq!(output.call_id, "call-123");
408        assert_eq!(output.output, "Temperature: 22°C");
409
410        let json_output = FunctionCallOutput::from_json(
411            "call-456",
412            &serde_json::json!({"temperature": 22, "unit": "celsius"}),
413        )
414        .unwrap();
415
416        assert_eq!(json_output.call_id, "call-456");
417        assert!(json_output.output.contains("temperature"));
418    }
419}