#[derive(Debug, Clone)]
pub enum TransportConfig {
Stdio,
#[cfg(feature = "http-transport")]
Http(HttpConfig),
}
#[cfg(feature = "http-transport")]
#[derive(Debug, Clone)]
pub struct HttpConfig {
pub port: u16,
pub bind: std::net::IpAddr,
pub auth: crate::mcp::auth::AuthConfig,
}
#[cfg(feature = "http-transport")]
impl HttpConfig {
#[must_use]
pub fn localhost(port: u16) -> Self {
Self {
port,
bind: std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
auth: crate::mcp::auth::AuthConfig::localhost_only(),
}
}
#[must_use]
pub fn with_bearer(port: u16, bind: std::net::IpAddr, token: String) -> Self {
Self {
port,
bind,
auth: crate::mcp::auth::AuthConfig::bearer(token),
}
}
}
pub async fn serve(config: TransportConfig) -> anyhow::Result<()> {
match config {
TransportConfig::Stdio => serve_stdio(),
#[cfg(feature = "http-transport")]
TransportConfig::Http(cfg) => serve_http(cfg).await,
}
}
fn serve_stdio() -> anyhow::Result<()> {
crate::mcp::server::run_stdio()
}
#[cfg(feature = "http-transport")]
mod http {
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use axum::extract::{ConnectInfo, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, NoContent, Response, Sse};
use axum::routing::post;
use axum::{Json, Router};
use serde_json::Value;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::StreamExt as _;
use tracing::{debug, error, info, warn};
use crate::mcp::auth::{AuthError, BearerValidator};
use crate::mcp::protocol::{JsonRpcRequest, JsonRpcResponse, RequestId, RpcError};
const SSE_CHANNEL_CAPACITY: usize = 64;
const SSE_KEEPALIVE: Duration = Duration::from_secs(15);
#[derive(Clone)]
pub(super) struct AppState {
validator: BearerValidator,
sse_tx: broadcast::Sender<SseEvent>,
}
#[derive(Debug, Clone)]
pub(super) struct SseEvent {
pub event: String,
pub data: String,
}
impl AppState {
pub fn new(validator: BearerValidator) -> Self {
let (sse_tx, _) = broadcast::channel(SSE_CHANNEL_CAPACITY);
Self { validator, sse_tx }
}
}
#[allow(clippy::result_large_err)]
fn check_auth(
headers: &HeaderMap,
peer: SocketAddr,
validator: &BearerValidator,
) -> Result<(), Response> {
if let Err(e) = validator.validate_source_ip(peer.ip()) {
warn!(%peer, "rejected non-localhost request: {e}");
return Err(unauthorized("Non-localhost request rejected"));
}
let raw = headers.get("Authorization").and_then(|v| v.to_str().ok());
if let Err(e) = validator.validate_header(raw) {
let msg = match e {
AuthError::MissingHeader => "Authorization header required",
AuthError::UnsupportedScheme => "Unsupported authorization scheme",
AuthError::InvalidToken => "Invalid bearer token",
_ => "Authorization failed",
};
warn!(%peer, "auth failure: {e}");
return Err(unauthorized(msg));
}
Ok(())
}
fn unauthorized(msg: &'static str) -> Response {
(
StatusCode::UNAUTHORIZED,
[("WWW-Authenticate", "Bearer")],
msg,
)
.into_response()
}
pub(super) async fn post_mcp(
ConnectInfo(peer): ConnectInfo<SocketAddr>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(body): Json<Value>,
) -> Response {
if let Err(resp) = check_auth(&headers, peer, &state.validator) {
return resp;
}
debug!(%peer, "POST /mcp");
let rpc_req: JsonRpcRequest = match serde_json::from_value(body) {
Ok(r) => r,
Err(e) => {
let resp = JsonRpcResponse::err(
RequestId::Number(0),
RpcError::new(RpcError::PARSE_ERROR, format!("Parse error: {e}")),
);
return Json(serde_json::to_value(&resp).unwrap_or(Value::Null)).into_response();
}
};
let mut sink = Vec::<u8>::new();
let maybe_resp = {
let mut server = crate::mcp::server::ServerHandle::new();
server.handle(&rpc_req, &mut sink)
};
if !sink.is_empty() {
if let Ok(notifications) = String::from_utf8(sink) {
for line in notifications.lines() {
if !line.is_empty() {
let _ = state.sse_tx.send(SseEvent {
event: "notification".into(),
data: line.to_owned(),
});
}
}
}
}
match maybe_resp {
Some(resp) => match serde_json::to_value(&resp) {
Ok(v) => Json(v).into_response(),
Err(e) => {
error!("response serialization failed: {e}");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
},
None => NoContent.into_response(),
}
}
pub(super) async fn get_mcp_sse(
ConnectInfo(peer): ConnectInfo<SocketAddr>,
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Response {
if let Err(resp) = check_auth(&headers, peer, &state.validator) {
return resp;
}
info!(%peer, "SSE client connected");
let rx = state.sse_tx.subscribe();
let stream = BroadcastStream::new(rx).filter_map(|result| {
result.ok().map(|ev| {
Ok::<_, Infallible>(
axum::response::sse::Event::default()
.event(ev.event)
.data(ev.data),
)
})
});
Sse::new(stream)
.keep_alive(
axum::response::sse::KeepAlive::new()
.interval(SSE_KEEPALIVE)
.text("keep-alive"),
)
.into_response()
}
async fn method_not_allowed() -> Response {
(
StatusCode::METHOD_NOT_ALLOWED,
[("Allow", "GET, POST")],
"/mcp only accepts GET (SSE stream) and POST (JSON-RPC)",
)
.into_response()
}
pub(super) async fn start(cfg: super::HttpConfig) -> anyhow::Result<()> {
use anyhow::Context as _;
let validator = BearerValidator::new(cfg.auth.clone());
validator
.check_bind_safety(cfg.bind)
.context("unsafe server configuration")?;
print_startup_banner(&cfg, &cfg.auth);
let state = Arc::new(AppState::new(validator));
let addr = SocketAddr::new(cfg.bind, cfg.port);
let mcp_route = post(post_mcp).get(get_mcp_sse).fallback(method_not_allowed);
let app = Router::new()
.route("/mcp", mcp_route)
.with_state(state)
.into_make_service_with_connect_info::<SocketAddr>();
let listener = tokio::net::TcpListener::bind(addr)
.await
.with_context(|| format!("Failed to bind to {addr}"))?;
info!(%addr, "MCP HTTP server listening");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.context("HTTP server error")?;
info!("MCP HTTP server stopped");
Ok(())
}
fn print_startup_banner(cfg: &super::HttpConfig, auth: &crate::mcp::auth::AuthConfig) {
eprintln!("MCP HTTP server starting");
eprintln!(" Address : http://{}:{}/mcp", cfg.bind, cfg.port);
match auth {
crate::mcp::auth::AuthConfig::LocalhostOnly => {
eprintln!(" Auth : localhost-only (no token required)");
}
crate::mcp::auth::AuthConfig::Bearer(token) => {
eprintln!(" Auth : Bearer token");
eprintln!(" Token : {token}");
eprintln!();
eprintln!(" Add this to your MCP client config:");
eprintln!(" Authorization: Bearer {token}");
}
}
eprintln!();
}
async fn shutdown_signal() {
let _ = tokio::signal::ctrl_c().await;
info!("shutdown signal received");
}
}
#[cfg(feature = "http-transport")]
async fn serve_http(cfg: HttpConfig) -> anyhow::Result<()> {
http::start(cfg).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transport_config_stdio_variant_exists() {
let _cfg = TransportConfig::Stdio;
}
#[cfg(feature = "http-transport")]
mod http_tests {
use super::super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn http_config_localhost_binds_to_loopback() {
let cfg = HttpConfig::localhost(8741);
assert!(cfg.bind.is_loopback());
assert_eq!(cfg.port, 8741);
assert!(cfg.auth.is_localhost_only());
}
#[test]
fn http_config_with_bearer_stores_token() {
let cfg =
HttpConfig::with_bearer(9000, IpAddr::V4(Ipv4Addr::LOCALHOST), "axt_tok".into());
assert!(cfg.auth.is_bearer());
assert_eq!(cfg.port, 9000);
}
#[tokio::test]
async fn serve_refuses_unsafe_config() {
use crate::mcp::auth::AuthConfig;
let cfg = HttpConfig {
port: 19999,
bind: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
auth: AuthConfig::localhost_only(),
};
let result = serve_http(cfg).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("unsafe") || msg.contains("configuration"),
"{msg}"
);
}
}
}