1use std::convert::Infallible;
17use std::sync::Arc;
18
19use axum::extract::{DefaultBodyLimit, State};
20use axum::http::{header, HeaderMap, StatusCode};
21use axum::response::sse::{Event, KeepAlive, Sse};
22use axum::response::{IntoResponse, Response};
23use axum::routing::post;
24use axum::{Json, Router};
25use serde_json::json;
26use tokio::net::TcpListener;
27use tracing::{error, info};
28
29use crate::errors::{MCSError, Result};
30use crate::kg::GraphHandle;
31use crate::server;
32use crate::vector_store::VectorStore;
33
34#[derive(Clone)]
37pub struct HttpState {
38 kg: Arc<GraphHandle>,
39 vs: Option<Arc<VectorStore>>,
40 auth_token: Option<Arc<str>>,
41}
42
43pub fn router(state: HttpState) -> Router {
46 Router::new()
47 .route("/mcp", post(post_handler).get(get_handler))
48 .route("/", post(post_handler).get(get_handler))
49 .layer(DefaultBodyLimit::max(server::MAX_REQUEST_BYTES))
50 .with_state(state)
51}
52
53pub async fn run(
59 addr: &str,
60 kg: Arc<GraphHandle>,
61 vs: Option<Arc<VectorStore>>,
62 auth_token: Option<Arc<str>>,
63 tls_cert: Option<std::path::PathBuf>,
64 tls_key: Option<std::path::PathBuf>,
65) -> Result<()> {
66 let auth = if auth_token.is_some() { "on" } else { "off" };
67 let state = HttpState { kg, vs, auth_token };
68
69 if let (Some(cert), Some(key)) = (tls_cert, tls_key) {
70 let tls = crate::tls::server_config(&cert, &key)
71 .await
72 .map_err(MCSError::IoError)?;
73 let socket_addr = resolve_addr(addr)?;
74 info!("Listening for HTTPS (Streamable) MCP on https://{socket_addr}/mcp (TLS, auth {auth})");
75 axum_server::bind_rustls(socket_addr, tls)
76 .serve(router(state).into_make_service())
77 .await
78 .map_err(MCSError::IoError)?;
79 } else {
80 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
81 info!("Listening for HTTP (Streamable) MCP on http://{addr}/mcp (auth {auth})");
82 axum::serve(listener, router(state))
83 .await
84 .map_err(MCSError::IoError)?;
85 }
86 Ok(())
87}
88
89fn resolve_addr(addr: &str) -> Result<std::net::SocketAddr> {
92 use std::net::ToSocketAddrs;
93 addr.to_socket_addrs()
94 .map_err(MCSError::IoError)?
95 .next()
96 .ok_or_else(|| {
97 MCSError::IoError(std::io::Error::new(
98 std::io::ErrorKind::InvalidInput,
99 format!("could not resolve bind address '{addr}'"),
100 ))
101 })
102}
103
104fn wants_sse(headers: &HeaderMap) -> bool {
105 headers
106 .get(header::ACCEPT)
107 .and_then(|v| v.to_str().ok())
108 .is_some_and(|a| a.contains("text/event-stream"))
109}
110
111fn authorized(state: &HttpState, headers: &HeaderMap) -> bool {
114 match state.auth_token {
115 None => true,
116 Some(ref expected) => headers
117 .get(header::AUTHORIZATION)
118 .and_then(|v| v.to_str().ok())
119 .is_some_and(|presented| server::token_matches(presented, expected)),
120 }
121}
122
123async fn post_handler(State(state): State<HttpState>, headers: HeaderMap, body: String) -> Response {
124 if !authorized(&state, &headers) {
125 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
126 }
127 let kg = state.kg;
128 let vs = state.vs;
129 let result = tokio::task::spawn_blocking(move || {
132 server::dispatch_http_body(&body, &kg, vs.as_deref())
133 })
134 .await;
135
136 let outcome = match result {
137 Ok(inner) => inner,
138 Err(join_err) => {
139 error!("dispatch task panicked: {join_err}");
140 return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
141 }
142 };
143
144 match outcome {
145 Ok(None) => StatusCode::ACCEPTED.into_response(),
147 Ok(Some(value)) => {
148 if wants_sse(&headers) {
149 let json = serde_json::to_string(&value).unwrap();
151 let stream = futures::stream::once(async move {
152 Ok::<Event, Infallible>(Event::default().data(json))
153 });
154 Sse::new(stream).into_response()
155 } else {
156 Json(value).into_response()
157 }
158 }
159 Err(e) => {
160 let resp = json!({
162 "jsonrpc": "2.0",
163 "error": { "code": -32700, "message": format!("Parse error: {e}") },
164 "id": null
165 });
166 (StatusCode::BAD_REQUEST, Json(resp)).into_response()
167 }
168 }
169}
170
171async fn get_handler(State(state): State<HttpState>, headers: HeaderMap) -> Response {
172 if !authorized(&state, &headers) {
173 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
174 }
175 let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
177 Sse::new(stream)
178 .keep_alive(KeepAlive::default())
179 .into_response()
180}
181
182