1use std::convert::Infallible;
17use std::sync::Arc;
18
19use axum::extract::{DefaultBodyLimit, State};
20use axum::http::{header, HeaderMap, StatusCode};
21use axum::response::sse::{Event, KeepAlive, Sse};
22use axum::response::{IntoResponse, Response};
23use axum::routing::post;
24use axum::{Json, Router};
25use serde_json::json;
26use tokio::net::TcpListener;
27use tracing::{error, info};
28
29use crate::errors::{MCSError, Result};
30use crate::kg::GraphHandle;
31use crate::server;
32
33#[derive(Clone)]
36pub struct HttpState {
37 kg: Arc<GraphHandle>,
38 auth_token: Option<Arc<str>>,
39}
40
41pub fn router(state: HttpState) -> Router {
44 Router::new()
45 .route("/mcp", post(post_handler).get(get_handler))
46 .route("/", post(post_handler).get(get_handler))
47 .layer(DefaultBodyLimit::max(server::MAX_REQUEST_BYTES))
48 .with_state(state)
49}
50
51pub async fn run(
57 addr: &str,
58 kg: Arc<GraphHandle>,
59 auth_token: Option<Arc<str>>,
60 tls_cert: Option<std::path::PathBuf>,
61 tls_key: Option<std::path::PathBuf>,
62) -> Result<()> {
63 let auth = if auth_token.is_some() { "on" } else { "off" };
64 let state = HttpState { kg, auth_token };
65
66 if let (Some(cert), Some(key)) = (tls_cert, tls_key) {
67 let tls = crate::tls::server_config(&cert, &key)
68 .await
69 .map_err(MCSError::IoError)?;
70 let socket_addr = resolve_addr(addr)?;
71 info!("Listening for HTTPS (Streamable) MCP on https://{socket_addr}/mcp (TLS, auth {auth})");
72 axum_server::bind_rustls(socket_addr, tls)
73 .serve(router(state).into_make_service())
74 .await
75 .map_err(MCSError::IoError)?;
76 } else {
77 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
78 info!("Listening for HTTP (Streamable) MCP on http://{addr}/mcp (auth {auth})");
79 axum::serve(listener, router(state))
80 .await
81 .map_err(MCSError::IoError)?;
82 }
83 Ok(())
84}
85
86fn resolve_addr(addr: &str) -> Result<std::net::SocketAddr> {
89 use std::net::ToSocketAddrs;
90 addr.to_socket_addrs()
91 .map_err(MCSError::IoError)?
92 .next()
93 .ok_or_else(|| {
94 MCSError::IoError(std::io::Error::new(
95 std::io::ErrorKind::InvalidInput,
96 format!("could not resolve bind address '{addr}'"),
97 ))
98 })
99}
100
101fn wants_sse(headers: &HeaderMap) -> bool {
102 headers
103 .get(header::ACCEPT)
104 .and_then(|v| v.to_str().ok())
105 .is_some_and(|a| a.contains("text/event-stream"))
106}
107
108fn authorized(state: &HttpState, headers: &HeaderMap) -> bool {
111 match state.auth_token {
112 None => true,
113 Some(ref expected) => headers
114 .get(header::AUTHORIZATION)
115 .and_then(|v| v.to_str().ok())
116 .is_some_and(|presented| server::token_matches(presented, expected)),
117 }
118}
119
120async fn post_handler(State(state): State<HttpState>, headers: HeaderMap, body: String) -> Response {
121 if !authorized(&state, &headers) {
122 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
123 }
124 let kg = state.kg;
125 let result = tokio::task::spawn_blocking(move || server::dispatch_http_body(&body, &kg)).await;
128
129 let outcome = match result {
130 Ok(inner) => inner,
131 Err(join_err) => {
132 error!("dispatch task panicked: {join_err}");
133 return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
134 }
135 };
136
137 match outcome {
138 Ok(None) => StatusCode::ACCEPTED.into_response(),
140 Ok(Some(value)) => {
141 if wants_sse(&headers) {
142 let json = serde_json::to_string(&value).unwrap();
144 let stream = futures::stream::once(async move {
145 Ok::<Event, Infallible>(Event::default().data(json))
146 });
147 Sse::new(stream).into_response()
148 } else {
149 Json(value).into_response()
150 }
151 }
152 Err(e) => {
153 let resp = json!({
155 "jsonrpc": "2.0",
156 "error": { "code": -32700, "message": format!("Parse error: {e}") },
157 "id": null
158 });
159 (StatusCode::BAD_REQUEST, Json(resp)).into_response()
160 }
161 }
162}
163
164async fn get_handler(State(state): State<HttpState>, headers: HeaderMap) -> Response {
165 if !authorized(&state, &headers) {
166 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
167 }
168 let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
170 Sse::new(stream)
171 .keep_alive(KeepAlive::default())
172 .into_response()
173}
174
175