use std::sync::Arc;
use async_trait::async_trait;
use super::client::McpClient;
use super::types::McpToolDefinition;
use crate::tool::{Tool, ToolDefinition, ToolError, ToolRegistry};
pub struct McpTool {
client: Arc<McpClient>,
mcp_definition: McpToolDefinition,
}
impl McpTool {
pub fn new(client: Arc<McpClient>, definition: McpToolDefinition) -> Self {
Self {
client,
mcp_definition: definition,
}
}
}
#[async_trait]
impl Tool for McpTool {
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: self.mcp_definition.name.clone(),
description: self.mcp_definition.description.clone().unwrap_or_default(),
input_schema: self.mcp_definition.input_schema.clone(),
}
}
async fn execute(&self, input: serde_json::Value) -> Result<String, ToolError> {
let arguments = if input.is_null() || input == serde_json::json!({}) {
None
} else {
Some(input)
};
let result = self
.client
.call_tool(&self.mcp_definition.name, arguments)
.await
.map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
let text: String = result
.content
.iter()
.filter_map(|c| c.as_text())
.collect::<Vec<_>>()
.join("\n");
if result.is_error {
Err(ToolError::ExecutionFailed(text))
} else {
Ok(text)
}
}
}
pub async fn register_mcp_tools(
registry: &mut ToolRegistry,
transport: Arc<dyn super::transport::McpTransport>,
) -> Result<Arc<McpClient>, super::client::McpError> {
let client = Arc::new(McpClient::new(transport));
client.initialize().await?;
let tools = client.list_tools().await?;
for tool_def in tools {
let mcp_tool = McpTool::new(client.clone(), tool_def);
registry.register(Box::new(mcp_tool));
}
Ok(client)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::transport::MockTransport;
use crate::mcp::types::{JsonRpcResponse, McpContent, McpToolDefinition, McpToolResult};
#[tokio::test]
async fn mcp_tool_implements_tool_trait() {
let (transport, mut req_rx, resp_tx) = MockTransport::new();
let client = Arc::new(McpClient::new(Arc::new(transport)));
let tool = McpTool::new(
client,
McpToolDefinition {
name: "greet".into(),
description: Some("Say hello".into()),
input_schema: serde_json::json!({"type": "object", "properties": {"name": {"type": "string"}}}),
},
);
let def = tool.definition();
assert_eq!(def.name, "greet");
assert_eq!(def.description, "Say hello");
let server_task = tokio::spawn(async move {
let req = req_rx.recv().await.unwrap();
assert_eq!(req.method, "tools/call");
resp_tx.send(JsonRpcResponse {
jsonrpc: "2.0".into(),
id: req.id,
result: Some(serde_json::to_value(McpToolResult {
content: vec![McpContent::Text { text: "Hello, Alice!".into() }],
is_error: false,
}).unwrap()),
error: None,
}).await.unwrap();
});
let result = tool.execute(serde_json::json!({"name": "Alice"})).await.unwrap();
assert_eq!(result, "Hello, Alice!");
server_task.await.unwrap();
}
#[tokio::test]
async fn mcp_tool_error_result() {
let (transport, mut req_rx, resp_tx) = MockTransport::new();
let client = Arc::new(McpClient::new(Arc::new(transport)));
let tool = McpTool::new(
client,
McpToolDefinition {
name: "fail".into(),
description: None,
input_schema: serde_json::json!({"type": "object"}),
},
);
let server_task = tokio::spawn(async move {
let req = req_rx.recv().await.unwrap();
resp_tx.send(JsonRpcResponse {
jsonrpc: "2.0".into(),
id: req.id,
result: Some(serde_json::to_value(McpToolResult {
content: vec![McpContent::Text { text: "something broke".into() }],
is_error: true,
}).unwrap()),
error: None,
}).await.unwrap();
});
let err = tool.execute(serde_json::json!({})).await.unwrap_err();
assert!(matches!(err, ToolError::ExecutionFailed(_)));
server_task.await.unwrap();
}
#[tokio::test]
async fn register_mcp_tools_populates_registry() {
let (transport, mut req_rx, resp_tx) = MockTransport::new();
let transport = Arc::new(transport);
let server_task = tokio::spawn(async move {
let req = req_rx.recv().await.unwrap();
resp_tx.send(JsonRpcResponse {
jsonrpc: "2.0".into(),
id: req.id,
result: Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test", "version": "1.0"}
})),
error: None,
}).await.unwrap();
let req = req_rx.recv().await.unwrap();
resp_tx.send(JsonRpcResponse {
jsonrpc: "2.0".into(),
id: req.id,
result: Some(serde_json::json!({
"tools": [
{"name": "tool_a", "description": "A", "inputSchema": {"type": "object"}},
{"name": "tool_b", "description": "B", "inputSchema": {"type": "object"}}
]
})),
error: None,
}).await.unwrap();
});
let mut registry = ToolRegistry::new();
register_mcp_tools(&mut registry, transport).await.unwrap();
assert_eq!(registry.len(), 2);
assert!(registry.get("tool_a").is_some());
assert!(registry.get("tool_b").is_some());
server_task.await.unwrap();
}
}