use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::tool::{ToolError, ToolHandler};
use crate::{ToolDefinition, ToolRegistry};
#[derive(Debug, Clone, thiserror::Error)]
pub enum McpError {
#[error("MCP protocol error: {0}")]
Protocol(String),
#[error("MCP tool execution error: {0}")]
ToolExecution(String),
}
pub trait McpService: Send + Sync {
fn list_tools(
&self,
) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>;
fn call_tool(
&self,
name: &str,
args: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>>;
}
struct McpToolHandler {
service: Arc<dyn McpService>,
definition: ToolDefinition,
}
impl McpToolHandler {
fn new(service: Arc<dyn McpService>, definition: ToolDefinition) -> Self {
Self {
service,
definition,
}
}
}
impl std::fmt::Debug for McpToolHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpToolHandler")
.field("tool", &self.definition.name)
.finish_non_exhaustive()
}
}
impl ToolHandler<()> for McpToolHandler {
fn definition(&self) -> ToolDefinition {
self.definition.clone()
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
_ctx: &'a (),
) -> Pin<Box<dyn Future<Output = Result<crate::tool::ToolOutput, ToolError>> + Send + 'a>> {
Box::pin(async move {
self.service
.call_tool(&self.definition.name, input)
.await
.map(crate::tool::ToolOutput::new)
.map_err(|e| ToolError::new(e.to_string()))
})
}
}
pub trait McpRegistryExt {
fn register_mcp_service<S: McpService + 'static>(
&mut self,
service: &Arc<S>,
) -> impl Future<Output = Result<usize, McpError>> + Send;
fn register_mcp_tools_by_name<S: McpService + 'static>(
&mut self,
service: &Arc<S>,
tool_names: &[&str],
) -> impl Future<Output = Result<usize, McpError>> + Send;
}
impl McpRegistryExt for ToolRegistry<()> {
async fn register_mcp_service<S: McpService + 'static>(
&mut self,
service: &Arc<S>,
) -> Result<usize, McpError> {
let tools = service.list_tools().await?;
let count = tools.len();
for definition in tools {
let handler =
McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
self.register(handler);
}
Ok(count)
}
async fn register_mcp_tools_by_name<S: McpService + 'static>(
&mut self,
service: &Arc<S>,
tool_names: &[&str],
) -> Result<usize, McpError> {
let tools = service.list_tools().await?;
let mut count = 0;
for definition in tools {
if tool_names.contains(&definition.name.as_str()) {
let handler =
McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
self.register(handler);
count += 1;
}
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::JsonSchema;
use std::sync::atomic::{AtomicUsize, Ordering};
struct MockMcpService {
tools: Vec<ToolDefinition>,
call_count: AtomicUsize,
}
impl MockMcpService {
fn new(tools: Vec<ToolDefinition>) -> Self {
Self {
tools,
call_count: AtomicUsize::new(0),
}
}
}
impl McpService for MockMcpService {
fn list_tools(
&self,
) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>
{
let tools = self.tools.clone();
Box::pin(async move { Ok(tools) })
}
fn call_tool(
&self,
name: &str,
_args: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let result = format!("Called {name}");
Box::pin(async move { Ok(result) })
}
}
fn test_tool(name: &str) -> ToolDefinition {
ToolDefinition {
name: name.to_string(),
description: format!("{name} description"),
parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
retry: None,
}
}
#[test]
fn test_trait_is_object_safe() {
fn assert_object_safe(_: &dyn McpService) {}
let mock = MockMcpService::new(vec![]);
assert_object_safe(&mock);
}
#[tokio::test]
async fn test_register_mcp_service() {
let service = Arc::new(MockMcpService::new(vec![
test_tool("tool_a"),
test_tool("tool_b"),
]));
let mut registry = ToolRegistry::new();
let count = registry.register_mcp_service(&service).await.unwrap();
assert_eq!(count, 2);
assert_eq!(registry.len(), 2);
assert!(registry.get("tool_a").is_some());
assert!(registry.get("tool_b").is_some());
}
#[tokio::test]
async fn test_register_mcp_tools_by_name() {
let service = Arc::new(MockMcpService::new(vec![
test_tool("tool_a"),
test_tool("tool_b"),
test_tool("tool_c"),
]));
let mut registry = ToolRegistry::new();
let count = registry
.register_mcp_tools_by_name(&service, &["tool_a", "tool_c"])
.await
.unwrap();
assert_eq!(count, 2);
assert!(registry.get("tool_a").is_some());
assert!(registry.get("tool_b").is_none());
assert!(registry.get("tool_c").is_some());
}
#[tokio::test]
async fn test_mcp_tool_execution() {
let service = Arc::new(MockMcpService::new(vec![test_tool("my_tool")]));
let mut registry = ToolRegistry::new();
registry.register_mcp_service(&service).await.unwrap();
let handler = registry.get("my_tool").unwrap();
let result = handler.execute(serde_json::json!({}), &()).await.unwrap();
assert_eq!(result.content, "Called my_tool");
assert_eq!(service.call_count.load(Ordering::SeqCst), 1);
}
}