genai_rs/
function_calling.rs

1use async_trait::async_trait;
2use inventory;
3use serde_json::Value;
4use std::collections::HashMap;
5use std::error::Error;
6use std::sync::Arc;
7use tracing::warn;
8
9use crate::FunctionDeclaration;
10
11/// Represents an error that can occur during function execution.
12///
13/// This enum is marked `#[non_exhaustive]` for forward compatibility.
14/// New error variants may be added in future versions.
15#[derive(Debug)]
16#[non_exhaustive]
17pub enum FunctionError {
18    ArgumentMismatch(String),
19    ExecutionError(Box<dyn Error + Send + Sync>),
20}
21
22impl std::fmt::Display for FunctionError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::ArgumentMismatch(msg) => write!(f, "Argument mismatch: {msg}"),
26            Self::ExecutionError(err) => write!(f, "Function execution error: {err}"),
27        }
28    }
29}
30
31impl Error for FunctionError {
32    fn source(&self) -> Option<&(dyn Error + 'static)> {
33        match self {
34            Self::ExecutionError(err) => Some(err.as_ref()),
35            Self::ArgumentMismatch(_) => None,
36        }
37    }
38}
39
40/// A trait for functions that can be called by the model.
41#[async_trait]
42pub trait CallableFunction: Send + Sync {
43    /// Returns the declaration of the function.
44    fn declaration(&self) -> FunctionDeclaration;
45
46    /// Executes the function with the given arguments.
47    /// The arguments are provided as a serde_json::Value,
48    /// and the function should return a serde_json::Value.
49    async fn call(&self, args: Value) -> Result<Value, FunctionError>;
50}
51
52/// A provider of callable functions with shared state/dependencies.
53///
54/// Implement this trait on structs that need to provide tools with access to
55/// shared resources like databases, APIs, or configuration. This enables
56/// dependency injection for tool functions.
57///
58/// # Example
59///
60/// ```ignore
61/// use genai_rs::{CallableFunction, ToolService, FunctionDeclaration};
62/// use std::sync::Arc;
63///
64/// struct WeatherService {
65///     api_key: String,
66/// }
67///
68/// impl ToolService for WeatherService {
69///     fn tools(&self) -> Vec<Arc<dyn CallableFunction>> {
70///         vec![
71///             Arc::new(GetWeatherTool { api_key: self.api_key.clone() }),
72///         ]
73///     }
74/// }
75///
76/// // Use with InteractionBuilder:
77/// let service = Arc::new(WeatherService { api_key: "...".into() });
78/// client.interaction()
79///     .with_tool_service(service)
80///     .create_with_auto_functions()
81///     .await?;
82/// ```
83pub trait ToolService: Send + Sync {
84    /// Returns the callable functions provided by this service.
85    ///
86    /// Each function can hold references to shared state from the service.
87    fn tools(&self) -> Vec<Arc<dyn CallableFunction>>;
88}
89
90/// A factory for creating instances of `CallableFunction`.
91/// Instances of this struct will be collected by `inventory`.
92pub struct CallableFunctionFactory {
93    pub factory_fn: fn() -> Box<dyn CallableFunction>,
94}
95
96impl CallableFunctionFactory {
97    pub const fn new(factory_fn: fn() -> Box<dyn CallableFunction>) -> Self {
98        Self { factory_fn }
99    }
100}
101
102// Declare that we want to collect `CallableFunctionFactory` instances.
103// This needs to be visible to the macros that will submit to it.
104// The `pub` keyword here is important.
105pub use inventory::submit;
106
107inventory::collect!(CallableFunctionFactory);
108
109/// A registry for callable functions.
110pub(crate) struct FunctionRegistry {
111    functions: HashMap<String, Box<dyn CallableFunction>>,
112}
113
114impl FunctionRegistry {
115    /// Creates a new empty function registry.
116    fn new() -> Self {
117        Self {
118            functions: HashMap::new(),
119        }
120    }
121
122    /// Registers a function directly.
123    fn register_raw(&mut self, function: Box<dyn CallableFunction>) {
124        let name = function.declaration().name().to_string();
125        if self.functions.contains_key(&name) {
126            warn!(
127                "Duplicate function name in auto-registration: function='{}'. Last registration will be used.",
128                name
129            );
130        }
131        self.functions.insert(name, function);
132    }
133
134    /// Retrieves a function by its name.
135    pub(crate) fn get(&self, name: &str) -> Option<&dyn CallableFunction> {
136        self.functions.get(name).map(std::convert::AsRef::as_ref)
137    }
138
139    /// Returns an iterator over all registered function declarations.
140    pub(crate) fn all_declarations(&self) -> Vec<FunctionDeclaration> {
141        self.functions.values().map(|f| f.declaration()).collect()
142    }
143}
144
145/// Global registry, populated automatically via inventory.
146static GLOBAL_FUNCTION_REGISTRY: std::sync::LazyLock<FunctionRegistry> =
147    std::sync::LazyLock::new(|| {
148        let mut registry = FunctionRegistry::new();
149
150        for factory in inventory::iter::<CallableFunctionFactory> {
151            let function = (factory.factory_fn)();
152            registry.register_raw(function);
153        }
154
155        registry
156    });
157
158/// Provides access to the global function registry.
159/// This is intended for internal use by the client for automatic function execution.
160pub(crate) fn get_global_function_registry() -> &'static FunctionRegistry {
161    &GLOBAL_FUNCTION_REGISTRY
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::FunctionDeclaration;
168    use async_trait::async_trait;
169    use serde_json::json;
170
171    // Dummy function for testing purposes.
172    // In real usage, this would be generated by the macro.
173    struct TestFunctionGlobal;
174
175    #[async_trait]
176    impl CallableFunction for TestFunctionGlobal {
177        fn declaration(&self) -> FunctionDeclaration {
178            FunctionDeclaration::new(
179                "test_function_global".to_string(),
180                "A global test function".to_string(),
181                crate::FunctionParameters::new(
182                    "object".to_string(),
183                    json!({"param": {"type": "string"}}),
184                    vec!["param".to_string()],
185                ),
186            )
187        }
188
189        async fn call(&self, args: Value) -> Result<Value, FunctionError> {
190            args.get("param").and_then(Value::as_str).map_or_else(
191                || {
192                    Err(FunctionError::ArgumentMismatch(
193                        "Missing param for Global".to_string(),
194                    ))
195                },
196                |p| Ok(json!({ "result": format!("Global says: Hello, {p}") })),
197            )
198        }
199    }
200
201    // Manually create a factory function for the test, similar to what the macro would do.
202    fn test_function_global_callable_factory() -> Box<dyn CallableFunction> {
203        Box::new(TestFunctionGlobal)
204    }
205
206    // Simulate macro-based registration for testing `FunctionRegistry::new()`
207    // This needs to be outside the test function to be collected by inventory.
208    inventory::submit! {
209        CallableFunctionFactory::new(test_function_global_callable_factory)
210    }
211
212    #[test]
213    fn test_global_registry_population_and_access() {
214        let registry = get_global_function_registry(); // Access the global registry
215        let retrieved_func = registry.get("test_function_global");
216        assert!(
217            retrieved_func.is_some(),
218            "Function 'test_function_global' should be in the global registry."
219        );
220        assert_eq!(
221            retrieved_func.unwrap().declaration().name(),
222            "test_function_global"
223        );
224    }
225
226    #[tokio::test]
227    async fn test_call_global_registered_function() {
228        let registry = get_global_function_registry();
229        let retrieved_func = registry
230            .get("test_function_global")
231            .expect("Global function not found");
232
233        let args = json!({ "param": "GlobalInventoryWorld" });
234        let result = retrieved_func.call(args).await;
235        assert!(result.is_ok());
236        assert_eq!(
237            result.unwrap(),
238            json!({ "result": "Global says: Hello, GlobalInventoryWorld" })
239        );
240    }
241
242    // Test ToolService trait for dependency injection
243
244    /// A tool that holds shared state from its service.
245    struct GreetTool {
246        greeting_prefix: String,
247    }
248
249    #[async_trait]
250    impl CallableFunction for GreetTool {
251        fn declaration(&self) -> FunctionDeclaration {
252            FunctionDeclaration::new(
253                "greet".to_string(),
254                "Greets a person with a custom prefix".to_string(),
255                crate::FunctionParameters::new(
256                    "object".to_string(),
257                    json!({"name": {"type": "string"}}),
258                    vec!["name".to_string()],
259                ),
260            )
261        }
262
263        async fn call(&self, args: Value) -> Result<Value, FunctionError> {
264            args.get("name").and_then(Value::as_str).map_or_else(
265                || {
266                    Err(FunctionError::ArgumentMismatch(
267                        "Missing 'name' argument".to_string(),
268                    ))
269                },
270                |name| Ok(json!({ "message": format!("{} {name}!", self.greeting_prefix) })),
271            )
272        }
273    }
274
275    /// A service that provides tools with shared configuration.
276    struct GreetingService {
277        prefix: String,
278    }
279
280    impl ToolService for GreetingService {
281        fn tools(&self) -> Vec<Arc<dyn CallableFunction>> {
282            vec![Arc::new(GreetTool {
283                greeting_prefix: self.prefix.clone(),
284            })]
285        }
286    }
287
288    #[test]
289    fn test_tool_service_returns_tools() {
290        let service = GreetingService {
291            prefix: "Hello".to_string(),
292        };
293        let tools = service.tools();
294
295        assert_eq!(tools.len(), 1);
296        assert_eq!(tools[0].declaration().name(), "greet");
297    }
298
299    #[tokio::test]
300    async fn test_tool_service_tool_can_be_called() {
301        let service = GreetingService {
302            prefix: "Howdy".to_string(),
303        };
304        let tools = service.tools();
305        let greet_tool = &tools[0];
306
307        let result = greet_tool.call(json!({ "name": "Partner" })).await;
308        assert!(result.is_ok());
309        assert_eq!(result.unwrap(), json!({ "message": "Howdy Partner!" }));
310    }
311
312    #[tokio::test]
313    async fn test_tool_service_with_different_config() {
314        // Demonstrate that different service instances produce different tool behavior
315        let formal_service = GreetingService {
316            prefix: "Good morning, Mr.".to_string(),
317        };
318        let casual_service = GreetingService {
319            prefix: "Hey".to_string(),
320        };
321
322        let formal_tools = formal_service.tools();
323        let casual_tools = casual_service.tools();
324
325        let formal_result = formal_tools[0].call(json!({ "name": "Smith" })).await;
326        let casual_result = casual_tools[0].call(json!({ "name": "Joe" })).await;
327
328        assert_eq!(
329            formal_result.unwrap(),
330            json!({ "message": "Good morning, Mr. Smith!" })
331        );
332        assert_eq!(casual_result.unwrap(), json!({ "message": "Hey Joe!" }));
333    }
334
335    #[test]
336    fn test_registry_returns_none_for_unknown_function() {
337        let registry = get_global_function_registry();
338
339        // Looking up a function that doesn't exist should return None
340        let result = registry.get("this_function_definitely_does_not_exist_xyz123");
341        assert!(
342            result.is_none(),
343            "Registry should return None for unknown functions"
344        );
345    }
346
347    #[test]
348    fn test_registry_all_declarations_contains_registered() {
349        let registry = get_global_function_registry();
350        let declarations = registry.all_declarations();
351
352        // Our test function should be in the list
353        let names: Vec<_> = declarations.iter().map(|d| d.name()).collect();
354        assert!(
355            names.contains(&"test_function_global"),
356            "all_declarations should include registered function"
357        );
358    }
359
360    #[test]
361    fn test_tool_service_tools_are_independent() {
362        // Verify that calling tools() multiple times returns independent instances
363        let service = GreetingService {
364            prefix: "Hi".to_string(),
365        };
366
367        let tools1 = service.tools();
368        let tools2 = service.tools();
369
370        // Both should have the same declaration
371        assert_eq!(
372            tools1[0].declaration().name(),
373            tools2[0].declaration().name()
374        );
375
376        // But they should be separate Arc instances (different pointers)
377        // This ensures each call to tools() creates fresh tool instances
378        assert!(!Arc::ptr_eq(&tools1[0], &tools2[0]));
379    }
380
381    #[test]
382    fn test_registry_duplicate_registration_last_wins() {
383        // Test that when two functions with the same name are registered,
384        // the last one wins (and a warning is logged)
385        let mut registry = FunctionRegistry::new();
386
387        // First function
388        struct FirstFunc;
389        #[async_trait]
390        impl CallableFunction for FirstFunc {
391            fn declaration(&self) -> FunctionDeclaration {
392                FunctionDeclaration::new(
393                    "duplicate_name".to_string(),
394                    "First function".to_string(),
395                    crate::FunctionParameters::new("object".to_string(), json!({}), vec![]),
396                )
397            }
398            async fn call(&self, _args: Value) -> Result<Value, FunctionError> {
399                Ok(json!("first"))
400            }
401        }
402
403        // Second function with same name
404        struct SecondFunc;
405        #[async_trait]
406        impl CallableFunction for SecondFunc {
407            fn declaration(&self) -> FunctionDeclaration {
408                FunctionDeclaration::new(
409                    "duplicate_name".to_string(),
410                    "Second function".to_string(),
411                    crate::FunctionParameters::new("object".to_string(), json!({}), vec![]),
412                )
413            }
414            async fn call(&self, _args: Value) -> Result<Value, FunctionError> {
415                Ok(json!("second"))
416            }
417        }
418
419        // Register first, then second with same name
420        registry.register_raw(Box::new(FirstFunc));
421        registry.register_raw(Box::new(SecondFunc));
422
423        // Last registration should win
424        let func = registry
425            .get("duplicate_name")
426            .expect("Function should exist");
427        assert_eq!(
428            func.declaration().description(),
429            "Second function",
430            "Last registered function should win"
431        );
432    }
433
434    #[test]
435    fn test_empty_tool_service() {
436        // A tool service that provides no tools
437        struct EmptyService;
438
439        impl ToolService for EmptyService {
440            fn tools(&self) -> Vec<Arc<dyn CallableFunction>> {
441                vec![]
442            }
443        }
444
445        let service = EmptyService;
446        let tools = service.tools();
447
448        assert!(tools.is_empty(), "Empty service should return no tools");
449    }
450}