use std::convert::Infallible;
use std::sync::Arc;
use axum::extract::{DefaultBodyLimit, State};
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::{Json, Router};
use futures::Stream;
use serde_json::json;
use tokio::net::TcpListener;
use tracing::{error, info};
use crate::errors::{MCSError, Result};
use crate::kg::GraphHandle;
use crate::server;
type AppState = Arc<GraphHandle>;
pub fn router(kg: AppState) -> Router {
Router::new()
.route("/mcp", post(post_handler).get(get_handler))
.route("/", post(post_handler).get(get_handler))
.layer(DefaultBodyLimit::max(server::MAX_REQUEST_BYTES))
.with_state(kg)
}
pub async fn run(addr: &str, kg: AppState) -> Result<()> {
let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
info!("Listening for HTTP (Streamable) MCP on http://{addr}/mcp");
axum::serve(listener, router(kg)).await.map_err(MCSError::IoError)?;
Ok(())
}
fn wants_sse(headers: &HeaderMap) -> bool {
headers
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.is_some_and(|a| a.contains("text/event-stream"))
}
async fn post_handler(State(kg): State<AppState>, headers: HeaderMap, body: String) -> Response {
let result = tokio::task::spawn_blocking(move || server::dispatch_http_body(&body, &kg)).await;
let outcome = match result {
Ok(inner) => inner,
Err(join_err) => {
error!("dispatch task panicked: {join_err}");
return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
}
};
match outcome {
Ok(None) => StatusCode::ACCEPTED.into_response(),
Ok(Some(value)) => {
if wants_sse(&headers) {
let json = serde_json::to_string(&value).unwrap();
let stream = futures::stream::once(async move {
Ok::<Event, Infallible>(Event::default().data(json))
});
Sse::new(stream).into_response()
} else {
Json(value).into_response()
}
}
Err(e) => {
let resp = json!({
"jsonrpc": "2.0",
"error": { "code": -32700, "message": format!("Parse error: {e}") },
"id": null
});
(StatusCode::BAD_REQUEST, Json(resp)).into_response()
}
}
}
async fn get_handler() -> Sse<impl Stream<Item = std::result::Result<Event, Infallible>>> {
let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
Sse::new(stream).keep_alive(KeepAlive::default())
}