use crate::{CallToolResult, ListToolsResult, Tool, ToolProvider};
use async_trait::async_trait;
use protocol_transport_core::{
ProtocolError, SseTransport, Transport, TransportFactory, UniversalRequest,
};
use serde_json::json;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct McpProxyConfig {
pub servers: Vec<McpProxyTarget>,
pub proxy_auth: Option<String>,
pub timeout_seconds: u64,
}
#[derive(Debug, Clone)]
pub struct McpProxyTarget {
pub name: String,
pub sse_endpoint: String,
pub auth_token: Option<String>,
pub description: Option<String>,
}
pub struct McpProxy {
config: McpProxyConfig,
sse_transports: HashMap<String, SseTransport>,
}
impl McpProxy {
pub fn new(config: McpProxyConfig) -> Self {
let mut sse_transports = HashMap::new();
for server in &config.servers {
let transport = match &server.auth_token {
Some(token) => TransportFactory::mcp_sse_auth(&server.sse_endpoint, token),
None => TransportFactory::mcp_sse(&server.sse_endpoint),
};
sse_transports.insert(server.name.clone(), transport);
}
Self {
config,
sse_transports,
}
}
async fn send_to_server(
&self,
server_name: &str,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value, ProtocolError> {
let transport = self.sse_transports.get(server_name).ok_or_else(|| {
ProtocolError::internal_error(&format!("Unknown server: {}", server_name))
})?;
let request = UniversalRequest {
method: method.to_string(),
uri: "/".to_string(),
headers: HashMap::new(),
body: json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": 1
})
.to_string()
.into_bytes(),
protocol: "MCP".to_string(),
correlation_id: format!("{}-{}", method.replace("/", "-"), server_name),
};
let response = transport
.send(request)
.await
.map_err(|e| ProtocolError::internal_error(&format!("Transport error: {:?}", e)))?;
let response_body = String::from_utf8(response.body)
.map_err(|e| ProtocolError::Parsing(format!("Invalid UTF-8 response: {}", e)))?;
let response_json: serde_json::Value = serde_json::from_str(&response_body)
.map_err(|e| ProtocolError::Parsing(format!("Invalid JSON response: {}", e)))?;
response_json
.get("result")
.ok_or_else(|| ProtocolError::Parsing("Missing 'result' field".to_string()))
.map(|v| v.clone())
}
pub async fn list_tools_async(&self) -> Result<Vec<Tool>, ProtocolError> {
let mut all_tools = Vec::new();
for server in &self.config.servers {
match self
.send_to_server(&server.name, "tools/list", json!({}))
.await
{
Ok(result) => {
let list_result: ListToolsResult =
serde_json::from_value(result).map_err(|e| {
ProtocolError::Parsing(format!("Invalid tools list format: {}", e))
})?;
let mut tools = list_result.tools;
for tool in &mut tools {
tool.name = format!("{}:{}", server.name, tool.name);
}
all_tools.extend(tools);
}
Err(e) => {
log::warn!(
"Failed to list tools from proxy target '{}': {:?}",
server.name,
e
);
}
}
}
Ok(all_tools)
}
pub async fn call_tool_async(
&self,
name: &str,
arguments: Option<serde_json::Value>,
) -> Result<CallToolResult, ProtocolError> {
let parts: Vec<&str> = name.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(ProtocolError::internal_error(
"Tool name must be in format 'server:tool'",
));
}
let server_name = parts[0];
let tool_name = parts[1];
let params = json!({
"name": tool_name,
"arguments": arguments
});
let result = self
.send_to_server(server_name, "tools/call", params)
.await?;
let call_result: CallToolResult = serde_json::from_value(result).map_err(|e| {
ProtocolError::Parsing(format!("Invalid tool call result format: {}", e))
})?;
Ok(call_result)
}
pub async fn health_check_all(&self) -> HashMap<String, bool> {
let mut health_status = HashMap::new();
for server in &self.config.servers {
if let Some(transport) = self.sse_transports.get(&server.name) {
let is_healthy = transport.health_check().await.is_ok();
health_status.insert(server.name.clone(), is_healthy);
} else {
health_status.insert(server.name.clone(), false);
}
}
health_status
}
}
#[async_trait]
impl ToolProvider for McpProxy {
fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
Err(ProtocolError::internal_error(
"Async tool listing not supported in sync context. Use async proxy methods.",
))
}
async fn call_tool(
&self,
name: &str,
_arguments: Option<serde_json::Value>,
) -> Result<CallToolResult, ProtocolError> {
let parts: Vec<&str> = name.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(ProtocolError::internal_error(
"Tool name must be in format 'server:tool'",
));
}
Err(ProtocolError::internal_error(
"Async tool calls not supported in sync context. Use async proxy methods.",
))
}
}
pub struct McpProxyBuilder {
servers: Vec<McpProxyTarget>,
proxy_auth: Option<String>,
timeout_seconds: u64,
}
impl McpProxyBuilder {
pub fn new() -> Self {
Self {
servers: Vec::new(),
proxy_auth: None,
timeout_seconds: 30,
}
}
pub fn add_server(mut self, name: &str, sse_endpoint: &str) -> Self {
self.servers.push(McpProxyTarget {
name: name.to_string(),
sse_endpoint: sse_endpoint.to_string(),
auth_token: None,
description: None,
});
self
}
pub fn add_server_with_auth(
mut self,
name: &str,
sse_endpoint: &str,
auth_token: &str,
) -> Self {
self.servers.push(McpProxyTarget {
name: name.to_string(),
sse_endpoint: sse_endpoint.to_string(),
auth_token: Some(auth_token.to_string()),
description: None,
});
self
}
pub fn with_proxy_auth(mut self, auth_token: &str) -> Self {
self.proxy_auth = Some(auth_token.to_string());
self
}
pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
self.timeout_seconds = timeout_seconds;
self
}
pub fn build(self) -> McpProxy {
let config = McpProxyConfig {
servers: self.servers,
proxy_auth: self.proxy_auth,
timeout_seconds: self.timeout_seconds,
};
McpProxy::new(config)
}
}
impl Default for McpProxyBuilder {
fn default() -> Self {
Self::new()
}
}