use super::types::*;
use crate::utils::error::gateway_error::Result;
use serde_json::Value;
use std::collections::HashMap;
#[async_trait::async_trait]
pub trait FunctionExecutor: Send + Sync {
async fn execute(&self, arguments: Value) -> Result<Value>;
fn get_schema(&self) -> FunctionDefinition;
fn validate_arguments(&self, _arguments: &Value) -> Result<()> {
Ok(())
}
}
pub struct FunctionCallingHandler {
pub(crate) functions: HashMap<String, FunctionDefinition>,
pub(crate) executors: HashMap<String, Box<dyn FunctionExecutor>>,
}
impl FunctionCallingHandler {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
executors: HashMap::new(),
}
}
pub fn register_function<F>(&mut self, name: String, executor: F) -> Result<()>
where
F: FunctionExecutor + 'static,
{
let schema = executor.get_schema();
self.functions.insert(name.clone(), schema);
self.executors.insert(name, Box::new(executor));
Ok(())
}
pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
self.functions
.values()
.map(|function| ToolDefinition {
tool_type: "function".to_string(),
function: function.clone(),
})
.collect()
}
}
impl Default for FunctionCallingHandler {
fn default() -> Self {
Self::new()
}
}