#![deny(missing_docs)]
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
#[non_exhaustive]
#[derive(Debug, Error)]
pub enum ToolError {
#[error("tool not found: {0}")]
NotFound(String),
#[error("execution failed: {0}")]
ExecutionFailed(String),
#[error("invalid input: {0}")]
InvalidInput(String),
#[error("{0}")]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum ToolConcurrencyHint {
Shared,
#[default]
Exclusive,
}
pub trait ToolDynStreaming: Send + Sync + 'static + ToolDyn {
fn call_streaming<'a>(
&'a self,
input: serde_json::Value,
on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>>;
}
pub trait ToolDyn: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_schema(&self) -> serde_json::Value;
fn call(
&self,
input: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>;
fn maybe_streaming(&self) -> Option<&dyn ToolDynStreaming> {
None
}
fn concurrency_hint(&self) -> ToolConcurrencyHint {
ToolConcurrencyHint::Exclusive
}
}
pub struct AliasedTool {
alias: String,
inner: Arc<dyn ToolDyn>,
}
impl AliasedTool {
pub fn new(alias: impl Into<String>, inner: Arc<dyn ToolDyn>) -> Self {
Self {
alias: alias.into(),
inner,
}
}
pub fn inner(&self) -> &Arc<dyn ToolDyn> {
&self.inner
}
}
impl ToolDyn for AliasedTool {
fn name(&self) -> &str {
&self.alias
}
fn description(&self) -> &str {
self.inner.description()
}
fn input_schema(&self) -> serde_json::Value {
self.inner.input_schema()
}
fn call(
&self,
input: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>> {
self.inner.call(input)
}
fn concurrency_hint(&self) -> ToolConcurrencyHint {
self.inner.concurrency_hint()
}
}
#[derive(Clone)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn ToolDyn>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Arc<dyn ToolDyn>) {
self.tools.insert(tool.name().to_string(), tool);
}
pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolDyn>> {
self.tools.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn ToolDyn>> {
self.tools.values()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn _assert_send_sync<T: Send + Sync>() {}
#[test]
fn tool_dyn_is_object_safe() {
_assert_send_sync::<Arc<dyn ToolDyn>>();
}
#[test]
fn tool_error_display() {
assert_eq!(
ToolError::NotFound("bash".into()).to_string(),
"tool not found: bash"
);
assert_eq!(
ToolError::ExecutionFailed("timeout".into()).to_string(),
"execution failed: timeout"
);
assert_eq!(
ToolError::InvalidInput("missing field".into()).to_string(),
"invalid input: missing field"
);
}
struct EchoTool;
impl ToolDyn for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"Echoes input back"
}
fn input_schema(&self) -> serde_json::Value {
json!({"type": "object"})
}
fn call(
&self,
input: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
{
Box::pin(async move { Ok(json!({"echoed": input})) })
}
}
struct FailTool;
impl ToolDyn for FailTool {
fn name(&self) -> &str {
"fail"
}
fn description(&self) -> &str {
"Always fails"
}
fn input_schema(&self) -> serde_json::Value {
json!({"type": "object"})
}
fn call(
&self,
_input: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
{
Box::pin(async { Err(ToolError::ExecutionFailed("always fails".into())) })
}
}
#[test]
fn registry_add_and_get() {
let mut reg = ToolRegistry::new();
assert!(reg.is_empty());
reg.register(Arc::new(EchoTool));
assert_eq!(reg.len(), 1);
assert!(reg.get("echo").is_some());
assert!(reg.get("nonexistent").is_none());
}
#[test]
fn registry_iter() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(EchoTool));
reg.register(Arc::new(FailTool));
let names: Vec<&str> = reg.iter().map(|t| t.name()).collect();
assert!(names.contains(&"echo"));
assert!(names.contains(&"fail"));
}
#[tokio::test]
async fn registry_call_tool() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(EchoTool));
let tool = reg.get("echo").unwrap();
let result = tool.call(json!({"msg": "hello"})).await.unwrap();
assert_eq!(result, json!({"echoed": {"msg": "hello"}}));
}
#[tokio::test]
async fn aliased_tool_exposes_alias_name_and_delegates() {
let inner: Arc<dyn ToolDyn> = Arc::new(EchoTool);
let tool: Arc<dyn ToolDyn> = Arc::new(AliasedTool::new("echo_alias", Arc::clone(&inner)));
assert_eq!(tool.name(), "echo_alias");
assert_eq!(tool.description(), inner.description());
let result = tool.call(json!({"msg": "hi"})).await.unwrap();
assert_eq!(result, json!({"echoed": {"msg": "hi"}}));
}
#[tokio::test]
async fn registry_call_failing_tool() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(FailTool));
let tool = reg.get("fail").unwrap();
let result = tool.call(json!({})).await;
assert!(result.is_err());
}
#[test]
fn registry_overwrite() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(EchoTool));
assert_eq!(reg.len(), 1);
reg.register(Arc::new(EchoTool));
assert_eq!(reg.len(), 1);
}
struct StreamerTool;
impl ToolDyn for StreamerTool {
fn name(&self) -> &str {
"streamer"
}
fn description(&self) -> &str {
"Streams chunks"
}
fn input_schema(&self) -> serde_json::Value {
json!({"type":"object"})
}
fn call(
&self,
_input: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, ToolError>> + Send + '_>>
{
Box::pin(async { Ok(serde_json::json!({"status":"done"})) })
}
fn maybe_streaming(&self) -> Option<&dyn ToolDynStreaming> {
Some(self)
}
}
impl ToolDynStreaming for StreamerTool {
fn call_streaming<'a>(
&'a self,
_input: serde_json::Value,
on_chunk: Box<dyn Fn(&str) + Send + Sync + 'a>,
) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>> {
Box::pin(async move {
on_chunk("one");
on_chunk("two");
on_chunk("three");
Ok(())
})
}
}
#[tokio::test]
async fn streaming_tool_emits_chunks_and_completes() {
use std::sync::{
Arc as StdArc, Mutex,
atomic::{AtomicUsize, Ordering},
};
let count = StdArc::new(AtomicUsize::new(0));
let seen: StdArc<Mutex<Vec<String>>> = StdArc::new(Mutex::new(vec![]));
let c2 = count.clone();
let s2 = seen.clone();
let tool = StreamerTool;
let on_chunk = Box::new(move |c: &str| {
c2.fetch_add(1, Ordering::SeqCst);
s2.lock().unwrap().push(c.to_string());
});
let res = tool.call_streaming(serde_json::json!({}), on_chunk).await;
assert!(res.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 3);
let got = seen.lock().unwrap().clone();
assert_eq!(got, vec!["one", "two", "three"]);
}
}