1use crate::{De, Ser};
2use serde::{self, Deserialize, Serialize};
3use serde_json::Value;
4
5#[derive(Debug, Clone, PartialEq, Eq, Ser, De)]
7pub struct FunctionTool {
8 pub name: String,
10 pub description: String,
12 pub parameters: Value,
14 #[serde(skip_serializing_if = "Option::is_none")]
16 pub strict: Option<bool>,
17}
18
19#[derive(Debug, Clone, Ser, De)]
21pub struct FunctionCall {
22 pub call_id: String,
24 pub name: String,
26 pub arguments: String,
28}
29
30#[derive(Debug, Clone, Ser, De)]
32pub struct FunctionCallOutput {
33 pub call_id: String,
35 pub output: String,
37}
38
39#[derive(Debug, Clone, Ser, De)]
41#[serde(tag = "type")]
42pub enum Tool {
43 #[serde(rename = "function")]
45 Function {
46 function: FunctionTool,
48 },
49 #[serde(rename = "custom")]
51 Custom {
52 custom_tool: CustomTool,
54 },
55}
56
57#[derive(Debug, Clone, Ser, De)]
59pub struct CustomTool {
60 pub name: String,
62 pub description: String,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub grammar: Option<Grammar>,
67}
68
69#[derive(Debug, Clone, Ser, De)]
71#[serde(tag = "type")]
72pub enum Grammar {
73 #[serde(rename = "lark")]
75 Lark {
76 definition: String,
78 },
79 #[serde(rename = "regex")]
81 Regex {
82 pattern: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 flags: Option<Vec<String>>,
87 },
88}
89
90#[derive(Debug, Clone, Ser, De)]
92#[serde(untagged)]
93pub enum ToolChoice {
94 Auto,
96 Required,
98 None,
100 Function {
102 r#type: String,
104 function: FunctionSelector,
106 },
107 AllowedTools {
109 allowed_tools: Vec<String>,
111 },
112}
113
114#[derive(Debug, Clone, Ser, De)]
116pub struct FunctionSelector {
117 pub name: String,
119}
120
121impl FunctionTool {
122 pub fn new(name: impl Into<String>, description: impl Into<String>, parameters: Value) -> Self {
124 Self {
125 name: name.into(),
126 description: description.into(),
127 parameters,
128 strict: None,
129 }
130 }
131
132 #[must_use]
134 pub fn with_strict(mut self, strict: bool) -> Self {
135 self.strict = Some(strict);
136 self
137 }
138
139 pub fn simple(name: impl Into<String>, description: impl Into<String>) -> Self {
141 Self::new(
142 name,
143 description,
144 serde_json::json!({
145 "type": "object",
146 "properties": {},
147 "required": [],
148 "additionalProperties": false
149 }),
150 )
151 }
152}
153
154impl Tool {
155 #[must_use]
157 pub fn function(function: FunctionTool) -> Self {
158 Self::Function { function }
159 }
160
161 #[must_use]
163 pub fn custom(custom_tool: CustomTool) -> Self {
164 Self::Custom { custom_tool }
165 }
166
167 #[must_use]
169 pub fn name(&self) -> &str {
170 match self {
171 Self::Function { function } => &function.name,
172 Self::Custom { custom_tool } => &custom_tool.name,
173 }
174 }
175
176 #[must_use]
178 pub fn description(&self) -> &str {
179 match self {
180 Self::Function { function } => &function.description,
181 Self::Custom { custom_tool } => &custom_tool.description,
182 }
183 }
184}
185
186impl CustomTool {
187 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189 Self {
190 name: name.into(),
191 description: description.into(),
192 grammar: None,
193 }
194 }
195
196 pub fn with_lark_grammar(mut self, definition: impl Into<String>) -> Self {
198 self.grammar = Some(Grammar::Lark {
199 definition: definition.into(),
200 });
201 self
202 }
203
204 pub fn with_regex_grammar(
206 mut self,
207 pattern: impl Into<String>,
208 flags: Option<Vec<String>>,
209 ) -> Self {
210 self.grammar = Some(Grammar::Regex {
211 pattern: pattern.into(),
212 flags,
213 });
214 self
215 }
216}
217
218impl Grammar {
219 pub fn lark(definition: impl Into<String>) -> Self {
221 Self::Lark {
222 definition: definition.into(),
223 }
224 }
225
226 pub fn regex(pattern: impl Into<String>, flags: Option<Vec<String>>) -> Self {
228 Self::Regex {
229 pattern: pattern.into(),
230 flags,
231 }
232 }
233}
234
235impl ToolChoice {
236 #[must_use]
238 pub fn auto() -> Self {
239 Self::Auto
240 }
241
242 #[must_use]
244 pub fn required() -> Self {
245 Self::Required
246 }
247
248 #[must_use]
250 pub fn none() -> Self {
251 Self::None
252 }
253
254 pub fn function(name: impl Into<String>) -> Self {
256 Self::Function {
257 r#type: "function".to_string(),
258 function: FunctionSelector { name: name.into() },
259 }
260 }
261
262 #[must_use]
264 pub fn allowed_tools(tools: Vec<String>) -> Self {
265 Self::AllowedTools {
266 allowed_tools: tools,
267 }
268 }
269}
270
271impl FunctionCall {
272 pub fn new(
274 call_id: impl Into<String>,
275 name: impl Into<String>,
276 arguments: impl Into<String>,
277 ) -> Self {
278 Self {
279 call_id: call_id.into(),
280 name: name.into(),
281 arguments: arguments.into(),
282 }
283 }
284
285 pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
287 serde_json::from_str(&self.arguments)
288 }
289
290 pub fn arguments_json(&self) -> Result<Value, serde_json::Error> {
292 serde_json::from_str(&self.arguments)
293 }
294}
295
296impl FunctionCallOutput {
297 pub fn new(call_id: impl Into<String>, output: impl Into<String>) -> Self {
299 Self {
300 call_id: call_id.into(),
301 output: output.into(),
302 }
303 }
304
305 pub fn from_value<T: Serialize>(
307 call_id: impl Into<String>,
308 value: &T,
309 ) -> Result<Self, serde_json::Error> {
310 let output = serde_json::to_string(value)?;
311 Ok(Self::new(call_id, output))
312 }
313
314 pub fn from_json(call_id: impl Into<String>, value: &Value) -> Result<Self, serde_json::Error> {
316 let output = serde_json::to_string(&value)?;
317 Ok(Self::new(call_id, output))
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_function_tool_creation() {
327 let func = FunctionTool::new(
328 "get_weather",
329 "Get weather for a location",
330 serde_json::json!({
331 "type": "object",
332 "properties": {
333 "location": {"type": "string"},
334 "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
335 },
336 "required": ["location"]
337 }),
338 );
339
340 assert_eq!(func.name, "get_weather");
341 assert_eq!(func.description, "Get weather for a location");
342 assert!(func.strict.is_none());
343 }
344
345 #[test]
346 fn test_function_tool_with_strict() {
347 let func = FunctionTool::simple("test", "Test function").with_strict(true);
348 assert_eq!(func.strict, Some(true));
349 }
350
351 #[test]
352 fn test_tool_creation() {
353 let func_tool = FunctionTool::simple("test", "Test");
354 let tool = Tool::function(func_tool);
355
356 assert_eq!(tool.name(), "test");
357 assert_eq!(tool.description(), "Test");
358 }
359
360 #[test]
361 fn test_custom_tool_with_grammar() {
362 let tool =
363 CustomTool::new("parser", "Parse text").with_lark_grammar("start: word+\nword: /\\w+/");
364
365 assert_eq!(tool.name, "parser");
366 assert!(tool.grammar.is_some());
367
368 if let Some(Grammar::Lark { definition }) = &tool.grammar {
369 assert!(definition.contains("start: word+"));
370 } else {
371 panic!("Expected Lark grammar");
372 }
373 }
374
375 #[test]
376 fn test_tool_choice_variants() {
377 let auto = ToolChoice::auto();
378 let required = ToolChoice::required();
379 let none = ToolChoice::none();
380 let function = ToolChoice::function("get_weather");
381 let allowed = ToolChoice::allowed_tools(vec!["tool1".to_string(), "tool2".to_string()]);
382
383 assert!(matches!(auto, ToolChoice::Auto));
385 assert!(matches!(required, ToolChoice::Required));
386 assert!(matches!(none, ToolChoice::None));
387 assert!(matches!(function, ToolChoice::Function { .. }));
388 assert!(matches!(allowed, ToolChoice::AllowedTools { .. }));
389 }
390
391 #[test]
392 fn test_function_call_arguments() {
393 let call = FunctionCall::new(
394 "call-123",
395 "get_weather",
396 r#"{"location": "San Francisco", "unit": "celsius"}"#,
397 );
398
399 let args: Value = call.arguments_json().unwrap();
400 assert_eq!(args["location"], "San Francisco");
401 assert_eq!(args["unit"], "celsius");
402 }
403
404 #[test]
405 fn test_function_call_output() {
406 let output = FunctionCallOutput::new("call-123", "Temperature: 22°C");
407 assert_eq!(output.call_id, "call-123");
408 assert_eq!(output.output, "Temperature: 22°C");
409
410 let json_output = FunctionCallOutput::from_json(
411 "call-456",
412 &serde_json::json!({"temperature": 22, "unit": "celsius"}),
413 )
414 .unwrap();
415
416 assert_eq!(json_output.call_id, "call-456");
417 assert!(json_output.output.contains("temperature"));
418 }
419}