use crate::resources::ResourceMetadata;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters_schema: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: serde_json::Value,
pub error: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ToolResult {
pub fn success(output: serde_json::Value) -> Self {
Self {
success: true,
output,
error: None,
metadata: HashMap::new(),
}
}
pub fn failure(error: String) -> Self {
Self {
success: false,
output: serde_json::Value::Null,
error: Some(error),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ToolParameterType {
String,
Number,
Integer,
Boolean,
Array,
Object,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameter {
pub name: String,
#[serde(rename = "type")]
pub param_type: ToolParameterType,
pub description: Option<String>,
pub required: bool,
pub default: Option<serde_json::Value>,
pub items: Option<Box<ToolParameterType>>,
pub properties: Option<HashMap<String, ToolParameter>>,
}
impl ToolParameter {
pub fn new(name: impl Into<String>, param_type: ToolParameterType) -> Self {
Self {
name: name.into(),
param_type,
description: None,
required: false,
default: None,
items: None,
properties: None,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn with_default(mut self, default: serde_json::Value) -> Self {
self.default = Some(default);
self
}
pub fn with_items(mut self, item_type: ToolParameterType) -> Self {
self.items = Some(Box::new(item_type));
self
}
pub fn with_properties(mut self, properties: HashMap<String, ToolParameter>) -> Self {
self.properties = Some(properties);
self
}
pub fn to_json_schema(&self) -> serde_json::Value {
let mut schema = match &self.param_type {
ToolParameterType::String => serde_json::json!({"type": "string"}),
ToolParameterType::Number => serde_json::json!({"type": "number"}),
ToolParameterType::Integer => serde_json::json!({"type": "integer"}),
ToolParameterType::Boolean => serde_json::json!({"type": "boolean"}),
ToolParameterType::Array => {
let items_schema = self
.items
.as_ref()
.map(|t| param_type_to_schema(t))
.unwrap_or_else(|| serde_json::json!({}));
serde_json::json!({"type": "array", "items": items_schema})
}
ToolParameterType::Object => {
if let Some(props) = &self.properties {
let properties: serde_json::Map<String, serde_json::Value> = props
.iter()
.map(|(k, v)| (k.clone(), v.to_json_schema()))
.collect();
let required: Vec<&str> = props
.values()
.filter(|p| p.required)
.map(|p| p.name.as_str())
.collect();
serde_json::json!({
"type": "object",
"properties": properties,
"required": required
})
} else {
serde_json::json!({"type": "object"})
}
}
};
if let Some(desc) = &self.description {
schema["description"] = serde_json::Value::String(desc.clone());
}
schema
}
}
fn param_type_to_schema(t: &ToolParameterType) -> serde_json::Value {
match t {
ToolParameterType::String => serde_json::json!({"type": "string"}),
ToolParameterType::Number => serde_json::json!({"type": "number"}),
ToolParameterType::Integer => serde_json::json!({"type": "integer"}),
ToolParameterType::Boolean => serde_json::json!({"type": "boolean"}),
ToolParameterType::Array => serde_json::json!({"type": "array"}),
ToolParameterType::Object => serde_json::json!({"type": "object"}),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolMetadata {
pub name: String,
pub description: String,
pub parameters: Vec<ToolParameter>,
pub protocol_metadata: HashMap<String, serde_json::Value>,
}
impl ToolMetadata {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: Vec::new(),
protocol_metadata: HashMap::new(),
}
}
pub fn with_parameter(mut self, param: ToolParameter) -> Self {
self.parameters.push(param);
self
}
pub fn with_protocol_metadata(
mut self,
key: impl Into<String>,
value: serde_json::Value,
) -> Self {
self.protocol_metadata.insert(key.into(), value);
self
}
pub fn to_tool_definition(&self) -> ToolDefinition {
let mut properties = serde_json::Map::new();
let mut required: Vec<String> = Vec::new();
for param in &self.parameters {
properties.insert(param.name.clone(), param.to_json_schema());
if param.required {
required.push(param.name.clone());
}
}
let parameters_schema = serde_json::json!({
"type": "object",
"properties": properties,
"required": required
});
ToolDefinition {
name: self.name.clone(),
description: self.description.clone(),
parameters_schema,
}
}
}
#[async_trait]
pub trait ToolProtocol: Send + Sync {
async fn execute(
&self,
tool_name: &str,
parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>>;
async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>>;
async fn get_tool_metadata(
&self,
tool_name: &str,
) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>>;
fn protocol_name(&self) -> &str;
async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
Ok(())
}
async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
Ok(())
}
async fn list_resources(&self) -> Result<Vec<ResourceMetadata>, Box<dyn Error + Send + Sync>> {
Ok(Vec::new())
}
async fn read_resource(&self, uri: &str) -> Result<String, Box<dyn Error + Send + Sync>> {
Err(format!("Resource not found: {}", uri).into())
}
fn supports_resources(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub enum ToolError {
NotFound(String),
ExecutionFailed(String),
InvalidParameters(String),
ProtocolError(String),
}
impl fmt::Display for ToolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolError::NotFound(name) => write!(f, "Tool not found: {}", name),
ToolError::ExecutionFailed(msg) => write!(f, "Tool execution failed: {}", msg),
ToolError::InvalidParameters(msg) => write!(f, "Invalid parameters: {}", msg),
ToolError::ProtocolError(msg) => write!(f, "Protocol error: {}", msg),
}
}
}
impl Error for ToolError {}
pub struct Tool {
metadata: ToolMetadata,
protocol: Arc<dyn ToolProtocol>,
}
impl Tool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
protocol: Arc<dyn ToolProtocol>,
) -> Self {
Self {
metadata: ToolMetadata::new(name, description),
protocol,
}
}
pub fn with_parameter(mut self, param: ToolParameter) -> Self {
self.metadata.parameters.push(param);
self
}
pub fn with_protocol_metadata(
mut self,
key: impl Into<String>,
value: serde_json::Value,
) -> Self {
self.metadata.protocol_metadata.insert(key.into(), value);
self
}
pub fn metadata(&self) -> &ToolMetadata {
&self.metadata
}
pub async fn execute(
&self,
parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
self.protocol.execute(&self.metadata.name, parameters).await
}
}
pub struct ToolRegistry {
tools: HashMap<String, Tool>,
tool_to_protocol: HashMap<String, String>,
protocols: HashMap<String, Arc<dyn ToolProtocol>>,
primary_protocol: Option<Arc<dyn ToolProtocol>>,
}
impl ToolRegistry {
pub fn new(protocol: Arc<dyn ToolProtocol>) -> Self {
Self {
tools: HashMap::new(),
tool_to_protocol: HashMap::new(),
protocols: {
let mut m = HashMap::new();
m.insert("primary".to_string(), protocol.clone());
m
},
primary_protocol: Some(protocol),
}
}
pub fn empty() -> Self {
Self {
tools: HashMap::new(),
tool_to_protocol: HashMap::new(),
protocols: HashMap::new(),
primary_protocol: None,
}
}
pub async fn add_protocol(
&mut self,
protocol_name: &str,
protocol: Arc<dyn ToolProtocol>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let discovered_tools = protocol.list_tools().await?;
self.protocols
.insert(protocol_name.to_string(), protocol.clone());
for tool_meta in discovered_tools {
let tool_name = tool_meta.name.clone();
let tool = Tool::new(
tool_name.clone(),
tool_meta.description.clone(),
protocol.clone(),
);
let mut tool = tool;
for param in &tool_meta.parameters {
tool = tool.with_parameter(param.clone());
}
for (key, value) in &tool_meta.protocol_metadata {
tool = tool.with_protocol_metadata(key.clone(), value.clone());
}
self.tools.insert(tool_name.clone(), tool);
self.tool_to_protocol
.insert(tool_name, protocol_name.to_string());
}
Ok(())
}
pub fn remove_protocol(&mut self, protocol_name: &str) {
self.protocols.remove(protocol_name);
let tools_to_remove: Vec<String> = self
.tool_to_protocol
.iter()
.filter(|(_, pn)| *pn == protocol_name)
.map(|(tn, _)| tn.clone())
.collect();
for tool_name in tools_to_remove {
self.tools.remove(&tool_name);
self.tool_to_protocol.remove(&tool_name);
}
}
pub fn add_tool(&mut self, tool: Tool) {
self.tools.insert(tool.metadata.name.clone(), tool);
}
pub fn remove_tool(&mut self, name: &str) -> Option<Tool> {
self.tool_to_protocol.remove(name);
self.tools.remove(name)
}
pub fn get_tool(&self, name: &str) -> Option<&Tool> {
self.tools.get(name)
}
pub fn list_tools(&self) -> Vec<&ToolMetadata> {
self.tools.values().map(|t| &t.metadata).collect()
}
pub async fn discover_tools_from_primary(
&mut self,
) -> Result<(), Box<dyn Error + Send + Sync>> {
if let Some(protocol) = &self.primary_protocol {
let discovered_tools = protocol.list_tools().await?;
for tool_meta in discovered_tools {
let tool_name = tool_meta.name.clone();
let tool = Tool::new(
tool_name.clone(),
tool_meta.description.clone(),
protocol.clone(),
);
let mut tool = tool;
for param in &tool_meta.parameters {
tool = tool.with_parameter(param.clone());
}
for (key, value) in &tool_meta.protocol_metadata {
tool = tool.with_protocol_metadata(key.clone(), value.clone());
}
self.tools.insert(tool_name.clone(), tool);
self.tool_to_protocol
.insert(tool_name, "primary".to_string());
}
Ok(())
} else {
Err("No primary protocol available".into())
}
}
pub fn get_tool_protocol(&self, tool_name: &str) -> Option<&str> {
self.tool_to_protocol.get(tool_name).map(|s| s.as_str())
}
pub fn list_protocols(&self) -> Vec<&str> {
self.protocols.keys().map(|s| s.as_str()).collect()
}
pub async fn execute_tool(
&self,
tool_name: &str,
parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
let tool = self
.tools
.get(tool_name)
.ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
tool.execute(parameters).await
}
pub fn protocol(&self) -> Option<&Arc<dyn ToolProtocol>> {
self.primary_protocol.as_ref()
}
pub fn to_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|t| t.metadata.to_tool_definition())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockProtocol;
#[async_trait]
impl ToolProtocol for MockProtocol {
async fn execute(
&self,
tool_name: &str,
_parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
Ok(ToolResult::success(serde_json::json!({
"tool": tool_name,
"result": "mock_result"
})))
}
async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
Ok(vec![])
}
async fn get_tool_metadata(
&self,
_tool_name: &str,
) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
Ok(ToolMetadata::new("mock_tool", "A mock tool"))
}
fn protocol_name(&self) -> &str {
"mock"
}
}
#[tokio::test]
async fn test_get_tool_protocol() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("local", protocol.clone())
.await
.unwrap();
let tool = Tool::new("calculator", "Performs calculations", protocol.clone());
registry.add_tool(tool);
registry
.tool_to_protocol
.insert("calculator".to_string(), "local".to_string());
assert_eq!(registry.get_tool_protocol("calculator"), Some("local"));
assert_eq!(registry.get_tool_protocol("nonexistent"), None);
}
#[tokio::test]
async fn test_remove_protocol_removes_tools() {
let protocol = Arc::new(MockProtocol);
let mut registry = ToolRegistry::empty();
registry
.add_protocol("protocol1", protocol.clone())
.await
.unwrap();
let tool1 = Tool::new("tool1", "First tool", protocol.clone());
registry.add_tool(tool1);
registry
.tool_to_protocol
.insert("tool1".to_string(), "protocol1".to_string());
let tool2 = Tool::new("tool2", "Second tool", protocol.clone());
registry.add_tool(tool2);
registry
.tool_to_protocol
.insert("tool2".to_string(), "protocol1".to_string());
assert_eq!(registry.list_tools().len(), 2);
registry.remove_protocol("protocol1");
assert_eq!(registry.list_tools().len(), 0);
assert_eq!(registry.get_tool_protocol("tool1"), None);
assert_eq!(registry.get_tool_protocol("tool2"), None);
}
}