use std::convert::Infallible;
use std::sync::Arc;
use axum::extract::{DefaultBodyLimit, State};
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::{Json, Router};
use serde_json::json;
use tokio::net::TcpListener;
use tracing::{error, info};
use crate::errors::{MCSError, Result};
use crate::kg::GraphHandle;
use crate::server;
#[derive(Clone)]
pub struct HttpState {
kg: Arc<GraphHandle>,
auth_token: Option<Arc<str>>,
}
pub fn router(state: HttpState) -> Router {
Router::new()
.route("/mcp", post(post_handler).get(get_handler))
.route("/", post(post_handler).get(get_handler))
.layer(DefaultBodyLimit::max(server::MAX_REQUEST_BYTES))
.with_state(state)
}
pub async fn run(addr: &str, kg: Arc<GraphHandle>, auth_token: Option<Arc<str>>) -> Result<()> {
let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
info!(
"Listening for HTTP (Streamable) MCP on http://{addr}/mcp (auth {})",
if auth_token.is_some() { "on" } else { "off" }
);
let state = HttpState { kg, auth_token };
axum::serve(listener, router(state)).await.map_err(MCSError::IoError)?;
Ok(())
}
fn wants_sse(headers: &HeaderMap) -> bool {
headers
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.is_some_and(|a| a.contains("text/event-stream"))
}
fn authorized(state: &HttpState, headers: &HeaderMap) -> bool {
match state.auth_token {
None => true,
Some(ref expected) => headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.is_some_and(|presented| server::token_matches(presented, expected)),
}
}
async fn post_handler(State(state): State<HttpState>, headers: HeaderMap, body: String) -> Response {
if !authorized(&state, &headers) {
return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
}
let kg = state.kg;
let result = tokio::task::spawn_blocking(move || server::dispatch_http_body(&body, &kg)).await;
let outcome = match result {
Ok(inner) => inner,
Err(join_err) => {
error!("dispatch task panicked: {join_err}");
return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
}
};
match outcome {
Ok(None) => StatusCode::ACCEPTED.into_response(),
Ok(Some(value)) => {
if wants_sse(&headers) {
let json = serde_json::to_string(&value).unwrap();
let stream = futures::stream::once(async move {
Ok::<Event, Infallible>(Event::default().data(json))
});
Sse::new(stream).into_response()
} else {
Json(value).into_response()
}
}
Err(e) => {
let resp = json!({
"jsonrpc": "2.0",
"error": { "code": -32700, "message": format!("Parse error: {e}") },
"id": null
});
(StatusCode::BAD_REQUEST, Json(resp)).into_response()
}
}
}
async fn get_handler(State(state): State<HttpState>, headers: HeaderMap) -> Response {
if !authorized(&state, &headers) {
return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
}
let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Durability;
use axum::http::HeaderValue;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU32, Ordering};
fn state(token: Option<&str>) -> HttpState {
static SEQ: AtomicU32 = AtomicU32::new(0);
let path = PathBuf::from(std::env::temp_dir()).join(format!(
"mcp_mem_http_auth_{}_{}.bin",
std::process::id(),
SEQ.fetch_add(1, Ordering::SeqCst)
));
let kg = Arc::new(GraphHandle::new(&path, Durability::Async).unwrap());
HttpState {
kg,
auth_token: token.map(Arc::from),
}
}
fn with_auth(value: &'static str) -> HeaderMap {
let mut h = HeaderMap::new();
h.insert(header::AUTHORIZATION, HeaderValue::from_static(value));
h
}
#[test]
fn no_token_configured_allows_any_request() {
let s = state(None);
assert!(authorized(&s, &HeaderMap::new()));
assert!(authorized(&s, &with_auth("Bearer whatever")));
}
#[test]
fn token_required_rejects_missing_and_wrong() {
let s = state(Some("s3cr3t"));
assert!(!authorized(&s, &HeaderMap::new()), "missing header rejected");
assert!(!authorized(&s, &with_auth("Bearer wrong")), "wrong token rejected");
}
#[test]
fn token_required_accepts_correct_bearer() {
let s = state(Some("s3cr3t"));
assert!(authorized(&s, &with_auth("Bearer s3cr3t")));
assert!(authorized(&s, &with_auth("s3cr3t")));
}
}