#![deny(clippy::print_stdout)]
use std::collections::HashMap;
use std::sync::Arc;
use rmcp::ErrorData as McpError;
use rmcp::ServerHandler;
use rmcp::model::{
CallToolRequestParams, CallToolResult, Content, Implementation, JsonObject, ListToolsResult,
PaginatedRequestParams, ServerCapabilities, ServerInfo, Tool,
};
use rmcp::service::{RequestContext, RoleServer};
use serde_json::Value;
use crate::error::{OutrigError, Result};
use crate::mcp::{self, McpClient, McpTool, McpToolResult};
use crate::tool_name;
pub trait BackingClient: Send + Sync + 'static {
fn name(&self) -> &str;
fn list_tools(&self) -> impl Future<Output = Result<Vec<McpTool>>> + Send;
fn call_tool(
&self,
name: &str,
args: Value,
) -> impl Future<Output = Result<McpToolResult>> + Send;
}
impl BackingClient for McpClient {
fn name(&self) -> &str {
McpClient::name(self)
}
fn list_tools(&self) -> impl Future<Output = Result<Vec<McpTool>>> + Send {
McpClient::list_tools(self)
}
fn call_tool(
&self,
name: &str,
args: Value,
) -> impl Future<Output = Result<McpToolResult>> + Send {
McpClient::call_tool(self, name, args)
}
}
impl<T> BackingClient for Arc<T>
where
T: BackingClient + ?Sized,
{
fn name(&self) -> &str {
(**self).name()
}
fn list_tools(&self) -> impl Future<Output = Result<Vec<McpTool>>> + Send {
(**self).list_tools()
}
fn call_tool(
&self,
name: &str,
args: Value,
) -> impl Future<Output = Result<McpToolResult>> + Send {
(**self).call_tool(name, args)
}
}
#[derive(Debug, Clone)]
struct ToolEntry {
public_name: String,
backend_tool: String,
description: String,
input_schema: Arc<JsonObject>,
client_idx: usize,
}
struct ProxyInner<C> {
clients: Vec<C>,
tools: Vec<ToolEntry>,
by_public_name: HashMap<String, usize>,
server_info: ServerInfo,
}
pub struct ProxyServer<C = Arc<McpClient>> {
inner: Arc<ProxyInner<C>>,
}
impl<C> Clone for ProxyServer<C> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<C: BackingClient> ProxyServer<C> {
pub async fn build(clients: Vec<C>) -> Result<Self> {
let mut seen_names: HashMap<&str, usize> = HashMap::with_capacity(clients.len());
for (idx, client) in clients.iter().enumerate() {
let name = client.name();
if let Some(prev) = seen_names.insert(name, idx) {
return Err(OutrigError::Configuration(format!(
"mcp_proxy: duplicate backing-client name {name:?} \
(clients[{prev}] and clients[{idx}])"
)));
}
}
let mut tools: Vec<ToolEntry> = Vec::new();
let mut by_public_name: HashMap<String, usize> = HashMap::new();
for (client_idx, client) in clients.iter().enumerate() {
let server_name = client.name().to_string();
let upstream = client.list_tools().await?;
for tool in upstream {
let public_name = tool_name::sanitize(&server_name, &tool.name);
if let Some(prev_idx) = by_public_name.get(&public_name) {
let prev = &tools[*prev_idx];
let prev_server = clients[prev.client_idx].name();
return Err(OutrigError::Configuration(format!(
"mcp_proxy: public name {public_name:?} produced by both \
({prev_server:?}, {prev_tool:?}) and ({server_name:?}, {tool_name:?})",
prev_tool = prev.backend_tool,
tool_name = tool.name,
)));
}
let input_schema = match tool.input_schema {
Value::Object(map) => Arc::new(map),
other => {
return Err(OutrigError::Configuration(format!(
"mcp_proxy: tool {server_name:?}::{tool_name:?} input_schema is \
not a JSON object (got {kind})",
tool_name = tool.name,
kind = mcp::kind_of(&other),
)));
}
};
by_public_name.insert(public_name.clone(), tools.len());
tools.push(ToolEntry {
public_name,
backend_tool: tool.name,
description: tool.description.unwrap_or_default(),
input_schema,
client_idx,
});
}
}
let server_info = ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
.with_server_info(Implementation::new("outrig", env!("CARGO_PKG_VERSION")))
.with_instructions(
"Tools are namespaced as <server>__<tool>; the prefix identifies which \
backing MCP server hosts the tool.",
);
Ok(Self {
inner: Arc::new(ProxyInner {
clients,
tools,
by_public_name,
server_info,
}),
})
}
pub fn iter_public_names(&self) -> impl Iterator<Item = &str> {
self.inner.tools.iter().map(|t| t.public_name.as_str())
}
pub fn per_server_counts(&self) -> Vec<(&str, usize)> {
let mut counts = vec![0usize; self.inner.clients.len()];
for entry in &self.inner.tools {
counts[entry.client_idx] += 1;
}
self.inner
.clients
.iter()
.zip(counts)
.map(|(c, n)| (c.name(), n))
.collect()
}
pub fn list_tools_inner(&self) -> ListToolsResult {
let tools = self
.inner
.tools
.iter()
.map(|entry| {
Tool::new(
entry.public_name.clone(),
entry.description.clone(),
entry.input_schema.clone(),
)
})
.collect();
ListToolsResult {
next_cursor: None,
meta: None,
tools,
}
}
pub async fn dispatch_call(&self, request: CallToolRequestParams) -> CallToolResult {
let public_name = request.name.as_ref();
let Some(&idx) = self.inner.by_public_name.get(public_name) else {
return CallToolResult::error(vec![Content::text(format!(
"unknown tool: {public_name}"
))]);
};
let entry = &self.inner.tools[idx];
let args = request.arguments.map(Value::Object).unwrap_or(Value::Null);
let client = &self.inner.clients[entry.client_idx];
match client.call_tool(&entry.backend_tool, args).await {
Ok(result) if result.is_error => {
CallToolResult::error(vec![Content::text(result.content_text)])
}
Ok(result) => CallToolResult::success(vec![Content::text(result.content_text)]),
Err(e) => {
let server = client.name();
tracing::warn!(
target: "outrig::mcp_proxy",
"backing server {server:?} call to {tool:?} failed: {e}",
tool = entry.backend_tool,
);
CallToolResult::error(vec![Content::text(format!(
"outrig: backing server `{server}` call failed: {e}"
))])
}
}
}
}
impl<C: BackingClient> ServerHandler for ProxyServer<C> {
fn get_info(&self) -> ServerInfo {
self.inner.server_info.clone()
}
async fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> std::result::Result<ListToolsResult, McpError> {
Ok(self.list_tools_inner())
}
async fn call_tool(
&self,
request: CallToolRequestParams,
_ctx: RequestContext<RoleServer>,
) -> std::result::Result<CallToolResult, McpError> {
Ok(self.dispatch_call(request).await)
}
}