use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::ToolError;
use crate::provider::{Message, ToolCall, ToolSpec};
pub type ToolResult<T> = std::result::Result<T, ToolError>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolOutput {
pub value: Value,
}
impl ToolOutput {
#[must_use]
pub const fn new(value: Value) -> Self {
Self { value }
}
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
Self {
value: Value::String(text.into()),
}
}
#[must_use]
pub fn error(message: impl Into<String>) -> Self {
Self {
value: serde_json::json!({ "error": message.into() }),
}
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> Value;
async fn execute(&self, arguments: Value) -> ToolResult<ToolOutput>;
fn is_read_only(&self) -> bool {
false
}
fn is_concurrency_safe(&self) -> bool {
false
}
fn to_spec(&self) -> ToolSpec {
ToolSpec::new(self.name(), self.description(), self.parameters_schema())
}
}
pub struct FunctionTool<F> {
name: String,
description: String,
parameters_schema: Value,
handler: F,
read_only: bool,
concurrency_safe: bool,
}
impl<F, Fut> FunctionTool<F>
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult<Value>> + Send + 'static,
{
#[must_use]
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters_schema: Value,
handler: F,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters_schema,
handler,
read_only: false,
concurrency_safe: false,
}
}
#[must_use]
pub fn read_only(mut self) -> Self {
self.read_only = true;
self
}
#[must_use]
pub fn concurrency_safe(mut self) -> Self {
self.concurrency_safe = true;
self
}
}
#[async_trait]
impl<F, Fut> Tool for FunctionTool<F>
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult<Value>> + Send + 'static,
{
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> Value {
self.parameters_schema.clone()
}
async fn execute(&self, arguments: Value) -> ToolResult<ToolOutput> {
let value = (self.handler)(arguments).await?;
Ok(ToolOutput::new(value))
}
fn is_read_only(&self) -> bool {
self.read_only
}
fn is_concurrency_safe(&self) -> bool {
self.concurrency_safe
}
}
pub struct ExternalTool {
name: String,
description: String,
parameters_schema: Value,
endpoint: Option<String>,
read_only: bool,
concurrency_safe: bool,
}
impl ExternalTool {
#[must_use]
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters_schema: Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters_schema,
endpoint: None,
read_only: false,
concurrency_safe: false,
}
}
#[must_use]
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = Some(endpoint.into());
self
}
#[must_use]
pub fn endpoint(&self) -> Option<&str> {
self.endpoint.as_deref()
}
#[must_use]
pub fn read_only(mut self) -> Self {
self.read_only = true;
self
}
#[must_use]
pub fn concurrency_safe(mut self) -> Self {
self.concurrency_safe = true;
self
}
}
#[async_trait]
impl Tool for ExternalTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> Value {
self.parameters_schema.clone()
}
async fn execute(&self, _arguments: Value) -> ToolResult<ToolOutput> {
Err(ToolError::NotImplemented {
name: self.name.clone(),
})
}
fn is_read_only(&self) -> bool {
self.read_only
}
fn is_concurrency_safe(&self) -> bool {
self.concurrency_safe
}
}
pub struct ToolRegistry {
tools: std::sync::RwLock<HashMap<String, Arc<dyn Tool>>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self {
tools: std::sync::RwLock::new(HashMap::new()),
}
}
}
impl Clone for ToolRegistry {
fn clone(&self) -> Self {
let guard = self
.tools
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner);
Self {
tools: std::sync::RwLock::new(guard.clone()),
}
}
}
impl ToolRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<T>(&self, tool: T) -> Option<Arc<dyn Tool>>
where
T: Tool + 'static,
{
let name = tool.name().to_owned();
self.tools
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(name, Arc::new(tool))
}
pub fn register_arc(&self, tool: Arc<dyn Tool>) -> Option<Arc<dyn Tool>> {
let name = tool.name().to_owned();
self.tools
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(name, tool)
}
pub fn unregister(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.remove(name)
}
#[must_use]
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.get(name)
.map(Arc::clone)
}
#[must_use]
pub fn names(&self) -> Vec<String> {
self.tools
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.keys()
.cloned()
.collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.tools
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tools
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.is_empty()
}
#[must_use]
pub fn specs(&self) -> Vec<ToolSpec> {
let guard = self
.tools
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let mut specs: Vec<ToolSpec> = guard.values().map(|t| t.to_spec()).collect();
specs.sort_by(|a, b| a.name.cmp(&b.name));
specs
}
pub async fn execute(&self, call: &ToolCall) -> ToolResult<ToolOutput> {
let tool = self.get(&call.name).ok_or_else(|| ToolError::NotFound {
name: call.name.clone(),
})?;
tool.execute(call.arguments.clone()).await
}
pub async fn execute_to_message(&self, call: &ToolCall) -> ToolResult<Message> {
let output = self.execute(call).await?;
Ok(Message::tool_text(
call.id.clone(),
call.name.clone(),
output.value.to_string(),
))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use serde_json::json;
async fn echo_handler(args: Value) -> ToolResult<Value> {
Ok(args.get("message").cloned().unwrap_or(Value::Null))
}
#[test]
fn tool_registry_should_be_empty_when_new() {
let registry = ToolRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[test]
fn tool_registry_should_register_function_tool() {
let registry = ToolRegistry::new();
let tool = FunctionTool::new(
"echo",
"Echoes input",
json!({"type": "object"}),
echo_handler,
);
registry.register(tool);
assert_eq!(registry.len(), 1);
assert!(registry.get("echo").is_some());
}
#[test]
fn tool_registry_should_return_none_for_unknown_tool() {
let registry = ToolRegistry::new();
assert!(registry.get("unknown").is_none());
}
#[test]
fn tool_registry_should_generate_specs() {
let registry = ToolRegistry::new();
let tool = FunctionTool::new(
"echo",
"Echoes input",
json!({"type": "object"}),
echo_handler,
);
registry.register(tool);
let specs = registry.specs();
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].name, "echo");
assert_eq!(specs[0].description, "Echoes input");
}
#[test]
fn tool_registry_should_replace_existing_tool() {
let registry = ToolRegistry::new();
let tool1 = FunctionTool::new("echo", "First", json!({}), echo_handler);
let tool2 = FunctionTool::new("echo", "Second", json!({}), echo_handler);
registry.register(tool1);
let replaced = registry.register(tool2);
assert!(replaced.is_some());
assert_eq!(registry.len(), 1);
}
#[tokio::test]
async fn tool_registry_should_execute_tool_call() {
let registry = ToolRegistry::new();
let tool = FunctionTool::new(
"echo",
"Echoes input",
json!({"type": "object"}),
echo_handler,
);
registry.register(tool);
let call = ToolCall::new("call_1", "echo", json!({"message": "hello"}));
let output = registry.execute(&call).await.unwrap();
assert_eq!(output.value, json!("hello"));
}
#[tokio::test]
async fn tool_registry_should_return_error_for_unknown_tool() {
let registry = ToolRegistry::new();
let call = ToolCall::new("call_1", "unknown", json!({}));
let result = registry.execute(&call).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ToolError::NotFound { .. }));
}
#[tokio::test]
async fn tool_registry_should_convert_output_to_message() {
let registry = ToolRegistry::new();
let tool = FunctionTool::new(
"echo",
"Echoes input",
json!({"type": "object"}),
echo_handler,
);
registry.register(tool);
let call = ToolCall::new("call_1", "echo", json!({"message": "hello"}));
let message = registry.execute_to_message(&call).await.unwrap();
match message {
Message::Tool {
tool_call_id,
name,
content,
} => {
assert_eq!(tool_call_id, "call_1");
assert_eq!(name, "echo");
assert!(!content.is_empty());
}
_ => panic!("expected Message::Tool"),
}
}
#[test]
fn external_tool_should_return_not_implemented() {
let tool = ExternalTool::new("external", "External tool", json!({}));
assert_eq!(tool.name(), "external");
assert!(tool.endpoint().is_none());
}
#[test]
fn external_tool_should_accept_endpoint() {
let tool = ExternalTool::new("external", "External tool", json!({}))
.with_endpoint("https://example.com/tool");
assert_eq!(tool.endpoint(), Some("https://example.com/tool"));
}
#[tokio::test]
async fn external_tool_execute_should_return_not_implemented() {
let tool = ExternalTool::new("external", "External tool", json!({}));
let result = tool.execute(json!({})).await;
assert!(matches!(result, Err(ToolError::NotImplemented { .. })));
}
#[test]
fn specs_should_return_sorted_by_name() {
let registry = ToolRegistry::new();
registry.register(FunctionTool::new(
"zebra",
"Zebra tool",
json!({}),
echo_handler,
));
registry.register(FunctionTool::new(
"alpha",
"Alpha tool",
json!({}),
echo_handler,
));
registry.register(FunctionTool::new(
"mike",
"Mike tool",
json!({}),
echo_handler,
));
let specs = registry.specs();
assert_eq!(specs.len(), 3);
assert_eq!(specs[0].name, "alpha");
assert_eq!(specs[1].name, "mike");
assert_eq!(specs[2].name, "zebra");
}
#[test]
fn function_tool_default_classification_is_false() {
let tool = FunctionTool::new("test", "desc", json!({}), |_| async { Ok(json!(null)) });
assert!(!tool.is_read_only());
assert!(!tool.is_concurrency_safe());
}
#[test]
fn function_tool_classification_builder() {
let tool = FunctionTool::new("test", "desc", json!({}), |_| async { Ok(json!(null)) })
.read_only()
.concurrency_safe();
assert!(tool.is_read_only());
assert!(tool.is_concurrency_safe());
}
#[test]
fn external_tool_default_classification_is_false() {
let tool = ExternalTool::new("test", "desc", json!({}));
assert!(!tool.is_read_only());
assert!(!tool.is_concurrency_safe());
}
#[test]
fn external_tool_classification_builder() {
let tool = ExternalTool::new("test", "desc", json!({}))
.read_only()
.concurrency_safe();
assert!(tool.is_read_only());
assert!(tool.is_concurrency_safe());
}
}