use std::sync::Arc;
use anyhow::Result;
use axum::{
Router,
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::{any_service, delete},
};
use rmcp::transport::streamable_http_server::{
SessionManager, StreamableHttpService, session::local::LocalSessionManager,
};
use super::middleware::{OriginConfig, with_guards};
use crate::service::RudofMcpService;
pub async fn run_mcp_http(
bind_address: &str,
port: u16,
route_path: &str,
allowed_networks: Option<Vec<String>>,
) -> Result<()> {
let bind_addr = format_bind_address(bind_address, port);
let canonical_uri = format_canonical_uri(bind_address, port, route_path);
let origin_config = match allowed_networks {
Some(networks) if !networks.is_empty() => {
OriginConfig::new(networks).map_err(|e| anyhow::anyhow!("Invalid network configuration: {}", e))?
},
_ => {
tracing::info!("No custom networks specified, using localhost-only configuration");
OriginConfig::localhost_only()
},
};
let session_manager = Arc::new(LocalSessionManager::default());
let mcp_service_factory = move || Ok(RudofMcpService::new());
let rmcp_service = StreamableHttpService::new(mcp_service_factory, session_manager.clone(), Default::default());
let router = Router::new()
.route(
route_path,
delete({
let sm = session_manager.clone();
move |headers| handle_delete_session(headers, sm.clone())
})
.fallback_service(any_service(rmcp_service)),
);
let guarded_router = with_guards(router, origin_config);
let listener = std::net::TcpListener::bind(&bind_addr)?;
let server = axum_server::Server::from_tcp(listener).serve(guarded_router.into_make_service());
tracing::info!("MCP HTTP server listening on {}", canonical_uri);
tokio::select! {
result = server => {
if let Err(e) = result {
tracing::error!("Server error: {}", e);
}
}
_ = tokio::signal::ctrl_c() => {
tracing::debug!("Shutdown signal received, stopping HTTP server...");
}
}
Ok(())
}
async fn handle_delete_session(headers: HeaderMap, session_manager: Arc<LocalSessionManager>) -> impl IntoResponse {
match headers.get("Mcp-Session-Id").and_then(|v| v.to_str().ok()) {
Some(id) => {
let id_arc = Arc::from(id.to_string());
match session_manager.close_session(&id_arc).await {
Ok(()) => {
tracing::info!(session_id = %id, "Session terminated successfully");
(StatusCode::NO_CONTENT, "").into_response()
},
Err(e) => {
tracing::error!(session_id = %id, error = %e, "Session not found or already expired");
(StatusCode::NOT_FOUND, "Session not found or already expired").into_response()
},
}
},
None => {
tracing::error!("Missing Mcp-Session-Id header in DELETE request");
(StatusCode::BAD_REQUEST, "Missing Mcp-Session-Id header").into_response()
},
}
}
fn format_bind_address(address: &str, port: u16) -> String {
if address.contains(':') {
format!("[{}]:{}", address, port)
} else {
format!("{}:{}", address, port)
}
}
fn format_canonical_uri(address: &str, port: u16, route_path: &str) -> String {
if address == "0.0.0.0" {
format!("http://127.0.0.1:{}{}", port, route_path)
} else if address == "::" {
format!("http://[::1]:{}{}", port, route_path)
} else if address.contains(':') {
format!("http://[{}]:{}{}", address, port, route_path)
} else {
format!("http://{}:{}{}", address, port, route_path)
}
}