use crabllm_core::{FunctionDef, Tool, ToolType};
use heck::ToSnakeCase;
use schemars::JsonSchema;
use std::collections::BTreeMap;
use tokio::sync::{mpsc, oneshot};
pub type ToolSender = mpsc::UnboundedSender<ToolRequest>;
pub struct ToolRequest {
pub name: String,
pub args: String,
pub agent: String,
pub reply: oneshot::Sender<Result<String, String>>,
pub task_id: Option<u64>,
pub sender: String,
pub conversation_id: Option<u64>,
}
#[derive(Default, Clone)]
pub struct ToolRegistry {
tools: BTreeMap<String, Tool>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, tool: Tool) {
self.tools.insert(tool.function.name.clone(), tool);
}
pub fn insert_all(&mut self, tools: Vec<Tool>) {
for tool in tools {
self.insert(tool);
}
}
pub fn remove(&mut self, name: &str) -> bool {
self.tools.remove(name).is_some()
}
pub fn contains(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn tools(&self) -> Vec<Tool> {
self.tools.values().cloned().collect()
}
pub fn filtered_snapshot(&self, names: &[String]) -> Vec<Tool> {
if names.is_empty() {
return self.tools();
}
self.tools
.iter()
.filter(|(k, _)| names.iter().any(|n| n == *k))
.map(|(_, v)| v.clone())
.collect()
}
}
pub trait ToolDescription {
const DESCRIPTION: &'static str;
}
pub trait AsTool: ToolDescription {
fn as_tool() -> Tool;
}
impl<T> AsTool for T
where
T: JsonSchema + ToolDescription,
{
fn as_tool() -> Tool {
Tool {
kind: ToolType::Function,
function: FunctionDef {
name: T::schema_name().to_snake_case(),
description: Some(Self::DESCRIPTION.into()),
parameters: Some(
serde_json::to_value(schemars::schema_for!(T)).unwrap_or_default(),
),
},
strict: None,
}
}
}