use crate::model::HistoryEntry;
use crabllm_core::{FunctionDef, Tool, ToolType};
use heck::ToSnakeCase;
use schemars::JsonSchema;
use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc};
pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>>;
pub trait ToolDispatcher: Send + Sync + 'static {
fn dispatch<'a>(
&'a self,
name: &'a str,
args: &'a str,
agent: &'a str,
sender: &'a str,
conversation_id: Option<u64>,
) -> ToolFuture<'a>;
}
#[derive(Clone)]
pub struct ToolDispatch {
pub args: String,
pub agent: String,
pub sender: String,
pub conversation_id: Option<u64>,
}
pub type ToolHandler = Arc<
dyn Fn(ToolDispatch) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>>
+ Send
+ Sync,
>;
pub type BeforeRunHook = Arc<dyn Fn(&[HistoryEntry]) -> Vec<HistoryEntry> + Send + Sync>;
pub struct ToolEntry {
pub schema: Tool,
pub handler: ToolHandler,
pub system_prompt: Option<String>,
pub before_run: Option<BeforeRunHook>,
}
#[derive(Default, Clone)]
pub struct ToolRegistry {
tools: BTreeMap<String, Tool>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, tool: Tool) {
self.tools.insert(tool.function.name.clone(), tool);
}
pub fn insert_all(&mut self, tools: Vec<Tool>) {
for tool in tools {
self.insert(tool);
}
}
pub fn remove(&mut self, name: &str) -> bool {
self.tools.remove(name).is_some()
}
pub fn contains(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn tools(&self) -> Vec<Tool> {
self.tools.values().cloned().collect()
}
pub fn filtered_snapshot(&self, names: &[String]) -> Vec<Tool> {
if names.is_empty() {
return self.tools();
}
self.tools
.iter()
.filter(|(k, _)| names.iter().any(|n| n == *k))
.map(|(_, v)| v.clone())
.collect()
}
}
pub trait AsTool {
fn as_tool() -> Tool;
}
impl<T: JsonSchema> AsTool for T {
fn as_tool() -> Tool {
let schema = schemars::schema_for!(T);
let description = schema
.get("description")
.and_then(|v| v.as_str())
.map(str::to_owned);
Tool {
kind: ToolType::Function,
function: FunctionDef {
name: T::schema_name().to_snake_case(),
description,
parameters: Some(serde_json::to_value(&schema).unwrap_or_default()),
},
strict: None,
}
}
}
impl ToolDispatcher for () {
fn dispatch<'a>(
&'a self,
name: &'a str,
_args: &'a str,
_agent: &'a str,
_sender: &'a str,
_conversation_id: Option<u64>,
) -> ToolFuture<'a> {
Box::pin(async move { Err(format!("tool not registered: {name}")) })
}
}