use std::path::{Path, PathBuf};
use crate::error::Error;
use crate::router::McpRouter;
use crate::transport::http::{HttpTransport, SessionConfig, SessionHandle};
pub struct UnixSocketTransport {
inner: HttpTransport,
cleanup_on_bind: bool,
}
impl UnixSocketTransport {
pub fn new(router: McpRouter) -> Self {
Self {
inner: HttpTransport::new(router),
cleanup_on_bind: true,
}
}
pub fn from_service<S>(service: S) -> Self
where
S: tower::Service<
crate::router::RouterRequest,
Response = crate::router::RouterResponse,
Error = std::convert::Infallible,
> + Clone
+ Send
+ 'static,
S::Future: Send,
{
Self {
inner: HttpTransport::from_service(service),
cleanup_on_bind: true,
}
}
pub fn with_sampling(mut self) -> Self {
self.inner = self.inner.with_sampling();
self
}
pub fn require_sessions(mut self) -> Self {
self.inner = self.inner.require_sessions();
self
}
pub fn session_config(mut self, config: SessionConfig) -> Self {
self.inner = self.inner.session_config(config);
self
}
pub fn session_ttl(mut self, ttl: std::time::Duration) -> Self {
self.inner = self.inner.session_ttl(ttl);
self
}
pub fn max_sessions(mut self, max: usize) -> Self {
self.inner = self.inner.max_sessions(max);
self
}
pub fn session_store(
mut self,
store: std::sync::Arc<dyn crate::session_store::SessionStore>,
) -> Self {
self.inner = self.inner.session_store(store);
self
}
pub fn event_store(
mut self,
store: std::sync::Arc<dyn crate::event_store::EventStore>,
) -> Self {
self.inner = self.inner.event_store(store);
self
}
pub fn auto_reinitialize_sessions(mut self, enabled: bool) -> Self {
self.inner = self.inner.auto_reinitialize_sessions(enabled);
self
}
pub fn disable_origin_validation(mut self) -> Self {
self.inner = self.inner.disable_origin_validation();
self
}
pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
self.inner = self.inner.allowed_origins(origins);
self
}
pub fn disable_host_validation(mut self) -> Self {
self.inner = self.inner.disable_host_validation();
self
}
pub fn allowed_hosts(mut self, hosts: Vec<String>) -> Self {
self.inner = self.inner.allowed_hosts(hosts);
self
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<McpRouter> + Send + Sync + 'static,
L::Service: tower::Service<crate::router::RouterRequest, Response = crate::router::RouterResponse>
+ Clone
+ Send
+ 'static,
<L::Service as tower::Service<crate::router::RouterRequest>>::Error:
std::fmt::Display + Send,
<L::Service as tower::Service<crate::router::RouterRequest>>::Future: Send,
{
self.inner = self.inner.layer(layer);
self
}
pub fn cleanup_on_bind(mut self, cleanup: bool) -> Self {
self.cleanup_on_bind = cleanup;
self
}
pub fn into_router(self) -> axum::Router {
self.inner.into_router()
}
pub fn into_router_with_handle(self) -> (axum::Router, SessionHandle) {
self.inner.into_router_with_handle()
}
pub async fn serve<P: AsRef<Path>>(self, path: P) -> crate::Result<()> {
let path = path.as_ref().to_path_buf();
if self.cleanup_on_bind {
cleanup_socket(&path);
}
let listener = tokio::net::UnixListener::bind(&path).map_err(|e| {
Error::Transport(format!(
"Failed to bind Unix socket {}: {}",
path.display(),
e
))
})?;
tracing::info!("MCP Unix socket transport listening on {}", path.display());
let router = self.inner.into_router();
axum::serve(listener, router)
.await
.map_err(|e| Error::Transport(format!("Server error: {}", e)))?;
Ok(())
}
}
fn cleanup_socket(path: &PathBuf) {
match std::fs::remove_file(path) {
Ok(()) => {
tracing::debug!("Removed existing socket file: {}", path.display());
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => {
tracing::warn!(
"Failed to remove existing socket file {}: {}",
path.display(),
e
);
}
}
}