use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use async_trait::async_trait;
use crate::messages::tools::{CustomTool, Tool as MessagesTool};
use crate::tool_dispatch::tool::{Tool, ToolError};
#[derive(Default)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_tool<T: Tool>(&mut self, tool: T) -> &mut Self {
let name = tool.name().to_owned();
self.tools.insert(name, Arc::new(tool));
self
}
pub fn register<F, Fut>(
&mut self,
name: impl Into<String>,
schema: serde_json::Value,
handler: F,
) -> &mut Self
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
let name = name.into();
let tool = FnTool::new(name.clone(), schema, handler);
self.tools.insert(name, Arc::new(tool));
self
}
pub fn register_described<F, Fut>(
&mut self,
name: impl Into<String>,
description: impl Into<String>,
schema: serde_json::Value,
handler: F,
) -> &mut Self
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
let name = name.into();
let mut tool = FnTool::new(name.clone(), schema, handler);
tool.description = Some(description.into());
self.tools.insert(name, Arc::new(tool));
self
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
self.tools.get(name)
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
#[must_use]
pub fn len(&self) -> usize {
self.tools.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.tools.keys().map(String::as_str)
}
#[must_use]
pub fn to_messages_tools(&self) -> Vec<MessagesTool> {
self.tools
.values()
.map(|t| {
let mut ct = CustomTool::new(t.name(), t.schema());
if let Some(desc) = t.description() {
ct = ct.description(desc);
}
MessagesTool::Custom(ct)
})
.collect()
}
pub async fn dispatch(
&self,
name: &str,
input: serde_json::Value,
) -> Result<serde_json::Value, ToolError> {
let tool = self.tools.get(name).ok_or_else(|| ToolError::Unknown {
name: name.to_owned(),
})?;
tool.invoke(input).await
}
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("tools", &self.tools.keys().collect::<Vec<_>>())
.finish()
}
}
pub struct FnTool<F, Fut>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
name: String,
schema: serde_json::Value,
description: Option<String>,
handler: F,
_phantom: PhantomData<fn() -> Fut>,
}
impl<F, Fut> FnTool<F, Fut>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
pub fn new(name: impl Into<String>, schema: serde_json::Value, handler: F) -> Self {
Self {
name: name.into(),
schema,
description: None,
handler,
_phantom: PhantomData,
}
}
#[must_use]
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
}
#[async_trait]
impl<F, Fut> Tool for FnTool<F, Fut>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
{
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> Option<&str> {
self.description.as_deref()
}
fn schema(&self) -> serde_json::Value {
self.schema.clone()
}
async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
(self.handler)(input).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::messages::tools::Tool as MessagesTool;
use pretty_assertions::assert_eq;
use serde_json::{Value, json};
fn echo_schema() -> Value {
json!({"type": "object", "properties": {"text": {"type": "string"}}})
}
struct UpperTool;
#[async_trait]
impl Tool for UpperTool {
#[allow(clippy::unnecessary_literal_bound)]
fn name(&self) -> &str {
"upper"
}
fn schema(&self) -> Value {
json!({"type": "object", "properties": {"text": {"type": "string"}}})
}
async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
let s = input
.get("text")
.and_then(Value::as_str)
.ok_or_else(|| ToolError::invalid_input("missing 'text'"))?;
Ok(json!({"upper": s.to_uppercase()}))
}
}
#[tokio::test]
async fn register_and_dispatch_closure_tool() {
let mut registry = ToolRegistry::new();
registry.register("echo", echo_schema(), |input| async move { Ok(input) });
assert!(registry.contains("echo"));
assert_eq!(registry.len(), 1);
let result = registry
.dispatch("echo", json!({"text": "hi"}))
.await
.unwrap();
assert_eq!(result, json!({"text": "hi"}));
}
#[tokio::test]
async fn register_and_dispatch_trait_tool() {
let mut registry = ToolRegistry::new();
registry.register_tool(UpperTool);
let result = registry
.dispatch("upper", json!({"text": "rust"}))
.await
.unwrap();
assert_eq!(result, json!({"upper": "RUST"}));
}
#[tokio::test]
async fn closure_and_trait_tools_coexist() {
let mut registry = ToolRegistry::new();
registry
.register_tool(UpperTool)
.register("echo", echo_schema(), |input| async move { Ok(input) });
assert_eq!(registry.len(), 2);
let names: std::collections::HashSet<_> = registry.names().collect();
assert!(names.contains("upper"));
assert!(names.contains("echo"));
let r1 = registry
.dispatch("upper", json!({"text": "ok"}))
.await
.unwrap();
let r2 = registry
.dispatch("echo", json!({"text": "ok"}))
.await
.unwrap();
assert_eq!(r1, json!({"upper": "OK"}));
assert_eq!(r2, json!({"text": "ok"}));
}
#[tokio::test]
async fn dispatch_unknown_returns_unknown_error() {
let registry = ToolRegistry::new();
let err = registry.dispatch("nope", json!({})).await.unwrap_err();
let ToolError::Unknown { name } = err else {
panic!("expected Unknown variant");
};
assert_eq!(name, "nope");
}
#[tokio::test]
async fn dispatch_propagates_invalid_input_error_from_tool() {
let mut registry = ToolRegistry::new();
registry.register_tool(UpperTool);
let err = registry.dispatch("upper", json!({})).await.unwrap_err();
let ToolError::InvalidInput(msg) = err else {
panic!("expected InvalidInput");
};
assert!(msg.contains("'text'"));
}
#[tokio::test]
async fn duplicate_register_replaces_previous_entry() {
let mut registry = ToolRegistry::new();
registry.register("dup", echo_schema(), |_| async move {
Ok(json!({"version": "first"}))
});
registry.register("dup", echo_schema(), |_| async move {
Ok(json!({"version": "second"}))
});
assert_eq!(registry.len(), 1);
let r = registry.dispatch("dup", json!({})).await.unwrap();
assert_eq!(r, json!({"version": "second"}));
}
#[test]
fn to_messages_tools_includes_name_schema_and_description() {
let mut registry = ToolRegistry::new();
registry.register_tool(UpperTool).register_described(
"echo",
"Returns its input verbatim.",
echo_schema(),
|input| async move { Ok(input) },
);
let tools = registry.to_messages_tools();
assert_eq!(tools.len(), 2);
let mut by_name: std::collections::HashMap<String, MessagesTool> =
std::collections::HashMap::new();
for t in tools {
let MessagesTool::Custom(ct) = &t else {
panic!("expected custom variant");
};
by_name.insert(ct.name.clone(), t);
}
let MessagesTool::Custom(echo) = by_name.get("echo").unwrap() else {
panic!("expected echo Custom");
};
assert_eq!(
echo.description.as_deref(),
Some("Returns its input verbatim.")
);
assert!(echo.input_schema.is_object());
let MessagesTool::Custom(upper) = by_name.get("upper").unwrap() else {
panic!("expected upper Custom");
};
assert_eq!(upper.description, None); }
#[tokio::test]
async fn registry_works_through_dyn_dispatch() {
let mut registry = ToolRegistry::new();
registry.register_tool(UpperTool);
let tool: &Arc<dyn Tool> = registry.get("upper").unwrap();
let r = tool.invoke(json!({"text": "abc"})).await.unwrap();
assert_eq!(r, json!({"upper": "ABC"}));
}
#[test]
fn debug_impl_lists_tool_names() {
let mut registry = ToolRegistry::new();
registry.register_tool(UpperTool);
let dbg = format!("{registry:?}");
assert!(dbg.contains("upper"), "{dbg}");
}
#[test]
fn registry_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ToolRegistry>();
}
}