genai_rs/
function_calling.rs1use async_trait::async_trait;
2use inventory;
3use serde_json::Value;
4use std::collections::HashMap;
5use std::error::Error;
6use std::sync::Arc;
7use tracing::warn;
8
9use crate::FunctionDeclaration;
10
11#[derive(Debug)]
16#[non_exhaustive]
17pub enum FunctionError {
18 ArgumentMismatch(String),
19 ExecutionError(Box<dyn Error + Send + Sync>),
20}
21
22impl std::fmt::Display for FunctionError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::ArgumentMismatch(msg) => write!(f, "Argument mismatch: {msg}"),
26 Self::ExecutionError(err) => write!(f, "Function execution error: {err}"),
27 }
28 }
29}
30
31impl Error for FunctionError {
32 fn source(&self) -> Option<&(dyn Error + 'static)> {
33 match self {
34 Self::ExecutionError(err) => Some(err.as_ref()),
35 Self::ArgumentMismatch(_) => None,
36 }
37 }
38}
39
40#[async_trait]
42pub trait CallableFunction: Send + Sync {
43 fn declaration(&self) -> FunctionDeclaration;
45
46 async fn call(&self, args: Value) -> Result<Value, FunctionError>;
50}
51
52pub trait ToolService: Send + Sync {
84 fn tools(&self) -> Vec<Arc<dyn CallableFunction>>;
88}
89
90pub struct CallableFunctionFactory {
93 pub factory_fn: fn() -> Box<dyn CallableFunction>,
94}
95
96impl CallableFunctionFactory {
97 pub const fn new(factory_fn: fn() -> Box<dyn CallableFunction>) -> Self {
98 Self { factory_fn }
99 }
100}
101
102pub use inventory::submit;
106
107inventory::collect!(CallableFunctionFactory);
108
109pub(crate) struct FunctionRegistry {
111 functions: HashMap<String, Box<dyn CallableFunction>>,
112}
113
114impl FunctionRegistry {
115 fn new() -> Self {
117 Self {
118 functions: HashMap::new(),
119 }
120 }
121
122 fn register_raw(&mut self, function: Box<dyn CallableFunction>) {
124 let name = function.declaration().name().to_string();
125 if self.functions.contains_key(&name) {
126 warn!(
127 "Duplicate function name in auto-registration: function='{}'. Last registration will be used.",
128 name
129 );
130 }
131 self.functions.insert(name, function);
132 }
133
134 pub(crate) fn get(&self, name: &str) -> Option<&dyn CallableFunction> {
136 self.functions.get(name).map(std::convert::AsRef::as_ref)
137 }
138
139 pub(crate) fn all_declarations(&self) -> Vec<FunctionDeclaration> {
141 self.functions.values().map(|f| f.declaration()).collect()
142 }
143}
144
145static GLOBAL_FUNCTION_REGISTRY: std::sync::LazyLock<FunctionRegistry> =
147 std::sync::LazyLock::new(|| {
148 let mut registry = FunctionRegistry::new();
149
150 for factory in inventory::iter::<CallableFunctionFactory> {
151 let function = (factory.factory_fn)();
152 registry.register_raw(function);
153 }
154
155 registry
156 });
157
158pub(crate) fn get_global_function_registry() -> &'static FunctionRegistry {
161 &GLOBAL_FUNCTION_REGISTRY
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::FunctionDeclaration;
168 use async_trait::async_trait;
169 use serde_json::json;
170
171 struct TestFunctionGlobal;
174
175 #[async_trait]
176 impl CallableFunction for TestFunctionGlobal {
177 fn declaration(&self) -> FunctionDeclaration {
178 FunctionDeclaration::new(
179 "test_function_global".to_string(),
180 "A global test function".to_string(),
181 crate::FunctionParameters::new(
182 "object".to_string(),
183 json!({"param": {"type": "string"}}),
184 vec!["param".to_string()],
185 ),
186 )
187 }
188
189 async fn call(&self, args: Value) -> Result<Value, FunctionError> {
190 args.get("param").and_then(Value::as_str).map_or_else(
191 || {
192 Err(FunctionError::ArgumentMismatch(
193 "Missing param for Global".to_string(),
194 ))
195 },
196 |p| Ok(json!({ "result": format!("Global says: Hello, {p}") })),
197 )
198 }
199 }
200
201 fn test_function_global_callable_factory() -> Box<dyn CallableFunction> {
203 Box::new(TestFunctionGlobal)
204 }
205
206 inventory::submit! {
209 CallableFunctionFactory::new(test_function_global_callable_factory)
210 }
211
212 #[test]
213 fn test_global_registry_population_and_access() {
214 let registry = get_global_function_registry(); let retrieved_func = registry.get("test_function_global");
216 assert!(
217 retrieved_func.is_some(),
218 "Function 'test_function_global' should be in the global registry."
219 );
220 assert_eq!(
221 retrieved_func.unwrap().declaration().name(),
222 "test_function_global"
223 );
224 }
225
226 #[tokio::test]
227 async fn test_call_global_registered_function() {
228 let registry = get_global_function_registry();
229 let retrieved_func = registry
230 .get("test_function_global")
231 .expect("Global function not found");
232
233 let args = json!({ "param": "GlobalInventoryWorld" });
234 let result = retrieved_func.call(args).await;
235 assert!(result.is_ok());
236 assert_eq!(
237 result.unwrap(),
238 json!({ "result": "Global says: Hello, GlobalInventoryWorld" })
239 );
240 }
241
242 struct GreetTool {
246 greeting_prefix: String,
247 }
248
249 #[async_trait]
250 impl CallableFunction for GreetTool {
251 fn declaration(&self) -> FunctionDeclaration {
252 FunctionDeclaration::new(
253 "greet".to_string(),
254 "Greets a person with a custom prefix".to_string(),
255 crate::FunctionParameters::new(
256 "object".to_string(),
257 json!({"name": {"type": "string"}}),
258 vec!["name".to_string()],
259 ),
260 )
261 }
262
263 async fn call(&self, args: Value) -> Result<Value, FunctionError> {
264 args.get("name").and_then(Value::as_str).map_or_else(
265 || {
266 Err(FunctionError::ArgumentMismatch(
267 "Missing 'name' argument".to_string(),
268 ))
269 },
270 |name| Ok(json!({ "message": format!("{} {name}!", self.greeting_prefix) })),
271 )
272 }
273 }
274
275 struct GreetingService {
277 prefix: String,
278 }
279
280 impl ToolService for GreetingService {
281 fn tools(&self) -> Vec<Arc<dyn CallableFunction>> {
282 vec![Arc::new(GreetTool {
283 greeting_prefix: self.prefix.clone(),
284 })]
285 }
286 }
287
288 #[test]
289 fn test_tool_service_returns_tools() {
290 let service = GreetingService {
291 prefix: "Hello".to_string(),
292 };
293 let tools = service.tools();
294
295 assert_eq!(tools.len(), 1);
296 assert_eq!(tools[0].declaration().name(), "greet");
297 }
298
299 #[tokio::test]
300 async fn test_tool_service_tool_can_be_called() {
301 let service = GreetingService {
302 prefix: "Howdy".to_string(),
303 };
304 let tools = service.tools();
305 let greet_tool = &tools[0];
306
307 let result = greet_tool.call(json!({ "name": "Partner" })).await;
308 assert!(result.is_ok());
309 assert_eq!(result.unwrap(), json!({ "message": "Howdy Partner!" }));
310 }
311
312 #[tokio::test]
313 async fn test_tool_service_with_different_config() {
314 let formal_service = GreetingService {
316 prefix: "Good morning, Mr.".to_string(),
317 };
318 let casual_service = GreetingService {
319 prefix: "Hey".to_string(),
320 };
321
322 let formal_tools = formal_service.tools();
323 let casual_tools = casual_service.tools();
324
325 let formal_result = formal_tools[0].call(json!({ "name": "Smith" })).await;
326 let casual_result = casual_tools[0].call(json!({ "name": "Joe" })).await;
327
328 assert_eq!(
329 formal_result.unwrap(),
330 json!({ "message": "Good morning, Mr. Smith!" })
331 );
332 assert_eq!(casual_result.unwrap(), json!({ "message": "Hey Joe!" }));
333 }
334
335 #[test]
336 fn test_registry_returns_none_for_unknown_function() {
337 let registry = get_global_function_registry();
338
339 let result = registry.get("this_function_definitely_does_not_exist_xyz123");
341 assert!(
342 result.is_none(),
343 "Registry should return None for unknown functions"
344 );
345 }
346
347 #[test]
348 fn test_registry_all_declarations_contains_registered() {
349 let registry = get_global_function_registry();
350 let declarations = registry.all_declarations();
351
352 let names: Vec<_> = declarations.iter().map(|d| d.name()).collect();
354 assert!(
355 names.contains(&"test_function_global"),
356 "all_declarations should include registered function"
357 );
358 }
359
360 #[test]
361 fn test_tool_service_tools_are_independent() {
362 let service = GreetingService {
364 prefix: "Hi".to_string(),
365 };
366
367 let tools1 = service.tools();
368 let tools2 = service.tools();
369
370 assert_eq!(
372 tools1[0].declaration().name(),
373 tools2[0].declaration().name()
374 );
375
376 assert!(!Arc::ptr_eq(&tools1[0], &tools2[0]));
379 }
380
381 #[test]
382 fn test_registry_duplicate_registration_last_wins() {
383 let mut registry = FunctionRegistry::new();
386
387 struct FirstFunc;
389 #[async_trait]
390 impl CallableFunction for FirstFunc {
391 fn declaration(&self) -> FunctionDeclaration {
392 FunctionDeclaration::new(
393 "duplicate_name".to_string(),
394 "First function".to_string(),
395 crate::FunctionParameters::new("object".to_string(), json!({}), vec![]),
396 )
397 }
398 async fn call(&self, _args: Value) -> Result<Value, FunctionError> {
399 Ok(json!("first"))
400 }
401 }
402
403 struct SecondFunc;
405 #[async_trait]
406 impl CallableFunction for SecondFunc {
407 fn declaration(&self) -> FunctionDeclaration {
408 FunctionDeclaration::new(
409 "duplicate_name".to_string(),
410 "Second function".to_string(),
411 crate::FunctionParameters::new("object".to_string(), json!({}), vec![]),
412 )
413 }
414 async fn call(&self, _args: Value) -> Result<Value, FunctionError> {
415 Ok(json!("second"))
416 }
417 }
418
419 registry.register_raw(Box::new(FirstFunc));
421 registry.register_raw(Box::new(SecondFunc));
422
423 let func = registry
425 .get("duplicate_name")
426 .expect("Function should exist");
427 assert_eq!(
428 func.declaration().description(),
429 "Second function",
430 "Last registered function should win"
431 );
432 }
433
434 #[test]
435 fn test_empty_tool_service() {
436 struct EmptyService;
438
439 impl ToolService for EmptyService {
440 fn tools(&self) -> Vec<Arc<dyn CallableFunction>> {
441 vec![]
442 }
443 }
444
445 let service = EmptyService;
446 let tools = service.tools();
447
448 assert!(tools.is_empty(), "Empty service should return no tools");
449 }
450}