use crate::error::Result;
use crate::mcp_server::is_mcp_tool;
use crate::tool_types::{BuiltinTool, ToolCall, ToolDefinition, ToolHints};
use crate::tools::{Tool, ToolExecutionResult};
use crate::traits::ToolContext;
use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
#[async_trait]
pub trait McpToolInvoker: Send + Sync {
async fn invoke(&self, tool_call: &ToolCall) -> Result<crate::tool_types::ToolResult>;
}
pub struct McpProxyTool {
definition: BuiltinTool,
invoker: Arc<dyn McpToolInvoker>,
}
impl McpProxyTool {
pub fn new(definition: BuiltinTool, invoker: Arc<dyn McpToolInvoker>) -> Self {
Self {
definition,
invoker,
}
}
async fn invoke(&self, tool_call_id: String, arguments: Value) -> ToolExecutionResult {
let call = ToolCall {
id: tool_call_id,
name: self.definition.name.clone(),
arguments,
};
match self.invoker.invoke(&call).await {
Ok(result) => tool_result_to_execution(result),
Err(error) => ToolExecutionResult::tool_error(error.to_string()),
}
}
}
#[async_trait]
impl Tool for McpProxyTool {
fn name(&self) -> &str {
&self.definition.name
}
fn display_name(&self) -> Option<&str> {
self.definition.display_name.as_deref()
}
fn description(&self) -> &str {
&self.definition.description
}
fn parameters_schema(&self) -> Value {
self.definition.parameters.clone()
}
fn hints(&self) -> ToolHints {
self.definition.hints.clone()
}
fn requires_context(&self) -> bool {
true
}
fn to_definition(&self) -> ToolDefinition {
ToolDefinition::Builtin(self.definition.clone())
}
async fn execute(&self, arguments: Value) -> ToolExecutionResult {
self.invoke(String::new(), arguments).await
}
async fn execute_with_context(
&self,
arguments: Value,
context: &ToolContext,
) -> ToolExecutionResult {
let tool_call_id = context.tool_call_id.clone().unwrap_or_default();
self.invoke(tool_call_id, arguments).await
}
}
pub fn build_mcp_proxy_tools(
definitions: &[ToolDefinition],
invoker: Arc<dyn McpToolInvoker>,
) -> Vec<Box<dyn Tool>> {
definitions
.iter()
.filter(|def| is_mcp_tool(def.name()))
.filter_map(|def| match def {
ToolDefinition::Builtin(builtin) => {
Some(Box::new(McpProxyTool::new(builtin.clone(), invoker.clone())) as Box<dyn Tool>)
}
ToolDefinition::ClientSide(_) => None,
})
.collect()
}
fn tool_result_to_execution(result: crate::tool_types::ToolResult) -> ToolExecutionResult {
if let Some(provider) = result.connection_required {
return ToolExecutionResult::ConnectionRequired { provider };
}
if let Some(error) = result.error {
return ToolExecutionResult::ToolError(error);
}
let value = result.result.unwrap_or(Value::Null);
match result.images {
Some(images) if !images.is_empty() => ToolExecutionResult::SuccessWithImages {
result: value,
images,
},
_ => ToolExecutionResult::Success(value),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool_types::{DeferrablePolicy, ToolPolicy, ToolResult};
use std::sync::Mutex;
fn mcp_def(name: &str) -> ToolDefinition {
ToolDefinition::Builtin(BuiltinTool {
name: name.to_string(),
display_name: None,
description: "an mcp tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": { "q": { "type": "string" } }
}),
policy: ToolPolicy::Auto,
category: Some("MCP Servers".to_string()),
deferrable: DeferrablePolicy::Automatic,
hints: ToolHints::default().with_open_world(true),
})
}
struct RecordingInvoker {
calls: Mutex<Vec<ToolCall>>,
result: ToolResult,
}
#[async_trait]
impl McpToolInvoker for RecordingInvoker {
async fn invoke(&self, tool_call: &ToolCall) -> Result<ToolResult> {
self.calls.lock().unwrap().push(tool_call.clone());
Ok(self.result.clone())
}
}
fn ok_result(value: Value) -> ToolResult {
ToolResult {
tool_call_id: String::new(),
result: Some(value),
images: None,
error: None,
connection_required: None,
raw_output: None,
}
}
#[test]
fn build_proxies_only_for_mcp_tools() {
let defs = vec![
mcp_def("mcp_docs__search"),
ToolDefinition::Builtin(BuiltinTool {
name: "read_file".to_string(),
display_name: None,
description: "read".to_string(),
parameters: serde_json::json!({}),
policy: ToolPolicy::Auto,
category: None,
deferrable: DeferrablePolicy::Automatic,
hints: ToolHints::default(),
}),
];
let invoker = Arc::new(RecordingInvoker {
calls: Mutex::new(vec![]),
result: ok_result(serde_json::json!({})),
});
let tools = build_mcp_proxy_tools(&defs, invoker);
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name(), "mcp_docs__search");
assert!(tools[0].parameters_schema()["properties"]["q"].is_object());
}
#[tokio::test]
async fn proxy_delegates_to_invoker_and_maps_success() {
let invoker = Arc::new(RecordingInvoker {
calls: Mutex::new(vec![]),
result: ok_result(serde_json::json!({ "answer": 42 })),
});
let tool = McpProxyTool::new(
match mcp_def("mcp_docs__search") {
ToolDefinition::Builtin(b) => b,
_ => unreachable!(),
},
invoker.clone(),
);
let mut ctx = ToolContext::new(uuid::Uuid::new_v4().into());
ctx.tool_call_id = Some("call_1".to_string());
let result = tool
.execute_with_context(serde_json::json!({ "q": "hi" }), &ctx)
.await;
match result {
ToolExecutionResult::Success(v) => assert_eq!(v["answer"], 42),
other => panic!("expected success, got {other:?}"),
}
let calls = invoker.calls.lock().unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "mcp_docs__search");
assert_eq!(calls[0].id, "call_1");
assert_eq!(calls[0].arguments["q"], "hi");
}
#[tokio::test]
async fn proxy_maps_tool_error() {
let invoker = Arc::new(RecordingInvoker {
calls: Mutex::new(vec![]),
result: ToolResult {
tool_call_id: String::new(),
result: Some(serde_json::json!({ "error": "boom" })),
images: None,
error: Some("boom".to_string()),
connection_required: None,
raw_output: None,
},
});
let tool = McpProxyTool::new(
match mcp_def("mcp_docs__search") {
ToolDefinition::Builtin(b) => b,
_ => unreachable!(),
},
invoker,
);
let result = tool.execute(serde_json::json!({})).await;
assert!(matches!(result, ToolExecutionResult::ToolError(ref m) if m == "boom"));
}
#[tokio::test]
async fn proxy_maps_invoker_error_to_tool_error() {
struct FailingInvoker;
#[async_trait]
impl McpToolInvoker for FailingInvoker {
async fn invoke(&self, _tool_call: &ToolCall) -> Result<ToolResult> {
Err(crate::error::AgentLoopError::tool(
"MCP server not found for prefix: docs",
))
}
}
let tool = McpProxyTool::new(
match mcp_def("mcp_docs__search") {
ToolDefinition::Builtin(b) => b,
_ => unreachable!(),
},
Arc::new(FailingInvoker),
);
let result = tool.execute(serde_json::json!({})).await;
match result {
ToolExecutionResult::ToolError(m) => assert!(m.contains("MCP server not found")),
other => panic!("expected ToolError, got {other:?}"),
}
}
#[test]
fn mcp_tool_is_first_class_in_registry() {
use crate::tools::ToolRegistry;
let invoker = Arc::new(RecordingInvoker {
calls: Mutex::new(vec![]),
result: ok_result(serde_json::json!({})),
});
let defs = vec![mcp_def("mcp_docs__search")];
let mut registry = ToolRegistry::new();
for tool in build_mcp_proxy_tools(&defs, invoker) {
registry.register_boxed(tool);
}
assert!(registry.has("mcp_docs__search"));
let registry_defs = registry.tool_definitions();
let found = registry_defs
.iter()
.find(|d| d.name() == "mcp_docs__search")
.expect("MCP tool visible via registry introspection");
assert!(found.parameters()["properties"]["q"].is_object());
}
}