use std::convert::Infallible;
use std::fmt;
use tokio::sync::mpsc;
use tower::Layer;
use tower::util::BoxCloneService;
use crate::client::{ClientTransport, McpClient};
use crate::error::{Error, Result};
use crate::router::{RouterRequest, RouterResponse};
use crate::transport::CatchError;
use super::backend::{Backend, BackendService, ListChanged};
use super::service::{BackendEntry, McpProxy};
#[derive(Debug)]
pub struct SkippedBackend {
pub namespace: String,
pub error: Error,
pub phase: SkippedPhase,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SkippedPhase {
Connect,
Initialize,
}
impl fmt::Display for SkippedBackend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{}] {:?} failed: {}",
self.namespace, self.phase, self.error
)
}
}
pub struct ProxyBuildResult {
pub proxy: McpProxy,
pub skipped: Vec<SkippedBackend>,
}
struct PendingBackend {
namespace: String,
backend: Backend,
invalidation_rx: Option<mpsc::Receiver<ListChanged>>,
custom_service: Option<BoxCloneService<RouterRequest, RouterResponse, Infallible>>,
}
struct ConnectionFailure {
namespace: String,
error: Error,
}
pub struct McpProxyBuilder {
name: String,
version: String,
separator: String,
pending: Vec<PendingBackend>,
notification_tx: Option<crate::context::NotificationSender>,
connection_failures: Vec<ConnectionFailure>,
instructions: Option<String>,
}
impl McpProxyBuilder {
pub(crate) fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
Self {
name: name.into(),
version: version.into(),
separator: "_".to_string(),
pending: Vec::new(),
notification_tx: None,
connection_failures: Vec::new(),
instructions: None,
}
}
pub fn separator(mut self, sep: impl Into<String>) -> Self {
self.separator = sep.into();
self
}
pub fn notification_sender(mut self, tx: crate::context::NotificationSender) -> Self {
self.notification_tx = Some(tx);
self
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn backend_client(mut self, namespace: impl Into<String>, client: McpClient) -> Self {
let backend = Backend::from_client(namespace, client, self.separator.clone());
self.pending.push(PendingBackend {
namespace: backend.namespace.clone(),
backend,
invalidation_rx: None,
custom_service: None,
});
self
}
pub async fn backend(
mut self,
namespace: impl Into<String>,
transport: impl ClientTransport,
) -> Self {
let namespace = namespace.into();
let (invalidation_tx, invalidation_rx) = mpsc::channel(16);
match Backend::connect(
namespace.clone(),
transport,
self.separator.clone(),
invalidation_tx,
)
.await
{
Ok(backend) => {
self.pending.push(PendingBackend {
namespace,
backend,
invalidation_rx: Some(invalidation_rx),
custom_service: None,
});
}
Err(e) => {
tracing::error!(
namespace = %namespace,
error = %e,
"Failed to connect backend"
);
self.connection_failures.push(ConnectionFailure {
namespace,
error: e,
});
}
}
self
}
pub async fn backend_try(
mut self,
namespace: impl Into<String>,
transport: impl ClientTransport,
) -> Result<Self> {
let namespace = namespace.into();
let (invalidation_tx, invalidation_rx) = mpsc::channel(16);
let backend = Backend::connect(
namespace.clone(),
transport,
self.separator.clone(),
invalidation_tx,
)
.await?;
self.pending.push(PendingBackend {
namespace,
backend,
invalidation_rx: Some(invalidation_rx),
custom_service: None,
});
Ok(self)
}
pub fn backend_layer<L>(mut self, layer: L) -> Self
where
L: Layer<BackendService> + Send + 'static,
L::Service: tower_service::Service<RouterRequest, Response = RouterResponse>
+ Clone
+ Send
+ 'static,
<L::Service as tower_service::Service<RouterRequest>>::Error: fmt::Display + Send,
<L::Service as tower_service::Service<RouterRequest>>::Future: Send,
{
let pending = self
.pending
.last_mut()
.expect("backend_layer called before adding a backend");
let base = pending.backend.service();
let layered = layer.layer(base);
let caught = CatchError::new(layered);
pending.custom_service = Some(BoxCloneService::new(caught));
self
}
pub async fn build(mut self) -> Result<ProxyBuildResult> {
if self.pending.is_empty() {
return Err(Error::internal("No backends configured"));
}
for pb in &mut self.pending {
pb.backend.separator = self.separator.clone();
}
let namespaces: Vec<&str> = self
.pending
.iter()
.map(|pb| pb.namespace.as_str())
.collect();
{
let mut sorted = namespaces.clone();
sorted.sort();
sorted.dedup();
if sorted.len() != namespaces.len() {
return Err(Error::internal("Duplicate backend namespaces"));
}
}
let prefixes: Vec<String> = namespaces
.iter()
.map(|ns| format!("{}{}", ns, self.separator))
.collect();
for (i, prefix_i) in prefixes.iter().enumerate() {
for (j, prefix_j) in prefixes.iter().enumerate() {
if i != j && prefix_j.starts_with(prefix_i.as_str()) {
return Err(Error::internal(format!(
"Ambiguous namespace prefixes: \"{}\" and \"{}\" with separator \"{}\". \
The prefix \"{}\" is a prefix of \"{}\", which makes routing ambiguous. \
Use a different separator (e.g., \".\") or rename the namespaces.",
namespaces[i], namespaces[j], self.separator, prefix_i, prefix_j,
)));
}
}
}
let name = self.name.clone();
let version = self.version.clone();
let init_futures: Vec<_> = self
.pending
.into_iter()
.map(|mut pb| {
let name = name.clone();
let version = version.clone();
async move {
match pb.backend.initialize(&name, &version).await {
Ok(instructions) => {
pb.backend.instructions = instructions;
{
let cache = pb.backend.cache.read().await;
tracing::info!(
namespace = %pb.namespace,
tools = cache.tools.len(),
resources = cache.resources.len(),
prompts = cache.prompts.len(),
"Backend initialized"
);
}
Ok(pb)
}
Err(e) => {
tracing::error!(
namespace = %pb.namespace,
error = %e,
"Failed to initialize backend, skipping"
);
Err(SkippedBackend {
namespace: pb.namespace,
error: e,
phase: SkippedPhase::Initialize,
})
}
}
}
})
.collect();
let results = futures::future::join_all(init_futures).await;
let mut backends = Vec::new();
let mut entries = Vec::new();
let mut invalidation_rxs = Vec::new();
let mut skipped: Vec<SkippedBackend> = self
.connection_failures
.into_iter()
.map(|f| SkippedBackend {
namespace: f.namespace,
error: f.error,
phase: SkippedPhase::Connect,
})
.collect();
for result in results {
let pb = match result {
Ok(pb) => pb,
Err(s) => {
skipped.push(s);
continue;
}
};
let entry = if let Some(svc) = pb.custom_service {
BackendEntry::from_backend_with_service(&pb.backend, svc)
} else {
BackendEntry::from_backend(&pb.backend)
};
if let Some(rx) = pb.invalidation_rx {
invalidation_rxs.push((backends.len(), rx));
}
entries.push(entry);
backends.push(pb.backend);
}
if backends.is_empty() {
return Err(Error::internal("All backends failed to initialize"));
}
let instructions = if let Some(custom) = self.instructions {
Some(custom)
} else {
let mut parts = vec![format!(
"MCP proxy aggregating {} backend servers.",
backends.len()
)];
for b in &backends {
if let Some(inst) = &b.instructions {
parts.push(format!("[{}] {}", b.namespace, inst));
}
}
if parts.len() > 1 {
Some(parts.join("\n\n"))
} else {
Some(parts.remove(0))
}
};
let proxy = McpProxy::new(
self.name,
self.version,
backends,
entries,
self.notification_tx,
instructions,
self.separator.clone(),
);
for (backend_idx, rx) in invalidation_rxs {
proxy.spawn_invalidation_watcher(backend_idx, rx);
}
Ok(ProxyBuildResult { proxy, skipped })
}
pub async fn build_strict(self) -> Result<McpProxy> {
let result = self.build().await?;
if let Some(first) = result.skipped.into_iter().next() {
return Err(Error::internal(format!(
"Backend \"{}\" failed to initialize: {}",
first.namespace, first.error
)));
}
Ok(result.proxy)
}
}