neuromance_common/
tools.rs

1//! Tool calling and function execution types for LLM interactions.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6use typed_builder::TypedBuilder;
7use uuid::Uuid;
8
9/// Represents the approval status of a tool call.
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub enum ToolApproval {
12    /// The tool call is approved and should be executed.
13    Approved,
14    /// The tool call is denied with a reason.
15    Denied(String),
16    /// Quit the current operation.
17    Quit,
18}
19
20/// Describes a single property in a function parameter schema.
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22pub struct Property {
23    /// The JSON type (e.g., "string", "number", "object").
24    #[serde(rename = "type")]
25    pub prop_type: String,
26    /// Human-readable description of this property.
27    pub description: String,
28}
29
30/// Defines the parameter schema for a function using JSON Schema conventions.
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32pub struct Parameters {
33    /// The JSON type, typically "object".
34    #[serde(rename = "type")]
35    pub param_type: String,
36    /// Map of parameter names to their property definitions.
37    pub properties: HashMap<String, Property>,
38    /// List of required parameter names.
39    pub required: Vec<String>,
40}
41
42/// Describes a function that can be called by an LLM.
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
44pub struct Function {
45    /// The name of the function.
46    pub name: String,
47    /// Human-readable description of what the function does.
48    pub description: String,
49    /// JSON Schema definition of the function's parameters.
50    pub parameters: serde_json::Value,
51}
52
53/// Represents a tool available to the LLM, typically wrapping a function.
54#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder, PartialEq)]
55pub struct Tool {
56    /// The type of tool (defaults to "function").
57    #[serde(rename = "type")]
58    #[builder(default = "function".to_string())]
59    pub r#type: String,
60    /// The function definition.
61    pub function: Function,
62}
63
64/// Represents an invocation of a function with arguments.
65#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
66pub struct FunctionCall {
67    /// The name of the function being called.
68    pub name: String,
69    /// The arguments passed to the function as strings.
70    pub arguments: Vec<String>,
71}
72
73/// Represents a complete tool call from an LLM, including ID and function details.
74///
75/// Arguments in `function.arguments` are passed through as-is from API responses.
76/// Users should validate and parse arguments when executing tools.
77#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
78pub struct ToolCall {
79    /// Unique identifier for this tool call.
80    pub id: String,
81    /// The function being invoked.
82    pub function: FunctionCall,
83    /// The type of call, typically "function".
84    pub call_type: String,
85}
86
87impl ToolCall {
88    /// Creates a new tool call with a generated ID.
89    pub fn new<I, T>(name: impl Into<String>, arguments: I) -> Self
90    where
91        I: IntoIterator<Item = T>,
92        T: Into<String>,
93    {
94        Self {
95            id: Uuid::new_v4().to_string(),
96            function: FunctionCall {
97                name: name.into(),
98                arguments: arguments.into_iter().map(|arg| arg.into()).collect(),
99            },
100            call_type: "function".to_string(),
101        }
102    }
103
104    /// Merges tool call deltas by ID, concatenating argument fragments.
105    ///
106    /// Used when processing streaming LLM responses where tool calls arrive incrementally.
107    /// Argument fragments are concatenated for matching IDs.
108    pub fn merge_deltas(mut accumulated: Vec<Self>, deltas: &[Self]) -> Vec<Self> {
109        for delta in deltas {
110            if let Some(existing) = accumulated.iter_mut().find(|tc| tc.id == delta.id) {
111                // Merge arguments by concatenating fragments
112                for arg in &delta.function.arguments {
113                    if let Some(last_arg) = existing.function.arguments.last_mut() {
114                        // Append to the last argument (streaming sends fragments)
115                        last_arg.push_str(arg);
116                    } else {
117                        // No existing arguments, add as new
118                        existing.function.arguments.push(arg.clone());
119                    }
120                }
121            } else {
122                // New tool call, add it
123                accumulated.push(delta.clone());
124            }
125        }
126
127        accumulated
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_tool_approval_variants() {
137        let approved = ToolApproval::Approved;
138        let denied = ToolApproval::Denied("Invalid request".to_string());
139        let quit = ToolApproval::Quit;
140
141        assert_eq!(approved, ToolApproval::Approved);
142        assert_eq!(denied, ToolApproval::Denied("Invalid request".to_string()));
143        assert_eq!(quit, ToolApproval::Quit);
144    }
145
146    #[test]
147    fn test_tool_approval_serialization() {
148        let approved = ToolApproval::Approved;
149        let json = serde_json::to_string(&approved).expect("Failed to serialize");
150        let deserialized: ToolApproval =
151            serde_json::from_str(&json).expect("Failed to deserialize");
152        assert_eq!(approved, deserialized);
153
154        let denied = ToolApproval::Denied("Reason".to_string());
155        let json = serde_json::to_string(&denied).expect("Failed to serialize");
156        let deserialized: ToolApproval =
157            serde_json::from_str(&json).expect("Failed to deserialize");
158        assert_eq!(denied, deserialized);
159    }
160
161    #[test]
162    fn test_property_creation() {
163        let prop = Property {
164            prop_type: "string".to_string(),
165            description: "The user's name".to_string(),
166        };
167
168        assert_eq!(prop.prop_type, "string");
169        assert_eq!(prop.description, "The user's name");
170    }
171
172    #[test]
173    fn test_property_serialization() {
174        let prop = Property {
175            prop_type: "number".to_string(),
176            description: "Age in years".to_string(),
177        };
178
179        let json = serde_json::to_value(&prop).expect("Failed to serialize");
180        assert_eq!(json["type"], "number");
181        assert_eq!(json["description"], "Age in years");
182
183        let deserialized: Property = serde_json::from_value(json).expect("Failed to deserialize");
184        assert_eq!(prop, deserialized);
185    }
186
187    #[test]
188    fn test_parameters_creation() {
189        let mut properties = HashMap::new();
190        properties.insert(
191            "location".to_string(),
192            Property {
193                prop_type: "string".to_string(),
194                description: "City name".to_string(),
195            },
196        );
197
198        let params = Parameters {
199            param_type: "object".to_string(),
200            properties,
201            required: vec!["location".to_string()],
202        };
203
204        assert_eq!(params.param_type, "object");
205        assert_eq!(params.properties.len(), 1);
206        assert_eq!(params.required, vec!["location"]);
207    }
208
209    #[test]
210    fn test_parameters_serialization() {
211        let mut properties = HashMap::new();
212        properties.insert(
213            "name".to_string(),
214            Property {
215                prop_type: "string".to_string(),
216                description: "Name".to_string(),
217            },
218        );
219
220        let params = Parameters {
221            param_type: "object".to_string(),
222            properties,
223            required: vec!["name".to_string()],
224        };
225
226        let json = serde_json::to_value(&params).expect("Failed to serialize");
227        assert_eq!(json["type"], "object");
228        assert!(json["properties"].is_object());
229        assert!(json["required"].is_array());
230
231        let deserialized: Parameters = serde_json::from_value(json).expect("Failed to deserialize");
232        assert_eq!(params, deserialized);
233    }
234
235    #[test]
236    fn test_function_creation() {
237        let func = Function {
238            name: "get_weather".to_string(),
239            description: "Get weather for a location".to_string(),
240            parameters: serde_json::json!({
241                "type": "object",
242                "properties": {},
243                "required": [],
244            }),
245        };
246
247        assert_eq!(func.name, "get_weather");
248        assert_eq!(func.description, "Get weather for a location");
249        assert!(func.parameters.is_object());
250    }
251
252    #[test]
253    fn test_function_serialization() {
254        let func = Function {
255            name: "calculate".to_string(),
256            description: "Perform calculation".to_string(),
257            parameters: serde_json::json!({"type": "object"}),
258        };
259
260        let json = serde_json::to_value(&func).expect("Failed to serialize");
261        assert_eq!(json["name"], "calculate");
262        assert_eq!(json["description"], "Perform calculation");
263
264        let deserialized: Function = serde_json::from_value(json).expect("Failed to deserialize");
265        assert_eq!(func, deserialized);
266    }
267
268    #[test]
269    fn test_tool_builder() {
270        let tool = Tool::builder()
271            .function(Function {
272                name: "test_func".to_string(),
273                description: "A test function".to_string(),
274                parameters: serde_json::json!({}),
275            })
276            .build();
277
278        assert_eq!(tool.r#type, "function");
279        assert_eq!(tool.function.name, "test_func");
280    }
281
282    #[test]
283    fn test_tool_builder_with_custom_type() {
284        let tool = Tool::builder()
285            .r#type("custom".to_string())
286            .function(Function {
287                name: "custom_func".to_string(),
288                description: "Custom function".to_string(),
289                parameters: serde_json::json!({}),
290            })
291            .build();
292
293        assert_eq!(tool.r#type, "custom");
294        assert_eq!(tool.function.name, "custom_func");
295    }
296
297    #[test]
298    fn test_tool_serialization() {
299        let tool = Tool::builder()
300            .function(Function {
301                name: "test".to_string(),
302                description: "Test".to_string(),
303                parameters: serde_json::json!({}),
304            })
305            .build();
306
307        let json = serde_json::to_value(&tool).expect("Failed to serialize");
308        assert_eq!(json["type"], "function");
309        assert_eq!(json["function"]["name"], "test");
310
311        let deserialized: Tool = serde_json::from_value(json).expect("Failed to deserialize");
312        assert_eq!(tool, deserialized);
313    }
314
315    #[test]
316    fn test_function_call_creation() {
317        let call = FunctionCall {
318            name: "my_function".to_string(),
319            arguments: vec!["arg1".to_string(), "arg2".to_string()],
320        };
321
322        assert_eq!(call.name, "my_function");
323        assert_eq!(call.arguments.len(), 2);
324        assert_eq!(call.arguments[0], "arg1");
325    }
326
327    #[test]
328    fn test_tool_call_new() {
329        let call = ToolCall::new("get_weather", vec!["NYC".to_string()]);
330
331        assert!(!call.id.is_empty());
332        assert_eq!(call.function.name, "get_weather");
333        assert_eq!(call.function.arguments, vec!["NYC"]);
334        assert_eq!(call.call_type, "function");
335    }
336
337    #[test]
338    fn test_tool_call_new_with_array_literal() {
339        let call = ToolCall::new("test_func", [r#"{"key": "value"}"#]);
340
341        assert_eq!(call.function.name, "test_func");
342        assert_eq!(call.function.arguments.len(), 1);
343        assert_eq!(call.function.arguments[0], r#"{"key": "value"}"#);
344    }
345
346    #[test]
347    fn test_tool_call_new_empty_args() {
348        let call = ToolCall::new("no_args_func", Vec::<String>::new());
349
350        assert_eq!(call.function.name, "no_args_func");
351        assert!(call.function.arguments.is_empty());
352    }
353
354    #[test]
355    fn test_tool_call_new_multiple_args() {
356        let call = ToolCall::new(
357            "multi_arg_func",
358            vec!["arg1".to_string(), "arg2".to_string(), "arg3".to_string()],
359        );
360
361        assert_eq!(call.function.arguments.len(), 3);
362        assert_eq!(call.function.arguments[0], "arg1");
363        assert_eq!(call.function.arguments[1], "arg2");
364        assert_eq!(call.function.arguments[2], "arg3");
365    }
366
367    #[test]
368    fn test_tool_call_serialization() {
369        let call = ToolCall::new("test_function", vec!["test_arg".to_string()]);
370
371        let json = serde_json::to_value(&call).expect("Failed to serialize");
372        assert_eq!(json["function"]["name"], "test_function");
373        assert_eq!(json["call_type"], "function");
374
375        let deserialized: ToolCall = serde_json::from_value(json).expect("Failed to deserialize");
376        assert_eq!(call.function.name, deserialized.function.name);
377        assert_eq!(call.function.arguments, deserialized.function.arguments);
378    }
379
380    #[test]
381    fn test_tool_call_unique_ids() {
382        let call1 = ToolCall::new("func", Vec::<String>::new());
383        let call2 = ToolCall::new("func", Vec::<String>::new());
384
385        assert_ne!(call1.id, call2.id);
386    }
387
388    #[test]
389    fn test_tool_call_delta_merging() {
390        // Simulate streaming chunks with tool call deltas that need to be merged
391        // This tests the delta merging logic used in streaming responses
392        let deltas = vec![
393            // First delta: tool call with ID and partial arguments
394            ToolCall {
395                id: "call_123".to_string(),
396                call_type: "function".to_string(),
397                function: FunctionCall {
398                    name: "test_function".to_string(),
399                    arguments: vec![r#"{"param1": ""#.to_string()],
400                },
401            },
402            // Second delta: same ID, more argument fragments
403            ToolCall {
404                id: "call_123".to_string(),
405                call_type: "function".to_string(),
406                function: FunctionCall {
407                    name: "test_function".to_string(),
408                    arguments: vec![r#"hello", "param2": "#.to_string()],
409                },
410            },
411            // Third delta: final argument fragment
412            ToolCall {
413                id: "call_123".to_string(),
414                call_type: "function".to_string(),
415                function: FunctionCall {
416                    name: "test_function".to_string(),
417                    arguments: vec![r#"123}"#.to_string()],
418                },
419            },
420        ];
421
422        // Use the actual production merge function
423        let mut tool_calls: Vec<ToolCall> = Vec::new();
424        for delta in &deltas {
425            tool_calls = ToolCall::merge_deltas(tool_calls, std::slice::from_ref(delta));
426        }
427
428        // Verify the final merged result
429        assert_eq!(tool_calls.len(), 1);
430
431        let merged = &tool_calls[0];
432        assert_eq!(merged.id, "call_123");
433        assert_eq!(merged.function.name, "test_function");
434        assert_eq!(merged.function.arguments.len(), 1);
435        assert_eq!(
436            merged.function.arguments[0],
437            r#"{"param1": "hello", "param2": 123}"#
438        );
439
440        // Verify it's valid JSON
441        let parsed: serde_json::Value = serde_json::from_str(&merged.function.arguments[0])
442            .expect("Merged arguments should be valid JSON");
443        assert_eq!(parsed["param1"], "hello");
444        assert_eq!(parsed["param2"], 123);
445    }
446
447    #[test]
448    fn test_multiple_tool_call_delta_merging() {
449        // Test merging multiple different tool calls in the same stream
450        let deltas = vec![
451            // First tool call starts
452            ToolCall {
453                id: "call_1".to_string(),
454                call_type: "function".to_string(),
455                function: FunctionCall {
456                    name: "func1".to_string(),
457                    arguments: vec![r#"{"a":"#.to_string()],
458                },
459            },
460            // Second tool call starts
461            ToolCall {
462                id: "call_2".to_string(),
463                call_type: "function".to_string(),
464                function: FunctionCall {
465                    name: "func2".to_string(),
466                    arguments: vec![r#"{"b":"#.to_string()],
467                },
468            },
469            // First tool call continues
470            ToolCall {
471                id: "call_1".to_string(),
472                call_type: "function".to_string(),
473                function: FunctionCall {
474                    name: "func1".to_string(),
475                    arguments: vec![r#"1}"#.to_string()],
476                },
477            },
478            // Second tool call continues
479            ToolCall {
480                id: "call_2".to_string(),
481                call_type: "function".to_string(),
482                function: FunctionCall {
483                    name: "func2".to_string(),
484                    arguments: vec![r#"2}"#.to_string()],
485                },
486            },
487        ];
488
489        // Use the actual production merge function
490        let mut tool_calls: Vec<ToolCall> = Vec::new();
491        for delta in &deltas {
492            tool_calls = ToolCall::merge_deltas(tool_calls, std::slice::from_ref(delta));
493        }
494
495        // Verify both tool calls were properly merged
496        assert_eq!(tool_calls.len(), 2);
497
498        let call1 = &tool_calls[0];
499        assert_eq!(call1.id, "call_1");
500        assert_eq!(call1.function.name, "func1");
501        assert_eq!(call1.function.arguments[0], r#"{"a":1}"#);
502
503        let call2 = &tool_calls[1];
504        assert_eq!(call2.id, "call_2");
505        assert_eq!(call2.function.name, "func2");
506        assert_eq!(call2.function.arguments[0], r#"{"b":2}"#);
507
508        // Verify both are valid JSON
509        serde_json::from_str::<serde_json::Value>(&call1.function.arguments[0])
510            .expect("First call should be valid JSON");
511        serde_json::from_str::<serde_json::Value>(&call2.function.arguments[0])
512            .expect("Second call should be valid JSON");
513    }
514}
515
516#[cfg(test)]
517mod proptests {
518    use super::*;
519    use proptest::prelude::*;
520
521    proptest! {
522        #[test]
523        fn fuzz_tool_call_deserialization(data in prop::collection::vec(any::<u8>(), 0..1000)) {
524            // Should not panic on arbitrary bytes
525            let _ = serde_json::from_slice::<ToolCall>(&data);
526        }
527
528        #[test]
529        fn fuzz_function_call_with_arbitrary_args(
530            name in ".*",
531            args in prop::collection::vec(".*", 0..10),
532        ) {
533            let call = FunctionCall {
534                name: name.clone(),
535                arguments: args.clone(),
536            };
537
538            // Should serialize and deserialize
539            let json = serde_json::to_string(&call).unwrap();
540            let parsed: FunctionCall = serde_json::from_str(&json).unwrap();
541            assert_eq!(call.name, parsed.name);
542            assert_eq!(call.arguments, parsed.arguments);
543        }
544
545        #[test]
546        fn fuzz_tool_call_new_with_special_chars(
547            func_name in r#"[a-zA-Z0-9_\-\.]{1,50}"#,
548            args in prop::collection::vec(r#"[\\x00-\\x7F]*"#, 0..5),
549        ) {
550            let call = ToolCall::new(func_name.clone(), args.clone());
551
552            assert_eq!(call.function.name, func_name);
553            assert_eq!(call.function.arguments, args);
554            assert_eq!(call.call_type, "function");
555            assert!(!call.id.is_empty());
556        }
557
558        #[test]
559        fn fuzz_tool_deserialization(data in prop::collection::vec(any::<u8>(), 0..1000)) {
560            // Should not panic on arbitrary bytes
561            let _ = serde_json::from_slice::<Tool>(&data);
562        }
563
564        #[test]
565        fn fuzz_function_with_arbitrary_json_params(
566            name in ".*",
567            description in ".*",
568        ) {
569            // Create various JSON parameter structures
570            let params_variants = vec![
571                serde_json::json!({}),
572                serde_json::json!({"type": "object"}),
573                serde_json::json!({"type": "object", "properties": {}, "required": []}),
574                serde_json::json!(null),
575                serde_json::json!([]),
576                serde_json::json!("string"),
577            ];
578
579            for params in params_variants {
580                let func = Function {
581                    name: name.clone(),
582                    description: description.clone(),
583                    parameters: params.clone(),
584                };
585
586                // Should serialize and deserialize
587                let json = serde_json::to_string(&func).unwrap();
588                let parsed: Function = serde_json::from_str(&json).unwrap();
589                assert_eq!(func.name, parsed.name);
590                assert_eq!(func.description, parsed.description);
591            }
592        }
593
594        #[test]
595        fn fuzz_parameters_with_arbitrary_properties(
596            num_props in 0usize..10,
597        ) {
598            let mut properties = HashMap::new();
599
600            for i in 0..num_props {
601                properties.insert(
602                    format!("prop_{}", i),
603                    Property {
604                        prop_type: format!("type_{}", i % 3),
605                        description: format!("desc_{}", i),
606                    },
607                );
608            }
609
610            let params = Parameters {
611                param_type: "object".to_string(),
612                properties: properties.clone(),
613                required: (0..num_props).map(|i| format!("prop_{}", i)).collect(),
614            };
615
616            // Should serialize and deserialize
617            let json = serde_json::to_string(&params).unwrap();
618            let parsed: Parameters = serde_json::from_str(&json).unwrap();
619            assert_eq!(params.param_type, parsed.param_type);
620            assert_eq!(params.properties.len(), parsed.properties.len());
621            assert_eq!(params.required, parsed.required);
622        }
623
624        #[test]
625        fn fuzz_tool_approval_serialization(
626            approval_type in 0usize..3,
627            reason in ".*",
628        ) {
629            let approval = match approval_type {
630                0 => ToolApproval::Approved,
631                1 => ToolApproval::Denied(reason),
632                _ => ToolApproval::Quit,
633            };
634
635            // Should serialize and deserialize
636            let json = serde_json::to_string(&approval).unwrap();
637            let parsed: ToolApproval = serde_json::from_str(&json).unwrap();
638            assert_eq!(approval, parsed);
639        }
640
641        #[test]
642        fn fuzz_tool_call_with_malformed_json_args(
643            func_name in ".*",
644            num_args in 0usize..10,
645        ) {
646            // Generate various potentially malformed JSON strings
647            let malformed_jsons = [
648                "{",
649                "}",
650                "[",
651                "]",
652                "null",
653                "undefined",
654                r#"{"incomplete": "#,
655                r#"{"key": "value"}"#,
656                "",
657                "   ",
658            ];
659
660            let args: Vec<String> = (0..num_args)
661                .map(|i| malformed_jsons[i % malformed_jsons.len()].to_string())
662                .collect();
663
664            let call = ToolCall::new(func_name.clone(), args.clone());
665
666            // Should create the call even with malformed JSON args
667            assert_eq!(call.function.name, func_name);
668            assert_eq!(call.function.arguments, args);
669
670            // Should serialize and deserialize the ToolCall itself
671            let json = serde_json::to_string(&call).unwrap();
672            let parsed: ToolCall = serde_json::from_str(&json).unwrap();
673            assert_eq!(call.function.arguments, parsed.function.arguments);
674        }
675    }
676}