use async_trait::async_trait;
#[async_trait]
pub trait Tool: Send + Sync + 'static {
fn name(&self) -> &str;
fn schema(&self) -> serde_json::Value;
fn description(&self) -> Option<&str> {
None
}
async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError>;
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ToolError {
#[error("invalid tool input: {0}")]
InvalidInput(String),
#[error("tool execution failed: {0}")]
Execution(Box<dyn std::error::Error + Send + Sync>),
#[error("no tool registered with name '{name}'")]
Unknown {
name: String,
},
}
impl ToolError {
pub fn invalid_input(msg: impl Into<String>) -> Self {
Self::InvalidInput(msg.into())
}
pub fn execution<E>(err: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Execution(Box::new(err))
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ApprovalDecision {
Approve,
ApproveWithInput(serde_json::Value),
Substitute(serde_json::Value),
Deny(String),
Stop(String),
}
#[async_trait]
pub trait ToolApprover: Send + Sync + 'static {
async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision;
}
#[must_use]
pub fn fn_approver<F, Fut>(handler: F) -> std::sync::Arc<dyn ToolApprover>
where
F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
{
std::sync::Arc::new(FnApprover { handler })
}
struct FnApprover<F> {
handler: F,
}
#[async_trait]
impl<F, Fut> ToolApprover for FnApprover<F>
where
F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
{
async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision {
(self.handler)(tool_name, input).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{Value, json};
use std::sync::Arc;
struct AddTool;
#[async_trait]
impl Tool for AddTool {
#[allow(clippy::unnecessary_literal_bound)]
fn name(&self) -> &str {
"add"
}
#[allow(clippy::unnecessary_literal_bound)]
fn description(&self) -> Option<&str> {
Some("Add two numbers and return the sum.")
}
fn schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"}
},
"required": ["a", "b"]
})
}
async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
let a = input
.get("a")
.and_then(Value::as_f64)
.ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
let b = input
.get("b")
.and_then(Value::as_f64)
.ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
Ok(json!({"sum": a + b}))
}
}
#[tokio::test]
async fn manual_impl_round_trips_a_value() {
let tool = AddTool;
let result = tool.invoke(json!({"a": 2, "b": 3})).await.unwrap();
assert_eq!(result, json!({"sum": 5.0}));
}
#[tokio::test]
async fn trait_object_dispatch_works() {
let tool: Arc<dyn Tool> = Arc::new(AddTool);
assert_eq!(tool.name(), "add");
assert_eq!(
tool.description(),
Some("Add two numbers and return the sum.")
);
assert!(tool.schema().is_object());
let result = tool.invoke(json!({"a": 4, "b": 1})).await.unwrap();
assert_eq!(result["sum"], 5.0);
}
#[tokio::test]
async fn invalid_input_propagates_message() {
let tool = AddTool;
let err = tool.invoke(json!({"a": 1})).await.unwrap_err();
let ToolError::InvalidInput(msg) = err else {
panic!("expected InvalidInput");
};
assert!(msg.contains("'b'"), "{msg}");
}
#[test]
fn invalid_input_constructor_takes_string_or_str() {
let _ = ToolError::invalid_input("plain str");
let _ = ToolError::invalid_input(String::from("owned"));
}
#[test]
fn execution_wraps_any_std_error() {
let inner = std::io::Error::other("disk on fire");
let err = ToolError::execution(inner);
let display = format!("{err}");
assert!(display.contains("disk on fire"), "{display}");
let ToolError::Execution(_) = err else {
panic!("expected Execution");
};
}
#[test]
fn tool_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync + ?Sized>() {}
assert_send_sync::<dyn Tool>();
}
}