use std::convert::Infallible;
use std::sync::Arc;
use axum::extract::{DefaultBodyLimit, State};
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::{Json, Router};
use serde_json::json;
use tokio::net::TcpListener;
use tracing::{error, info};
use crate::errors::{MCSError, Result};
use crate::kg::GraphHandle;
use crate::server;
#[derive(Clone)]
pub struct HttpState {
kg: Arc<GraphHandle>,
auth_token: Option<Arc<str>>,
}
pub fn router(state: HttpState) -> Router {
Router::new()
.route("/mcp", post(post_handler).get(get_handler))
.route("/", post(post_handler).get(get_handler))
.layer(DefaultBodyLimit::max(server::MAX_REQUEST_BYTES))
.with_state(state)
}
pub async fn run(addr: &str, kg: Arc<GraphHandle>, auth_token: Option<Arc<str>>) -> Result<()> {
let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
info!(
"Listening for HTTP (Streamable) MCP on http://{addr}/mcp (auth {})",
if auth_token.is_some() { "on" } else { "off" }
);
let state = HttpState { kg, auth_token };
axum::serve(listener, router(state)).await.map_err(MCSError::IoError)?;
Ok(())
}
fn wants_sse(headers: &HeaderMap) -> bool {
headers
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.is_some_and(|a| a.contains("text/event-stream"))
}
fn authorized(state: &HttpState, headers: &HeaderMap) -> bool {
match state.auth_token {
None => true,
Some(ref expected) => headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.is_some_and(|presented| server::token_matches(presented, expected)),
}
}
async fn post_handler(State(state): State<HttpState>, headers: HeaderMap, body: String) -> Response {
if !authorized(&state, &headers) {
return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
}
let kg = state.kg;
let result = tokio::task::spawn_blocking(move || server::dispatch_http_body(&body, &kg)).await;
let outcome = match result {
Ok(inner) => inner,
Err(join_err) => {
error!("dispatch task panicked: {join_err}");
return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
}
};
match outcome {
Ok(None) => StatusCode::ACCEPTED.into_response(),
Ok(Some(value)) => {
if wants_sse(&headers) {
let json = serde_json::to_string(&value).unwrap();
let stream = futures::stream::once(async move {
Ok::<Event, Infallible>(Event::default().data(json))
});
Sse::new(stream).into_response()
} else {
Json(value).into_response()
}
}
Err(e) => {
let resp = json!({
"jsonrpc": "2.0",
"error": { "code": -32700, "message": format!("Parse error: {e}") },
"id": null
});
(StatusCode::BAD_REQUEST, Json(resp)).into_response()
}
}
}
async fn get_handler(State(state): State<HttpState>, headers: HeaderMap) -> Response {
if !authorized(&state, &headers) {
return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
}
let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
}