1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6use typed_builder::TypedBuilder;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub enum ToolApproval {
12 Approved,
14 Denied(String),
16 Quit,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22pub struct Property {
23 #[serde(rename = "type")]
25 pub prop_type: String,
26 pub description: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32pub struct Parameters {
33 #[serde(rename = "type")]
35 pub param_type: String,
36 pub properties: HashMap<String, Property>,
38 pub required: Vec<String>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
44pub struct Function {
45 pub name: String,
47 pub description: String,
49 pub parameters: serde_json::Value,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder, PartialEq)]
55pub struct Tool {
56 #[serde(rename = "type")]
58 #[builder(default = "function".to_string())]
59 pub r#type: String,
60 pub function: Function,
62}
63
64#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
66pub struct FunctionCall {
67 pub name: String,
69 pub arguments: Vec<String>,
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
78pub struct ToolCall {
79 pub id: String,
81 pub function: FunctionCall,
83 pub call_type: String,
85}
86
87impl ToolCall {
88 pub fn new<I, T>(name: impl Into<String>, arguments: I) -> Self
90 where
91 I: IntoIterator<Item = T>,
92 T: Into<String>,
93 {
94 Self {
95 id: Uuid::new_v4().to_string(),
96 function: FunctionCall {
97 name: name.into(),
98 arguments: arguments.into_iter().map(|arg| arg.into()).collect(),
99 },
100 call_type: "function".to_string(),
101 }
102 }
103
104 pub fn merge_deltas(mut accumulated: Vec<Self>, deltas: &[Self]) -> Vec<Self> {
109 for delta in deltas {
110 if let Some(existing) = accumulated.iter_mut().find(|tc| tc.id == delta.id) {
111 for arg in &delta.function.arguments {
113 if let Some(last_arg) = existing.function.arguments.last_mut() {
114 last_arg.push_str(arg);
116 } else {
117 existing.function.arguments.push(arg.clone());
119 }
120 }
121 } else {
122 accumulated.push(delta.clone());
124 }
125 }
126
127 accumulated
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_tool_approval_variants() {
137 let approved = ToolApproval::Approved;
138 let denied = ToolApproval::Denied("Invalid request".to_string());
139 let quit = ToolApproval::Quit;
140
141 assert_eq!(approved, ToolApproval::Approved);
142 assert_eq!(denied, ToolApproval::Denied("Invalid request".to_string()));
143 assert_eq!(quit, ToolApproval::Quit);
144 }
145
146 #[test]
147 fn test_tool_approval_serialization() {
148 let approved = ToolApproval::Approved;
149 let json = serde_json::to_string(&approved).expect("Failed to serialize");
150 let deserialized: ToolApproval =
151 serde_json::from_str(&json).expect("Failed to deserialize");
152 assert_eq!(approved, deserialized);
153
154 let denied = ToolApproval::Denied("Reason".to_string());
155 let json = serde_json::to_string(&denied).expect("Failed to serialize");
156 let deserialized: ToolApproval =
157 serde_json::from_str(&json).expect("Failed to deserialize");
158 assert_eq!(denied, deserialized);
159 }
160
161 #[test]
162 fn test_property_creation() {
163 let prop = Property {
164 prop_type: "string".to_string(),
165 description: "The user's name".to_string(),
166 };
167
168 assert_eq!(prop.prop_type, "string");
169 assert_eq!(prop.description, "The user's name");
170 }
171
172 #[test]
173 fn test_property_serialization() {
174 let prop = Property {
175 prop_type: "number".to_string(),
176 description: "Age in years".to_string(),
177 };
178
179 let json = serde_json::to_value(&prop).expect("Failed to serialize");
180 assert_eq!(json["type"], "number");
181 assert_eq!(json["description"], "Age in years");
182
183 let deserialized: Property = serde_json::from_value(json).expect("Failed to deserialize");
184 assert_eq!(prop, deserialized);
185 }
186
187 #[test]
188 fn test_parameters_creation() {
189 let mut properties = HashMap::new();
190 properties.insert(
191 "location".to_string(),
192 Property {
193 prop_type: "string".to_string(),
194 description: "City name".to_string(),
195 },
196 );
197
198 let params = Parameters {
199 param_type: "object".to_string(),
200 properties,
201 required: vec!["location".to_string()],
202 };
203
204 assert_eq!(params.param_type, "object");
205 assert_eq!(params.properties.len(), 1);
206 assert_eq!(params.required, vec!["location"]);
207 }
208
209 #[test]
210 fn test_parameters_serialization() {
211 let mut properties = HashMap::new();
212 properties.insert(
213 "name".to_string(),
214 Property {
215 prop_type: "string".to_string(),
216 description: "Name".to_string(),
217 },
218 );
219
220 let params = Parameters {
221 param_type: "object".to_string(),
222 properties,
223 required: vec!["name".to_string()],
224 };
225
226 let json = serde_json::to_value(¶ms).expect("Failed to serialize");
227 assert_eq!(json["type"], "object");
228 assert!(json["properties"].is_object());
229 assert!(json["required"].is_array());
230
231 let deserialized: Parameters = serde_json::from_value(json).expect("Failed to deserialize");
232 assert_eq!(params, deserialized);
233 }
234
235 #[test]
236 fn test_function_creation() {
237 let func = Function {
238 name: "get_weather".to_string(),
239 description: "Get weather for a location".to_string(),
240 parameters: serde_json::json!({
241 "type": "object",
242 "properties": {},
243 "required": [],
244 }),
245 };
246
247 assert_eq!(func.name, "get_weather");
248 assert_eq!(func.description, "Get weather for a location");
249 assert!(func.parameters.is_object());
250 }
251
252 #[test]
253 fn test_function_serialization() {
254 let func = Function {
255 name: "calculate".to_string(),
256 description: "Perform calculation".to_string(),
257 parameters: serde_json::json!({"type": "object"}),
258 };
259
260 let json = serde_json::to_value(&func).expect("Failed to serialize");
261 assert_eq!(json["name"], "calculate");
262 assert_eq!(json["description"], "Perform calculation");
263
264 let deserialized: Function = serde_json::from_value(json).expect("Failed to deserialize");
265 assert_eq!(func, deserialized);
266 }
267
268 #[test]
269 fn test_tool_builder() {
270 let tool = Tool::builder()
271 .function(Function {
272 name: "test_func".to_string(),
273 description: "A test function".to_string(),
274 parameters: serde_json::json!({}),
275 })
276 .build();
277
278 assert_eq!(tool.r#type, "function");
279 assert_eq!(tool.function.name, "test_func");
280 }
281
282 #[test]
283 fn test_tool_builder_with_custom_type() {
284 let tool = Tool::builder()
285 .r#type("custom".to_string())
286 .function(Function {
287 name: "custom_func".to_string(),
288 description: "Custom function".to_string(),
289 parameters: serde_json::json!({}),
290 })
291 .build();
292
293 assert_eq!(tool.r#type, "custom");
294 assert_eq!(tool.function.name, "custom_func");
295 }
296
297 #[test]
298 fn test_tool_serialization() {
299 let tool = Tool::builder()
300 .function(Function {
301 name: "test".to_string(),
302 description: "Test".to_string(),
303 parameters: serde_json::json!({}),
304 })
305 .build();
306
307 let json = serde_json::to_value(&tool).expect("Failed to serialize");
308 assert_eq!(json["type"], "function");
309 assert_eq!(json["function"]["name"], "test");
310
311 let deserialized: Tool = serde_json::from_value(json).expect("Failed to deserialize");
312 assert_eq!(tool, deserialized);
313 }
314
315 #[test]
316 fn test_function_call_creation() {
317 let call = FunctionCall {
318 name: "my_function".to_string(),
319 arguments: vec!["arg1".to_string(), "arg2".to_string()],
320 };
321
322 assert_eq!(call.name, "my_function");
323 assert_eq!(call.arguments.len(), 2);
324 assert_eq!(call.arguments[0], "arg1");
325 }
326
327 #[test]
328 fn test_tool_call_new() {
329 let call = ToolCall::new("get_weather", vec!["NYC".to_string()]);
330
331 assert!(!call.id.is_empty());
332 assert_eq!(call.function.name, "get_weather");
333 assert_eq!(call.function.arguments, vec!["NYC"]);
334 assert_eq!(call.call_type, "function");
335 }
336
337 #[test]
338 fn test_tool_call_new_with_array_literal() {
339 let call = ToolCall::new("test_func", [r#"{"key": "value"}"#]);
340
341 assert_eq!(call.function.name, "test_func");
342 assert_eq!(call.function.arguments.len(), 1);
343 assert_eq!(call.function.arguments[0], r#"{"key": "value"}"#);
344 }
345
346 #[test]
347 fn test_tool_call_new_empty_args() {
348 let call = ToolCall::new("no_args_func", Vec::<String>::new());
349
350 assert_eq!(call.function.name, "no_args_func");
351 assert!(call.function.arguments.is_empty());
352 }
353
354 #[test]
355 fn test_tool_call_new_multiple_args() {
356 let call = ToolCall::new(
357 "multi_arg_func",
358 vec!["arg1".to_string(), "arg2".to_string(), "arg3".to_string()],
359 );
360
361 assert_eq!(call.function.arguments.len(), 3);
362 assert_eq!(call.function.arguments[0], "arg1");
363 assert_eq!(call.function.arguments[1], "arg2");
364 assert_eq!(call.function.arguments[2], "arg3");
365 }
366
367 #[test]
368 fn test_tool_call_serialization() {
369 let call = ToolCall::new("test_function", vec!["test_arg".to_string()]);
370
371 let json = serde_json::to_value(&call).expect("Failed to serialize");
372 assert_eq!(json["function"]["name"], "test_function");
373 assert_eq!(json["call_type"], "function");
374
375 let deserialized: ToolCall = serde_json::from_value(json).expect("Failed to deserialize");
376 assert_eq!(call.function.name, deserialized.function.name);
377 assert_eq!(call.function.arguments, deserialized.function.arguments);
378 }
379
380 #[test]
381 fn test_tool_call_unique_ids() {
382 let call1 = ToolCall::new("func", Vec::<String>::new());
383 let call2 = ToolCall::new("func", Vec::<String>::new());
384
385 assert_ne!(call1.id, call2.id);
386 }
387
388 #[test]
389 fn test_tool_call_delta_merging() {
390 let deltas = vec![
393 ToolCall {
395 id: "call_123".to_string(),
396 call_type: "function".to_string(),
397 function: FunctionCall {
398 name: "test_function".to_string(),
399 arguments: vec![r#"{"param1": ""#.to_string()],
400 },
401 },
402 ToolCall {
404 id: "call_123".to_string(),
405 call_type: "function".to_string(),
406 function: FunctionCall {
407 name: "test_function".to_string(),
408 arguments: vec![r#"hello", "param2": "#.to_string()],
409 },
410 },
411 ToolCall {
413 id: "call_123".to_string(),
414 call_type: "function".to_string(),
415 function: FunctionCall {
416 name: "test_function".to_string(),
417 arguments: vec![r#"123}"#.to_string()],
418 },
419 },
420 ];
421
422 let mut tool_calls: Vec<ToolCall> = Vec::new();
424 for delta in &deltas {
425 tool_calls = ToolCall::merge_deltas(tool_calls, std::slice::from_ref(delta));
426 }
427
428 assert_eq!(tool_calls.len(), 1);
430
431 let merged = &tool_calls[0];
432 assert_eq!(merged.id, "call_123");
433 assert_eq!(merged.function.name, "test_function");
434 assert_eq!(merged.function.arguments.len(), 1);
435 assert_eq!(
436 merged.function.arguments[0],
437 r#"{"param1": "hello", "param2": 123}"#
438 );
439
440 let parsed: serde_json::Value = serde_json::from_str(&merged.function.arguments[0])
442 .expect("Merged arguments should be valid JSON");
443 assert_eq!(parsed["param1"], "hello");
444 assert_eq!(parsed["param2"], 123);
445 }
446
447 #[test]
448 fn test_multiple_tool_call_delta_merging() {
449 let deltas = vec![
451 ToolCall {
453 id: "call_1".to_string(),
454 call_type: "function".to_string(),
455 function: FunctionCall {
456 name: "func1".to_string(),
457 arguments: vec![r#"{"a":"#.to_string()],
458 },
459 },
460 ToolCall {
462 id: "call_2".to_string(),
463 call_type: "function".to_string(),
464 function: FunctionCall {
465 name: "func2".to_string(),
466 arguments: vec![r#"{"b":"#.to_string()],
467 },
468 },
469 ToolCall {
471 id: "call_1".to_string(),
472 call_type: "function".to_string(),
473 function: FunctionCall {
474 name: "func1".to_string(),
475 arguments: vec![r#"1}"#.to_string()],
476 },
477 },
478 ToolCall {
480 id: "call_2".to_string(),
481 call_type: "function".to_string(),
482 function: FunctionCall {
483 name: "func2".to_string(),
484 arguments: vec![r#"2}"#.to_string()],
485 },
486 },
487 ];
488
489 let mut tool_calls: Vec<ToolCall> = Vec::new();
491 for delta in &deltas {
492 tool_calls = ToolCall::merge_deltas(tool_calls, std::slice::from_ref(delta));
493 }
494
495 assert_eq!(tool_calls.len(), 2);
497
498 let call1 = &tool_calls[0];
499 assert_eq!(call1.id, "call_1");
500 assert_eq!(call1.function.name, "func1");
501 assert_eq!(call1.function.arguments[0], r#"{"a":1}"#);
502
503 let call2 = &tool_calls[1];
504 assert_eq!(call2.id, "call_2");
505 assert_eq!(call2.function.name, "func2");
506 assert_eq!(call2.function.arguments[0], r#"{"b":2}"#);
507
508 serde_json::from_str::<serde_json::Value>(&call1.function.arguments[0])
510 .expect("First call should be valid JSON");
511 serde_json::from_str::<serde_json::Value>(&call2.function.arguments[0])
512 .expect("Second call should be valid JSON");
513 }
514}
515
516#[cfg(test)]
517mod proptests {
518 use super::*;
519 use proptest::prelude::*;
520
521 proptest! {
522 #[test]
523 fn fuzz_tool_call_deserialization(data in prop::collection::vec(any::<u8>(), 0..1000)) {
524 let _ = serde_json::from_slice::<ToolCall>(&data);
526 }
527
528 #[test]
529 fn fuzz_function_call_with_arbitrary_args(
530 name in ".*",
531 args in prop::collection::vec(".*", 0..10),
532 ) {
533 let call = FunctionCall {
534 name: name.clone(),
535 arguments: args.clone(),
536 };
537
538 let json = serde_json::to_string(&call).unwrap();
540 let parsed: FunctionCall = serde_json::from_str(&json).unwrap();
541 assert_eq!(call.name, parsed.name);
542 assert_eq!(call.arguments, parsed.arguments);
543 }
544
545 #[test]
546 fn fuzz_tool_call_new_with_special_chars(
547 func_name in r#"[a-zA-Z0-9_\-\.]{1,50}"#,
548 args in prop::collection::vec(r#"[\\x00-\\x7F]*"#, 0..5),
549 ) {
550 let call = ToolCall::new(func_name.clone(), args.clone());
551
552 assert_eq!(call.function.name, func_name);
553 assert_eq!(call.function.arguments, args);
554 assert_eq!(call.call_type, "function");
555 assert!(!call.id.is_empty());
556 }
557
558 #[test]
559 fn fuzz_tool_deserialization(data in prop::collection::vec(any::<u8>(), 0..1000)) {
560 let _ = serde_json::from_slice::<Tool>(&data);
562 }
563
564 #[test]
565 fn fuzz_function_with_arbitrary_json_params(
566 name in ".*",
567 description in ".*",
568 ) {
569 let params_variants = vec![
571 serde_json::json!({}),
572 serde_json::json!({"type": "object"}),
573 serde_json::json!({"type": "object", "properties": {}, "required": []}),
574 serde_json::json!(null),
575 serde_json::json!([]),
576 serde_json::json!("string"),
577 ];
578
579 for params in params_variants {
580 let func = Function {
581 name: name.clone(),
582 description: description.clone(),
583 parameters: params.clone(),
584 };
585
586 let json = serde_json::to_string(&func).unwrap();
588 let parsed: Function = serde_json::from_str(&json).unwrap();
589 assert_eq!(func.name, parsed.name);
590 assert_eq!(func.description, parsed.description);
591 }
592 }
593
594 #[test]
595 fn fuzz_parameters_with_arbitrary_properties(
596 num_props in 0usize..10,
597 ) {
598 let mut properties = HashMap::new();
599
600 for i in 0..num_props {
601 properties.insert(
602 format!("prop_{}", i),
603 Property {
604 prop_type: format!("type_{}", i % 3),
605 description: format!("desc_{}", i),
606 },
607 );
608 }
609
610 let params = Parameters {
611 param_type: "object".to_string(),
612 properties: properties.clone(),
613 required: (0..num_props).map(|i| format!("prop_{}", i)).collect(),
614 };
615
616 let json = serde_json::to_string(¶ms).unwrap();
618 let parsed: Parameters = serde_json::from_str(&json).unwrap();
619 assert_eq!(params.param_type, parsed.param_type);
620 assert_eq!(params.properties.len(), parsed.properties.len());
621 assert_eq!(params.required, parsed.required);
622 }
623
624 #[test]
625 fn fuzz_tool_approval_serialization(
626 approval_type in 0usize..3,
627 reason in ".*",
628 ) {
629 let approval = match approval_type {
630 0 => ToolApproval::Approved,
631 1 => ToolApproval::Denied(reason),
632 _ => ToolApproval::Quit,
633 };
634
635 let json = serde_json::to_string(&approval).unwrap();
637 let parsed: ToolApproval = serde_json::from_str(&json).unwrap();
638 assert_eq!(approval, parsed);
639 }
640
641 #[test]
642 fn fuzz_tool_call_with_malformed_json_args(
643 func_name in ".*",
644 num_args in 0usize..10,
645 ) {
646 let malformed_jsons = [
648 "{",
649 "}",
650 "[",
651 "]",
652 "null",
653 "undefined",
654 r#"{"incomplete": "#,
655 r#"{"key": "value"}"#,
656 "",
657 " ",
658 ];
659
660 let args: Vec<String> = (0..num_args)
661 .map(|i| malformed_jsons[i % malformed_jsons.len()].to_string())
662 .collect();
663
664 let call = ToolCall::new(func_name.clone(), args.clone());
665
666 assert_eq!(call.function.name, func_name);
668 assert_eq!(call.function.arguments, args);
669
670 let json = serde_json::to_string(&call).unwrap();
672 let parsed: ToolCall = serde_json::from_str(&json).unwrap();
673 assert_eq!(call.function.arguments, parsed.function.arguments);
674 }
675 }
676}