1use std::convert::Infallible;
17use std::sync::Arc;
18
19use axum::extract::{DefaultBodyLimit, State};
20use axum::http::{header, HeaderMap, StatusCode};
21use axum::response::sse::{Event, KeepAlive, Sse};
22use axum::response::{IntoResponse, Response};
23use axum::routing::post;
24use axum::{Json, Router};
25use serde_json::json;
26use tokio::net::TcpListener;
27use tracing::{error, info};
28
29use crate::errors::{MCSError, Result};
30use crate::kg::GraphHandle;
31use crate::server;
32
33#[derive(Clone)]
36pub struct HttpState {
37 kg: Arc<GraphHandle>,
38 auth_token: Option<Arc<str>>,
39}
40
41pub fn router(state: HttpState) -> Router {
44 Router::new()
45 .route("/mcp", post(post_handler).get(get_handler))
46 .route("/", post(post_handler).get(get_handler))
47 .layer(DefaultBodyLimit::max(server::MAX_REQUEST_BYTES))
48 .with_state(state)
49}
50
51pub async fn run(addr: &str, kg: Arc<GraphHandle>, auth_token: Option<Arc<str>>) -> Result<()> {
53 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
54 info!(
55 "Listening for HTTP (Streamable) MCP on http://{addr}/mcp (auth {})",
56 if auth_token.is_some() { "on" } else { "off" }
57 );
58 let state = HttpState { kg, auth_token };
59 axum::serve(listener, router(state)).await.map_err(MCSError::IoError)?;
60 Ok(())
61}
62
63fn wants_sse(headers: &HeaderMap) -> bool {
64 headers
65 .get(header::ACCEPT)
66 .and_then(|v| v.to_str().ok())
67 .is_some_and(|a| a.contains("text/event-stream"))
68}
69
70fn authorized(state: &HttpState, headers: &HeaderMap) -> bool {
73 match state.auth_token {
74 None => true,
75 Some(ref expected) => headers
76 .get(header::AUTHORIZATION)
77 .and_then(|v| v.to_str().ok())
78 .is_some_and(|presented| server::token_matches(presented, expected)),
79 }
80}
81
82async fn post_handler(State(state): State<HttpState>, headers: HeaderMap, body: String) -> Response {
83 if !authorized(&state, &headers) {
84 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
85 }
86 let kg = state.kg;
87 let result = tokio::task::spawn_blocking(move || server::dispatch_http_body(&body, &kg)).await;
90
91 let outcome = match result {
92 Ok(inner) => inner,
93 Err(join_err) => {
94 error!("dispatch task panicked: {join_err}");
95 return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
96 }
97 };
98
99 match outcome {
100 Ok(None) => StatusCode::ACCEPTED.into_response(),
102 Ok(Some(value)) => {
103 if wants_sse(&headers) {
104 let json = serde_json::to_string(&value).unwrap();
106 let stream = futures::stream::once(async move {
107 Ok::<Event, Infallible>(Event::default().data(json))
108 });
109 Sse::new(stream).into_response()
110 } else {
111 Json(value).into_response()
112 }
113 }
114 Err(e) => {
115 let resp = json!({
117 "jsonrpc": "2.0",
118 "error": { "code": -32700, "message": format!("Parse error: {e}") },
119 "id": null
120 });
121 (StatusCode::BAD_REQUEST, Json(resp)).into_response()
122 }
123 }
124}
125
126async fn get_handler(State(state): State<HttpState>, headers: HeaderMap) -> Response {
127 if !authorized(&state, &headers) {
128 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
129 }
130 let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
132 Sse::new(stream)
133 .keep_alive(KeepAlive::default())
134 .into_response()
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::config::Durability;
141 use axum::http::HeaderValue;
142 use std::path::PathBuf;
143 use std::sync::atomic::{AtomicU32, Ordering};
144
145 fn state(token: Option<&str>) -> HttpState {
146 static SEQ: AtomicU32 = AtomicU32::new(0);
147 let path = PathBuf::from(std::env::temp_dir()).join(format!(
148 "mcp_mem_http_auth_{}_{}.bin",
149 std::process::id(),
150 SEQ.fetch_add(1, Ordering::SeqCst)
151 ));
152 let kg = Arc::new(GraphHandle::new(&path, Durability::Async).unwrap());
153 HttpState {
154 kg,
155 auth_token: token.map(Arc::from),
156 }
157 }
158
159 fn with_auth(value: &'static str) -> HeaderMap {
160 let mut h = HeaderMap::new();
161 h.insert(header::AUTHORIZATION, HeaderValue::from_static(value));
162 h
163 }
164
165 #[test]
166 fn no_token_configured_allows_any_request() {
167 let s = state(None);
168 assert!(authorized(&s, &HeaderMap::new()));
169 assert!(authorized(&s, &with_auth("Bearer whatever")));
170 }
171
172 #[test]
173 fn token_required_rejects_missing_and_wrong() {
174 let s = state(Some("s3cr3t"));
175 assert!(!authorized(&s, &HeaderMap::new()), "missing header rejected");
176 assert!(!authorized(&s, &with_auth("Bearer wrong")), "wrong token rejected");
177 }
178
179 #[test]
180 fn token_required_accepts_correct_bearer() {
181 let s = state(Some("s3cr3t"));
182 assert!(authorized(&s, &with_auth("Bearer s3cr3t")));
183 assert!(authorized(&s, &with_auth("s3cr3t")));
185 }
186}