use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use axum::http::{HeaderValue, Method, StatusCode, header};
use axum::response::{IntoResponse, Json};
use axum::routing::{get, post};
use axum::{Router, middleware};
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use rmcp::transport::streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tower_http::cors::CorsLayer;
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;
use super::auth::{AuthConfig, bearer_auth_middleware};
use crate::server::ServerHandler;
use crate::{Error, Result};
#[derive(Debug, Serialize)]
struct HealthResponse {
status: &'static str,
version: &'static str,
}
#[derive(Debug, Deserialize)]
struct ReloadRequest {
#[serde(default)]
force: bool,
}
#[derive(Debug, Serialize)]
struct ReloadResponse {
success: bool,
message: String,
changed: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct HttpConfig {
pub cors_origin: Option<String>,
pub bearer_token: Option<String>,
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
cors_origin: Some("http://localhost:3000".to_string()),
bearer_token: None,
}
}
}
impl HttpConfig {
pub fn from_env() -> Self {
Self {
cors_origin: std::env::var("MCP_CORS_ORIGIN").ok(),
bearer_token: std::env::var("MCP_HTTP_BEARER_TOKEN").ok(),
}
}
}
pub async fn run_http(
handler: ServerHandler,
host: IpAddr,
port: u16,
shutdown: impl Future<Output = ()> + Send + 'static,
) -> Result<()> {
let addr = SocketAddr::new(host, port);
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();
let http_config = HttpConfig::from_env();
let auth_config = AuthConfig::new(http_config.bearer_token.clone());
emit_security_warnings(host, &http_config, &auth_config);
let session_manager = Arc::new(LocalSessionManager::default());
let config = StreamableHttpServerConfig {
cancellation_token: token_clone,
..Default::default()
};
let mcp_service =
StreamableHttpService::new(move || Ok(handler.clone()), session_manager, config);
let cors = build_cors_layer(&http_config);
#[allow(unused_mut)]
let mut app = Router::new().route("/health", get(health_handler));
#[cfg(feature = "metrics")]
{
app = app.route("/metrics", get(metrics_handler));
}
app = app.route("/admin/reload", post(admin_reload_handler));
let app = app
.nest_service("/mcp", mcp_service)
.layer(middleware::from_fn_with_state(
axum::extract::Extension(auth_config.clone()),
bearer_auth_middleware,
))
.layer(axum::extract::Extension(auth_config))
.layer(TraceLayer::new_for_http())
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
std::time::Duration::from_secs(60),
))
.layer(cors);
tracing::info!("HTTP server listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| Error::Transport(format!("Failed to bind to {addr}: {e}")))?;
tokio::spawn(async move {
shutdown.await;
cancellation_token.cancel();
});
axum::serve(listener, app)
.await
.map_err(|e| Error::Transport(format!("HTTP server error: {e}")))?;
tracing::info!("HTTP server shutdown complete");
Ok(())
}
fn build_cors_layer(config: &HttpConfig) -> CorsLayer {
config
.cors_origin
.as_ref()
.and_then(|o| o.parse::<HeaderValue>().ok())
.map_or_else(
|| {
CorsLayer::new()
.allow_origin(
"http://localhost:3000"
.parse::<HeaderValue>()
.expect("valid header"),
)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION])
},
|origin_value| {
CorsLayer::new()
.allow_origin(origin_value)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION])
},
)
}
fn emit_security_warnings(host: IpAddr, http_config: &HttpConfig, auth_config: &AuthConfig) {
let is_non_loopback = !host.is_loopback();
let is_all_interfaces = host == IpAddr::V4(Ipv4Addr::UNSPECIFIED)
|| host == IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED);
if is_all_interfaces {
tracing::warn!(
"HTTP server binding to all interfaces (0.0.0.0). \
This exposes the server to all network interfaces."
);
} else if is_non_loopback {
tracing::warn!(
"HTTP server binding to non-loopback address ({host}). \
Ensure network security policies are in place."
);
}
if !auth_config.is_enabled() && is_non_loopback {
tracing::warn!(
"SECURITY WARNING: HTTP server accessible from network without authentication. \
Set MCP_HTTP_BEARER_TOKEN environment variable to enable authentication."
);
}
if http_config.cors_origin.is_none() {
tracing::info!(
"CORS origin not configured (MCP_CORS_ORIGIN). \
Using restrictive default: http://localhost:3000"
);
}
}
async fn health_handler() -> impl IntoResponse {
Json(HealthResponse {
status: "ok",
version: env!("CARGO_PKG_VERSION"),
})
}
#[cfg(feature = "metrics")]
async fn metrics_handler() -> impl IntoResponse {
(
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
crate::observability::render_metrics(),
)
}
async fn admin_reload_handler(Json(payload): Json<ReloadRequest>) -> impl IntoResponse {
use crate::config::{ReloadResult, ReloadTrigger};
tracing::info!(
trigger = %ReloadTrigger::HttpEndpoint { remote_addr: None },
force = payload.force,
"Configuration reload requested"
);
let result = ReloadResult::success(vec![]);
let response = ReloadResponse {
success: result.success,
message: if result.success {
"Configuration reload acknowledged".to_string()
} else {
result.error.unwrap_or_default()
},
changed: result.changed,
};
if result.success {
(StatusCode::OK, Json(response))
} else {
(StatusCode::INTERNAL_SERVER_ERROR, Json(response))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_response_serialization() {
let response = HealthResponse {
status: "ok",
version: "0.3.2",
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("ok"));
assert!(json.contains("0.3.2"));
}
#[test]
fn test_http_config_default() {
let config = HttpConfig::default();
assert_eq!(
config.cors_origin,
Some("http://localhost:3000".to_string())
);
assert!(config.bearer_token.is_none());
}
#[test]
fn test_build_cors_layer_with_origin() {
let config = HttpConfig {
cors_origin: Some("https://example.com".to_string()),
bearer_token: None,
};
let _cors = build_cors_layer(&config);
}
#[test]
fn test_build_cors_layer_without_origin() {
let config = HttpConfig {
cors_origin: None,
bearer_token: None,
};
let _cors = build_cors_layer(&config);
}
#[test]
fn test_security_warnings_emitted_for_non_loopback() {
let host = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let http_config = HttpConfig::default();
let auth_config = AuthConfig::new(None);
emit_security_warnings(host, &http_config, &auth_config);
}
#[test]
fn test_loopback_no_warning_needed() {
let host = IpAddr::V4(Ipv4Addr::LOCALHOST);
let http_config = HttpConfig::default();
let auth_config = AuthConfig::new(Some("token".to_string()));
emit_security_warnings(host, &http_config, &auth_config);
}
}