use alloc::{boxed::Box, format, vec::Vec};
use super::{
rust_tool::{ErasedTool, RustTool, definition_of},
types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
};
use crate::compat::HashMap;
struct RegisteredTool {
definition: ToolDefinition,
erased: Box<dyn ErasedTool>,
}
pub struct ToolRegistry {
tools: HashMap<&'static str, RegisteredTool>,
}
impl core::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let names: Vec<&str> = self
.tools
.values()
.map(|r| r.definition.name.as_str())
.collect();
f.debug_struct("ToolRegistry")
.field("tool_count", &self.tools.len())
.field("tool_names", &names)
.finish()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
#[must_use]
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
let definition = definition_of(&tool)
.unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
self.tools.insert(
T::NAME,
RegisteredTool {
definition,
erased: Box::new(tool),
},
);
self
}
#[must_use]
pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
self.register(tool);
self
}
#[must_use]
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|entry| entry.definition.clone())
.collect()
}
pub async fn dispatch(
&self,
name: &str,
args: serde_json::Value,
ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
let entry = self
.tools
.get(name)
.ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
entry.erased.call_erased(args, ctx).await
}
pub async fn dispatch_str(
&self,
name: &str,
args_json: &str,
ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
let args = serde_json::from_str(args_json)
.map_err(|e| ToolError::new(format!("Malformed JSON arguments: {e}")))?;
self.dispatch(name, args, ctx).await
}
#[must_use]
pub fn len(&self) -> usize {
self.tools.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
self.tools
.iter()
.map(|(name, entry)| (*name, entry.definition.clone()))
}
}
impl<'a> IntoIterator for &'a ToolRegistry {
type Item = (&'static str, ToolDefinition);
type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
fn into_iter(self) -> Self::IntoIter {
Box::new(
self.tools
.iter()
.map(|(name, entry)| (*name, entry.definition.clone())),
)
}
}
#[cfg(all(test, feature = "std"))]
mod tests;