use crate::resources::ResourceMetadata;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters_schema: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: serde_json::Value,
pub error: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ToolResult {
pub fn success(output: serde_json::Value) -> Self {
Self {
success: true,
output,
error: None,
metadata: HashMap::new(),
}
}
pub fn failure(error: String) -> Self {
Self {
success: false,
output: serde_json::Value::Null,
error: Some(error),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ToolParameterType {
String,
Number,
Integer,
Boolean,
Array,
Object,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameter {
pub name: String,
#[serde(rename = "type")]
pub param_type: ToolParameterType,
pub description: Option<String>,
pub required: bool,
pub default: Option<serde_json::Value>,
pub items: Option<Box<ToolParameterType>>,
pub properties: Option<HashMap<String, ToolParameter>>,
}
impl ToolParameter {
pub fn new(name: impl Into<String>, param_type: ToolParameterType) -> Self {
Self {
name: name.into(),
param_type,
description: None,
required: false,
default: None,
items: None,
properties: None,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn with_default(mut self, default: serde_json::Value) -> Self {
self.default = Some(default);
self
}
pub fn with_items(mut self, item_type: ToolParameterType) -> Self {
self.items = Some(Box::new(item_type));
self
}
pub fn with_properties(mut self, properties: HashMap<String, ToolParameter>) -> Self {
self.properties = Some(properties);
self
}
pub fn to_json_schema(&self) -> serde_json::Value {
let mut schema = match &self.param_type {
ToolParameterType::String => serde_json::json!({"type": "string"}),
ToolParameterType::Number => serde_json::json!({"type": "number"}),
ToolParameterType::Integer => serde_json::json!({"type": "integer"}),
ToolParameterType::Boolean => serde_json::json!({"type": "boolean"}),
ToolParameterType::Array => {
let items_schema = self
.items
.as_ref()
.map(|t| param_type_to_schema(t))
.unwrap_or_else(|| serde_json::json!({}));
serde_json::json!({"type": "array", "items": items_schema})
}
ToolParameterType::Object => {
if let Some(props) = &self.properties {
let properties: serde_json::Map<String, serde_json::Value> = props
.iter()
.map(|(k, v)| (k.clone(), v.to_json_schema()))
.collect();
let required: Vec<&str> = props
.values()
.filter(|p| p.required)
.map(|p| p.name.as_str())
.collect();
serde_json::json!({
"type": "object",
"properties": properties,
"required": required
})
} else {
serde_json::json!({"type": "object"})
}
}
};
if let Some(desc) = &self.description {
schema["description"] = serde_json::Value::String(desc.clone());
}
schema
}
}
fn param_type_to_schema(t: &ToolParameterType) -> serde_json::Value {
match t {
ToolParameterType::String => serde_json::json!({"type": "string"}),
ToolParameterType::Number => serde_json::json!({"type": "number"}),
ToolParameterType::Integer => serde_json::json!({"type": "integer"}),
ToolParameterType::Boolean => serde_json::json!({"type": "boolean"}),
ToolParameterType::Array => serde_json::json!({"type": "array"}),
ToolParameterType::Object => serde_json::json!({"type": "object"}),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolMetadata {
pub name: String,
pub description: String,
pub parameters: Vec<ToolParameter>,
pub protocol_metadata: HashMap<String, serde_json::Value>,
}
impl ToolMetadata {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: Vec::new(),
protocol_metadata: HashMap::new(),
}
}
pub fn with_parameter(mut self, param: ToolParameter) -> Self {
self.parameters.push(param);
self
}
pub fn with_protocol_metadata(
mut self,
key: impl Into<String>,
value: serde_json::Value,
) -> Self {
self.protocol_metadata.insert(key.into(), value);
self
}
pub fn to_tool_definition(&self) -> ToolDefinition {
let mut properties = serde_json::Map::new();
let mut required: Vec<String> = Vec::new();
for param in &self.parameters {
properties.insert(param.name.clone(), param.to_json_schema());
if param.required {
required.push(param.name.clone());
}
}
let parameters_schema = serde_json::json!({
"type": "object",
"properties": properties,
"required": required
});
ToolDefinition {
name: self.name.clone(),
description: self.description.clone(),
parameters_schema,
}
}
}
#[async_trait]
pub trait ToolProtocol: Send + Sync {
async fn execute(
&self,
tool_name: &str,
parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>>;
async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>>;
async fn get_tool_metadata(
&self,
tool_name: &str,
) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>>;
fn protocol_name(&self) -> &str;
async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
Ok(())
}
async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
Ok(())
}
async fn list_resources(&self) -> Result<Vec<ResourceMetadata>, Box<dyn Error + Send + Sync>> {
Ok(Vec::new())
}
async fn read_resource(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
Err(format!("Resource not found: {}", uri).into())
}
fn supports_resources(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub enum ToolError {
NotFound(String),
ExecutionFailed(String),
InvalidParameters(String),
ProtocolError(String),
}
impl fmt::Display for ToolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolError::NotFound(name) => write!(f, "Tool not found: {}", name),
ToolError::ExecutionFailed(msg) => write!(f, "Tool execution failed: {}", msg),
ToolError::InvalidParameters(msg) => write!(f, "Invalid parameters: {}", msg),
ToolError::ProtocolError(msg) => write!(f, "Protocol error: {}", msg),
}
}
}
impl Error for ToolError {}
pub struct Tool {
metadata: ToolMetadata,
protocol: Arc<dyn ToolProtocol>,
}
impl Tool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
protocol: Arc<dyn ToolProtocol>,
) -> Self {
Self {
metadata: ToolMetadata::new(name, description),
protocol,
}
}
pub fn with_parameter(mut self, param: ToolParameter) -> Self {
self.metadata.parameters.push(param);
self
}
pub fn with_protocol_metadata(
mut self,
key: impl Into<String>,
value: serde_json::Value,
) -> Self {
self.metadata.protocol_metadata.insert(key.into(), value);
self
}
pub fn metadata(&self) -> &ToolMetadata {
&self.metadata
}
pub async fn execute(
&self,
parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
self.protocol.execute(&self.metadata.name, parameters).await
}
}
pub struct ToolRegistry {
tools: HashMap<String, Tool>,
tool_to_protocol: HashMap<String, String>,
protocols: HashMap<String, Arc<dyn ToolProtocol>>,
primary_protocol: Option<Arc<dyn ToolProtocol>>,
}
impl ToolRegistry {
pub fn new(protocol: Arc<dyn ToolProtocol>) -> Self {
Self {
tools: HashMap::new(),
tool_to_protocol: HashMap::new(),
protocols: {
let mut m = HashMap::new();
m.insert("primary".to_string(), protocol.clone());
m
},
primary_protocol: Some(protocol),
}
}
pub fn empty() -> Self {
Self {
tools: HashMap::new(),
tool_to_protocol: HashMap::new(),
protocols: HashMap::new(),
primary_protocol: None,
}
}
pub async fn add_protocol(
&mut self,
protocol_name: &str,
protocol: Arc<dyn ToolProtocol>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let discovered_tools = protocol.list_tools().await?;
self.protocols
.insert(protocol_name.to_string(), protocol.clone());
for tool_meta in discovered_tools {
let tool_name = tool_meta.name.clone();
let tool = Tool::new(
tool_name.clone(),
tool_meta.description.clone(),
protocol.clone(),
);
let mut tool = tool;
for param in &tool_meta.parameters {
tool = tool.with_parameter(param.clone());
}
for (key, value) in &tool_meta.protocol_metadata {
tool = tool.with_protocol_metadata(key.clone(), value.clone());
}
self.tools.insert(tool_name.clone(), tool);
self.tool_to_protocol
.insert(tool_name, protocol_name.to_string());
}
Ok(())
}
pub fn remove_protocol(&mut self, protocol_name: &str) {
self.protocols.remove(protocol_name);
let tools_to_remove: Vec<String> = self
.tool_to_protocol
.iter()
.filter(|(_, pn)| *pn == protocol_name)
.map(|(tn, _)| tn.clone())
.collect();
for tool_name in tools_to_remove {
self.tools.remove(&tool_name);
self.tool_to_protocol.remove(&tool_name);
}
}
pub fn add_tool(&mut self, tool: Tool) {
self.tools.insert(tool.metadata.name.clone(), tool);
}
pub fn remove_tool(&mut self, name: &str) -> Option<Tool> {
self.tool_to_protocol.remove(name);
self.tools.remove(name)
}
pub fn get_tool(&self, name: &str) -> Option<&Tool> {
self.tools.get(name)
}
pub fn list_tools(&self) -> Vec<&ToolMetadata> {
self.tools.values().map(|t| &t.metadata).collect()
}
pub async fn discover_tools_from_primary(
&mut self,
) -> Result<(), Box<dyn Error + Send + Sync>> {
if let Some(protocol) = &self.primary_protocol {
let discovered_tools = protocol.list_tools().await?;
for tool_meta in discovered_tools {
let tool_name = tool_meta.name.clone();
let tool = Tool::new(
tool_name.clone(),
tool_meta.description.clone(),
protocol.clone(),
);
let mut tool = tool;
for param in &tool_meta.parameters {
tool = tool.with_parameter(param.clone());
}
for (key, value) in &tool_meta.protocol_metadata {
tool = tool.with_protocol_metadata(key.clone(), value.clone());
}
self.tools.insert(tool_name.clone(), tool);
self.tool_to_protocol
.insert(tool_name, "primary".to_string());
}
Ok(())
} else {
Err("No primary protocol available".into())
}
}
pub fn get_tool_protocol(&self, tool_name: &str) -> Option<&str> {
self.tool_to_protocol.get(tool_name).map(|s| s.as_str())
}
pub fn list_protocols(&self) -> Vec<&str> {
self.protocols.keys().map(|s| s.as_str()).collect()
}
pub async fn execute_tool(
&self,
tool_name: &str,
parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
let tool = self
.tools
.get(tool_name)
.ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
tool.execute(parameters).await
}
pub fn protocol(&self) -> Option<&Arc<dyn ToolProtocol>> {
self.primary_protocol.as_ref()
}
pub fn to_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|t| t.metadata.to_tool_definition())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockProtocol;
#[async_trait]
impl ToolProtocol for MockProtocol {
async fn execute(
&self,
tool_name: &str,
_parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
Ok(ToolResult::success(serde_json::json!({
"tool": tool_name,
"result": "mock_result"
})))
}
async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
Ok(vec![])
}
async fn get_tool_metadata(
&self,
_tool_name: &str,
) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
Ok(ToolMetadata::new("mock_tool", "A mock tool"))
}
fn protocol_name(&self) -> &str {
"mock"
}
}
#[test]
fn test_tool_parameter_builder() {
let param = ToolParameter::new("test_param", ToolParameterType::String)
.with_description("A test parameter")
.required()
.with_default(serde_json::json!("default_value"));
assert_eq!(param.name, "test_param");
assert_eq!(param.param_type, ToolParameterType::String);
assert_eq!(param.description, Some("A test parameter".to_string()));
assert!(param.required);
assert_eq!(param.default, Some(serde_json::json!("default_value")));
}
#[tokio::test]
async fn test_tool_execution() {
let protocol = Arc::new(MockProtocol);
let tool = Tool::new("test_tool", "A test tool", protocol.clone());
let result = tool.execute(serde_json::json!({})).await.unwrap();
assert!(result.success);
assert_eq!(result.output["tool"], "test_tool");
}
#[tokio::test]
async fn test_tool_registry() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::new(protocol.clone());
let tool = Tool::new("calculator", "Performs calculations", protocol.clone());
registry.add_tool(tool);
assert!(registry.get_tool("calculator").is_some());
assert_eq!(registry.list_tools().len(), 1);
let result = registry
.execute_tool("calculator", serde_json::json!({}))
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_empty_registry_creation() {
let registry = ToolRegistry::empty();
assert_eq!(registry.list_tools().len(), 0);
assert_eq!(registry.list_protocols().len(), 0);
assert!(registry.protocol().is_none());
}
#[tokio::test]
async fn test_add_single_protocol_to_empty_registry() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("mock", protocol.clone())
.await
.unwrap();
assert_eq!(registry.list_protocols().len(), 1);
assert!(registry.list_protocols().contains(&"mock"));
}
#[tokio::test]
async fn test_add_multiple_protocols() {
let protocol1 = Arc::new(MockProtocol);
let protocol2 = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("protocol1", protocol1.clone())
.await
.unwrap();
registry
.add_protocol("protocol2", protocol2.clone())
.await
.unwrap();
assert_eq!(registry.list_protocols().len(), 2);
assert!(registry.list_protocols().contains(&"protocol1"));
assert!(registry.list_protocols().contains(&"protocol2"));
}
#[tokio::test]
async fn test_remove_protocol() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("protocol1", protocol.clone())
.await
.unwrap();
assert_eq!(registry.list_protocols().len(), 1);
registry.remove_protocol("protocol1");
assert_eq!(registry.list_protocols().len(), 0);
}
#[tokio::test]
async fn test_get_tool_protocol() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("local", protocol.clone())
.await
.unwrap();
let tool = Tool::new("calculator", "Performs calculations", protocol.clone());
registry.add_tool(tool);
registry
.tool_to_protocol
.insert("calculator".to_string(), "local".to_string());
assert_eq!(registry.get_tool_protocol("calculator"), Some("local"));
assert_eq!(registry.get_tool_protocol("nonexistent"), None);
}
#[tokio::test]
async fn test_remove_protocol_removes_tools() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("protocol1", protocol.clone())
.await
.unwrap();
let tool1 = Tool::new("tool1", "First tool", protocol.clone());
registry.add_tool(tool1);
registry
.tool_to_protocol
.insert("tool1".to_string(), "protocol1".to_string());
let tool2 = Tool::new("tool2", "Second tool", protocol.clone());
registry.add_tool(tool2);
registry
.tool_to_protocol
.insert("tool2".to_string(), "protocol1".to_string());
assert_eq!(registry.list_tools().len(), 2);
registry.remove_protocol("protocol1");
assert_eq!(registry.list_tools().len(), 0);
assert_eq!(registry.get_tool_protocol("tool1"), None);
assert_eq!(registry.get_tool_protocol("tool2"), None);
}
#[tokio::test]
async fn test_execute_tool_through_registry() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("mock", protocol.clone())
.await
.unwrap();
let tool = Tool::new("test_tool", "A test tool", protocol.clone());
registry.add_tool(tool);
let result = registry
.execute_tool("test_tool", serde_json::json!({}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output["tool"], "test_tool");
}
#[tokio::test]
async fn test_backwards_compatibility_single_protocol() {
let protocol = Arc::new(MockProtocol);
let registry = ToolRegistry::new(protocol.clone());
assert!(registry.protocol().is_some());
assert_eq!(registry.list_protocols().len(), 1);
assert!(registry.list_protocols().contains(&"primary"));
}
#[tokio::test]
async fn test_discover_tools_from_primary() {
struct TestProtocol {
tools: Vec<ToolMetadata>,
}
#[async_trait]
impl ToolProtocol for TestProtocol {
async fn execute(
&self,
tool_name: &str,
_parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
Ok(ToolResult::success(serde_json::json!({
"tool": tool_name,
})))
}
async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
Ok(self.tools.clone())
}
async fn get_tool_metadata(
&self,
tool_name: &str,
) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
self.tools
.iter()
.find(|t| t.name == tool_name)
.cloned()
.ok_or_else(|| "Tool not found".into())
}
fn protocol_name(&self) -> &str {
"test"
}
}
let protocol = Arc::new(TestProtocol {
tools: vec![
ToolMetadata::new("tool1", "First tool"),
ToolMetadata::new("tool2", "Second tool"),
],
});
let mut registry = ToolRegistry::new(protocol.clone());
assert_eq!(registry.list_tools().len(), 0);
registry.discover_tools_from_primary().await.unwrap();
assert_eq!(registry.list_tools().len(), 2);
assert!(registry.get_tool("tool1").is_some());
assert!(registry.get_tool("tool2").is_some());
assert_eq!(registry.get_tool_protocol("tool1"), Some("primary"));
assert_eq!(registry.get_tool_protocol("tool2"), Some("primary"));
}
#[test]
fn test_to_json_schema_string() {
let param =
ToolParameter::new("q", ToolParameterType::String).with_description("Search query");
let schema = param.to_json_schema();
assert_eq!(schema["type"], "string");
assert_eq!(schema["description"], "Search query");
}
#[test]
fn test_to_json_schema_number() {
let param = ToolParameter::new("value", ToolParameterType::Number);
let schema = param.to_json_schema();
assert_eq!(schema["type"], "number");
assert!(schema.get("description").is_none());
}
#[test]
fn test_to_json_schema_integer() {
let schema = ToolParameter::new("n", ToolParameterType::Integer).to_json_schema();
assert_eq!(schema["type"], "integer");
}
#[test]
fn test_to_json_schema_boolean() {
let schema = ToolParameter::new("flag", ToolParameterType::Boolean).to_json_schema();
assert_eq!(schema["type"], "boolean");
}
#[test]
fn test_to_json_schema_array_with_items() {
let param = ToolParameter::new("ids", ToolParameterType::Array)
.with_items(ToolParameterType::Integer);
let schema = param.to_json_schema();
assert_eq!(schema["type"], "array");
assert_eq!(schema["items"]["type"], "integer");
}
#[test]
fn test_to_json_schema_array_without_items() {
let schema = ToolParameter::new("items", ToolParameterType::Array).to_json_schema();
assert_eq!(schema["type"], "array");
assert!(schema.get("items").is_some());
}
#[test]
fn test_to_json_schema_object_with_properties() {
use std::collections::HashMap;
let mut props = HashMap::new();
props.insert(
"name".to_string(),
ToolParameter::new("name", ToolParameterType::String)
.with_description("Person's name")
.required(),
);
props.insert(
"age".to_string(),
ToolParameter::new("age", ToolParameterType::Integer),
);
let param = ToolParameter::new("person", ToolParameterType::Object).with_properties(props);
let schema = param.to_json_schema();
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["name"].is_object());
assert!(schema["properties"]["age"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.iter().any(|v| v.as_str() == Some("name")));
assert!(!required.iter().any(|v| v.as_str() == Some("age")));
}
#[test]
fn test_to_tool_definition_roundtrip() {
let meta = ToolMetadata::new("calculator", "Evaluates a math expression")
.with_parameter(
ToolParameter::new("expression", ToolParameterType::String)
.with_description("The expression")
.required(),
)
.with_parameter(
ToolParameter::new("precision", ToolParameterType::Integer)
.with_description("Decimal places"),
);
let def = meta.to_tool_definition();
assert_eq!(def.name, "calculator");
assert_eq!(def.description, "Evaluates a math expression");
assert_eq!(def.parameters_schema["type"], "object");
assert!(def.parameters_schema["properties"]["expression"].is_object());
assert!(def.parameters_schema["properties"]["precision"].is_object());
let required = def.parameters_schema["required"].as_array().unwrap();
assert!(required.iter().any(|v| v.as_str() == Some("expression")));
assert!(!required.iter().any(|v| v.as_str() == Some("precision")));
}
#[test]
fn test_to_tool_definitions_collects_all() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
let tool_a = Tool::new("tool_a", "First tool", protocol.clone());
registry.add_tool(tool_a);
let tool_b = Tool::new("tool_b", "Second tool", protocol.clone());
registry.add_tool(tool_b);
let defs = registry.to_tool_definitions();
assert_eq!(defs.len(), 2);
let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
assert!(names.contains(&"tool_a"));
assert!(names.contains(&"tool_b"));
}
#[test]
fn test_to_tool_definitions_empty_registry() {
let registry = ToolRegistry::empty();
let defs = registry.to_tool_definitions();
assert!(defs.is_empty());
}
}