Skip to main content

mcp_memory/
http.rs

1//! MCP **Streamable HTTP** transport (the 2025-03-26 transport that
2//! superseded the older HTTP+SSE pair).
3//!
4//! * `POST /mcp` — the client sends one JSON-RPC message (or a batch array).
5//!   The reply is delivered as `application/json` by default, or as a one-shot
6//!   `text/event-stream` (SSE) event when the client `Accept`s it. A body of
7//!   only notifications gets `202 Accepted` with no content.
8//! * `GET /mcp` — opens a standalone server→client SSE stream. This server has
9//!   no server-initiated messages, so the stream simply stays open with
10//!   keep-alives; it exists for spec compliance.
11//!
12//! `/` is also wired to the same handlers for convenience. The JSON-RPC
13//! semantics are identical to the stdio and TCP transports — only framing
14//! differs (see [`crate::server::dispatch_http_body`]).
15
16use std::convert::Infallible;
17use std::sync::Arc;
18
19use axum::extract::{DefaultBodyLimit, State};
20use axum::http::{header, HeaderMap, StatusCode};
21use axum::response::sse::{Event, KeepAlive, Sse};
22use axum::response::{IntoResponse, Response};
23use axum::routing::post;
24use axum::{Json, Router};
25use serde_json::json;
26use tokio::net::TcpListener;
27use tracing::{error, info};
28
29use crate::errors::{MCSError, Result};
30use crate::kg::GraphHandle;
31use crate::server;
32use crate::vector_store::VectorStore;
33
34/// Shared state for the HTTP handlers: the graph, the optional vector store, and
35/// an optional bearer token required on every request when present.
36#[derive(Clone)]
37pub struct HttpState {
38    kg: Arc<GraphHandle>,
39    vs: Option<Arc<VectorStore>>,
40    auth_token: Option<Arc<str>>,
41}
42
43/// Build the axum router for the HTTP transport. Exposed so tests can drive it
44/// with `tower::ServiceExt::oneshot` without binding a socket.
45pub fn router(state: HttpState) -> Router {
46    Router::new()
47        .route("/mcp", post(post_handler).get(get_handler))
48        .route("/", post(post_handler).get(get_handler))
49        .layer(DefaultBodyLimit::max(server::MAX_REQUEST_BYTES))
50        .with_state(state)
51}
52
53/// Bind `addr` and serve the HTTP transport until the process is killed.
54///
55/// When `tls_cert` and `tls_key` are both set, the transport is served over TLS
56/// (HTTPS); otherwise it stays plaintext. The caller (`config.rs`) guarantees
57/// the two are set together.
58pub async fn run(
59    addr: &str,
60    kg: Arc<GraphHandle>,
61    vs: Option<Arc<VectorStore>>,
62    auth_token: Option<Arc<str>>,
63    tls_cert: Option<std::path::PathBuf>,
64    tls_key: Option<std::path::PathBuf>,
65) -> Result<()> {
66    let auth = if auth_token.is_some() { "on" } else { "off" };
67    let state = HttpState { kg, vs, auth_token };
68
69    if let (Some(cert), Some(key)) = (tls_cert, tls_key) {
70        let tls = crate::tls::server_config(&cert, &key)
71            .await
72            .map_err(MCSError::IoError)?;
73        let socket_addr = resolve_addr(addr)?;
74        info!("Listening for HTTPS (Streamable) MCP on https://{socket_addr}/mcp (TLS, auth {auth})");
75        axum_server::bind_rustls(socket_addr, tls)
76            .serve(router(state).into_make_service())
77            .await
78            .map_err(MCSError::IoError)?;
79    } else {
80        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
81        info!("Listening for HTTP (Streamable) MCP on http://{addr}/mcp (auth {auth})");
82        axum::serve(listener, router(state))
83            .await
84            .map_err(MCSError::IoError)?;
85    }
86    Ok(())
87}
88
89/// Resolve a `host:port` string to a single `SocketAddr` for `axum_server`,
90/// which binds an address rather than an already-bound listener.
91fn resolve_addr(addr: &str) -> Result<std::net::SocketAddr> {
92    use std::net::ToSocketAddrs;
93    addr.to_socket_addrs()
94        .map_err(MCSError::IoError)?
95        .next()
96        .ok_or_else(|| {
97            MCSError::IoError(std::io::Error::new(
98                std::io::ErrorKind::InvalidInput,
99                format!("could not resolve bind address '{addr}'"),
100            ))
101        })
102}
103
104fn wants_sse(headers: &HeaderMap) -> bool {
105    headers
106        .get(header::ACCEPT)
107        .and_then(|v| v.to_str().ok())
108        .is_some_and(|a| a.contains("text/event-stream"))
109}
110
111/// `true` when the request is allowed: either no token is configured, or the
112/// `Authorization` header carries the expected bearer token.
113fn authorized(state: &HttpState, headers: &HeaderMap) -> bool {
114    match state.auth_token {
115        None => true,
116        Some(ref expected) => headers
117            .get(header::AUTHORIZATION)
118            .and_then(|v| v.to_str().ok())
119            .is_some_and(|presented| server::token_matches(presented, expected)),
120    }
121}
122
123async fn post_handler(State(state): State<HttpState>, headers: HeaderMap, body: String) -> Response {
124    if !authorized(&state, &headers) {
125        return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
126    }
127    let kg = state.kg;
128    let vs = state.vs;
129    // The dispatch path locks the graph and may perform a blocking fsync, so
130    // run it off the async worker pool (keeps the HTTP reactor responsive).
131    let result = tokio::task::spawn_blocking(move || {
132        server::dispatch_http_body(&body, &kg, vs.as_deref())
133    })
134    .await;
135
136    let outcome = match result {
137        Ok(inner) => inner,
138        Err(join_err) => {
139            error!("dispatch task panicked: {join_err}");
140            return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
141        }
142    };
143
144    match outcome {
145        // Body held only notifications → nothing to return.
146        Ok(None) => StatusCode::ACCEPTED.into_response(),
147        Ok(Some(value)) => {
148            if wants_sse(&headers) {
149                // One JSON-RPC reply delivered as a single SSE event, then close.
150                let json = serde_json::to_string(&value).unwrap();
151                let stream = futures::stream::once(async move {
152                    Ok::<Event, Infallible>(Event::default().data(json))
153                });
154                Sse::new(stream).into_response()
155            } else {
156                Json(value).into_response()
157            }
158        }
159        Err(e) => {
160            // Malformed JSON body → JSON-RPC parse error.
161            let resp = json!({
162                "jsonrpc": "2.0",
163                "error": { "code": -32700, "message": format!("Parse error: {e}") },
164                "id": null
165            });
166            (StatusCode::BAD_REQUEST, Json(resp)).into_response()
167        }
168    }
169}
170
171async fn get_handler(State(state): State<HttpState>, headers: HeaderMap) -> Response {
172    if !authorized(&state, &headers) {
173        return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
174    }
175    // No server-initiated messages: an open, keep-alive'd stream for compliance.
176    let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
177    Sse::new(stream)
178        .keep_alive(KeepAlive::default())
179        .into_response()
180}
181
182