Skip to main content

aperion_shield/transport/
http_server.rs

1//! IDE-facing Streamable HTTP MCP server (v0.9, `--http-listen`).
2//!
3//! Shield itself listens as an MCP server speaking the Streamable HTTP
4//! transport, so hosts that only talk HTTP can still put Shield in
5//! front of any upstream (stdio child process or remote HTTP server):
6//!
7//! * `POST <any path>` with a JSON-RPC **request** -> the request runs
8//!   through the same Shield gate as the stdio path (Block / Approval /
9//!   Warn / identity), is forwarded upstream on pass, and the matching
10//!   upstream response is returned as `application/json`.
11//! * `POST` with a **notification or client response** -> forwarded
12//!   upstream, `202 Accepted`.
13//! * `GET` with `Accept: text/event-stream` -> a long-lived SSE stream
14//!   carrying server-initiated messages (notifications, requests the
15//!   upstream pushes outside a POST exchange).
16//! * `DELETE` -> `200` (session termination; sessions here are lenient).
17//!
18//! An `Mcp-Session-Id` is minted on `initialize` and echoed back; Shield
19//! does not currently reject requests with missing/stale session ids
20//! (lenient mode -- enforcement adds nothing while Shield fronts exactly
21//! one upstream per process).
22//!
23//! JSON-RPC batch arrays are rejected with 400 -- the 2025-06-18 MCP
24//! revision removed batching support.
25
26use std::collections::HashMap;
27use std::convert::Infallible;
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31use bytes::Bytes;
32use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
33use hyper::body::{Frame, Incoming};
34use hyper::server::conn::http1;
35use hyper::service::service_fn;
36use hyper::{Method, Request, Response, StatusCode};
37use hyper_util::rt::TokioIo;
38use log::{error, info, warn};
39use serde_json::Value;
40use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
41
42type RespBody = BoxBody<Bytes, Infallible>;
43
44/// How the relay core lets the HTTP server run requests through the
45/// Shield gate. Implemented by `main.rs` on its `Shield` state.
46#[async_trait::async_trait]
47pub trait RequestGate: Send + Sync {
48    /// `Some(response)` when Shield answers the request itself (block /
49    /// denied approval / pending identity); `None` to forward upstream.
50    async fn intercept(&self, req: &Value) -> Option<Value>;
51}
52
53/// Shared state between the HTTP server and the relay core.
54pub struct HttpDownstream {
55    /// Responses the HTTP layer is waiting on, keyed by canonical id.
56    pub pending: Mutex<HashMap<String, oneshot::Sender<String>>>,
57    /// Frames with no waiting POST (server-initiated traffic) fan out to
58    /// every open GET SSE stream.
59    pub broadcast: broadcast::Sender<String>,
60}
61
62impl HttpDownstream {
63    pub fn new() -> Arc<Self> {
64        let (tx, _) = broadcast::channel(super::CHANNEL_DEPTH);
65        Arc::new(Self { pending: Mutex::new(HashMap::new()), broadcast: tx })
66    }
67
68    /// Route one upstream frame: complete the waiting POST if there is
69    /// one, otherwise broadcast to SSE subscribers. Called by the relay
70    /// core for every (post-interception) upstream frame.
71    pub async fn route_upstream_frame(&self, frame: String) {
72        if let Ok(parsed) = serde_json::from_str::<Value>(&frame) {
73            if let Some(id) = parsed.get("id") {
74                if !id.is_null() && parsed.get("method").is_none() {
75                    let key = canonical_id(id);
76                    if let Some(tx) = self.pending.lock().await.remove(&key) {
77                        let _ = tx.send(frame);
78                        return;
79                    }
80                }
81            }
82        }
83        // No waiter -- fan out (errors just mean no open GET streams).
84        let _ = self.broadcast.send(frame);
85    }
86}
87
88/// Canonical map key for a JSON-RPC id (number or string).
89pub fn canonical_id(id: &Value) -> String {
90    id.to_string()
91}
92
93/// Serve the downstream HTTP endpoint until the process exits.
94pub async fn serve(
95    addr: SocketAddr,
96    gate: Arc<dyn RequestGate>,
97    to_upstream: mpsc::Sender<String>,
98    state: Arc<HttpDownstream>,
99) -> anyhow::Result<()> {
100    if !addr.ip().is_loopback() {
101        warn!(
102            "[shield] --http-listen {} is NOT loopback -- anyone who can reach this port \
103             can drive your MCP tools. Prefer 127.0.0.1.",
104            addr
105        );
106    }
107    let listener = tokio::net::TcpListener::bind(addr).await?;
108    info!("[shield] HTTP downstream listening on http://{} (Streamable HTTP MCP)", addr);
109    serve_on(listener, gate, to_upstream, state).await
110}
111
112/// Accept-loop over an already-bound listener. Split out from [`serve`]
113/// so integration tests can bind port 0 and learn the real address.
114pub async fn serve_on(
115    listener: tokio::net::TcpListener,
116    gate: Arc<dyn RequestGate>,
117    to_upstream: mpsc::Sender<String>,
118    state: Arc<HttpDownstream>,
119) -> anyhow::Result<()> {
120    loop {
121        let (stream, _peer) = match listener.accept().await {
122            Ok(x) => x,
123            Err(e) => {
124                error!("[shield] http accept error: {}", e);
125                continue;
126            }
127        };
128        let io = TokioIo::new(stream);
129        let gate = gate.clone();
130        let to_upstream = to_upstream.clone();
131        let state = state.clone();
132        tokio::spawn(async move {
133            let svc = service_fn(move |req: Request<Incoming>| {
134                let gate = gate.clone();
135                let to_upstream = to_upstream.clone();
136                let state = state.clone();
137                async move { Ok::<_, Infallible>(handle(req, gate, to_upstream, state).await) }
138            });
139            if let Err(e) = http1::Builder::new().serve_connection(io, svc).await {
140                // Normal for SSE streams the client drops.
141                log::debug!("[shield] http connection ended: {}", e);
142            }
143        });
144    }
145}
146
147async fn handle(
148    req: Request<Incoming>,
149    gate: Arc<dyn RequestGate>,
150    to_upstream: mpsc::Sender<String>,
151    state: Arc<HttpDownstream>,
152) -> Response<RespBody> {
153    match *req.method() {
154        Method::POST => handle_post(req, gate, to_upstream, state).await,
155        Method::GET => handle_get_sse(req, state).await,
156        Method::DELETE => text(StatusCode::OK, "session terminated"),
157        _ => text(StatusCode::METHOD_NOT_ALLOWED, "use POST / GET / DELETE"),
158    }
159}
160
161async fn handle_post(
162    req: Request<Incoming>,
163    gate: Arc<dyn RequestGate>,
164    to_upstream: mpsc::Sender<String>,
165    state: Arc<HttpDownstream>,
166) -> Response<RespBody> {
167    let body = match req.into_body().collect().await {
168        Ok(b) => b.to_bytes(),
169        Err(e) => return text(StatusCode::BAD_REQUEST, &format!("body read error: {}", e)),
170    };
171    let parsed: Value = match serde_json::from_slice(&body) {
172        Ok(v) => v,
173        Err(e) => return text(StatusCode::BAD_REQUEST, &format!("invalid JSON: {}", e)),
174    };
175    if parsed.is_array() {
176        return text(
177            StatusCode::BAD_REQUEST,
178            "JSON-RPC batching is not supported (removed in MCP 2025-06-18)",
179        );
180    }
181
182    let frame = parsed.to_string();
183    let is_initialize = parsed.get("method").and_then(|m| m.as_str()) == Some("initialize");
184    let id = parsed.get("id").cloned().unwrap_or(Value::Null);
185    let is_request = parsed.get("method").is_some() && !id.is_null();
186
187    if !is_request {
188        // Notification or client->server response: forward, 202.
189        if to_upstream.send(frame).await.is_err() {
190            return text(StatusCode::BAD_GATEWAY, "upstream gone");
191        }
192        return text(StatusCode::ACCEPTED, "");
193    }
194
195    // Run the Shield gate exactly like the stdio path.
196    if let Some(decision_resp) = gate.intercept(&parsed).await {
197        return json_response(decision_resp.to_string(), is_initialize);
198    }
199
200    // Register the waiter BEFORE forwarding so the response can't race us.
201    let (tx, rx) = oneshot::channel::<String>();
202    let key = canonical_id(&id);
203    state.pending.lock().await.insert(key.clone(), tx);
204
205    if to_upstream.send(frame).await.is_err() {
206        state.pending.lock().await.remove(&key);
207        return text(StatusCode::BAD_GATEWAY, "upstream gone");
208    }
209
210    // Approvals can legitimately take a minute -- be generous.
211    match tokio::time::timeout(std::time::Duration::from_secs(300), rx).await {
212        Ok(Ok(resp_frame)) => json_response(resp_frame, is_initialize),
213        Ok(Err(_)) => text(StatusCode::BAD_GATEWAY, "upstream closed without responding"),
214        Err(_) => {
215            state.pending.lock().await.remove(&key);
216            text(StatusCode::GATEWAY_TIMEOUT, "upstream response timeout")
217        }
218    }
219}
220
221async fn handle_get_sse(req: Request<Incoming>, state: Arc<HttpDownstream>) -> Response<RespBody> {
222    let wants_sse = req
223        .headers()
224        .get("accept")
225        .and_then(|v| v.to_str().ok())
226        .map(|a| a.contains("text/event-stream"))
227        .unwrap_or(false);
228    if !wants_sse {
229        return text(
230            StatusCode::OK,
231            "aperion-shield Streamable HTTP MCP endpoint. POST JSON-RPC here; \
232             GET with Accept: text/event-stream for the server-initiated stream.",
233        );
234    }
235
236    let rx = state.broadcast.subscribe();
237    let stream = futures_util::stream::unfold(rx, |mut rx| async move {
238        loop {
239            match rx.recv().await {
240                Ok(frame) => {
241                    let chunk = Bytes::from(format!("data: {}\n\n", frame));
242                    return Some((Ok::<_, Infallible>(Frame::data(chunk)), rx));
243                }
244                Err(broadcast::error::RecvError::Lagged(n)) => {
245                    warn!("[shield] SSE subscriber lagged, skipped {} frames", n);
246                    continue;
247                }
248                Err(broadcast::error::RecvError::Closed) => return None,
249            }
250        }
251    });
252
253    Response::builder()
254        .status(StatusCode::OK)
255        .header("content-type", "text/event-stream")
256        .header("cache-control", "no-store")
257        .body(BoxBody::new(StreamBody::new(stream)))
258        .unwrap()
259}
260
261fn json_response(frame: String, mint_session: bool) -> Response<RespBody> {
262    let mut b = Response::builder()
263        .status(StatusCode::OK)
264        .header("content-type", "application/json");
265    if mint_session {
266        b = b.header("mcp-session-id", uuid::Uuid::new_v4().simple().to_string());
267    }
268    b.body(BoxBody::new(Full::new(Bytes::from(frame)))).unwrap()
269}
270
271fn text(status: StatusCode, msg: &str) -> Response<RespBody> {
272    Response::builder()
273        .status(status)
274        .header("content-type", "text/plain; charset=utf-8")
275        .body(BoxBody::new(Full::new(Bytes::from(msg.to_string()))))
276        .unwrap()
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use serde_json::json;
283
284    #[test]
285    fn canonical_id_distinguishes_number_and_string() {
286        assert_eq!(canonical_id(&json!(1)), "1");
287        assert_eq!(canonical_id(&json!("1")), "\"1\"");
288        assert_ne!(canonical_id(&json!(1)), canonical_id(&json!("1")));
289    }
290
291    #[tokio::test]
292    async fn route_completes_waiting_post() {
293        let state = HttpDownstream::new();
294        let (tx, rx) = oneshot::channel();
295        state.pending.lock().await.insert("7".to_string(), tx);
296        state
297            .route_upstream_frame(r#"{"jsonrpc":"2.0","id":7,"result":{}}"#.to_string())
298            .await;
299        let frame = rx.await.unwrap();
300        assert!(frame.contains("\"id\":7"));
301    }
302
303    #[tokio::test]
304    async fn route_broadcasts_unmatched_frames() {
305        let state = HttpDownstream::new();
306        let mut sub = state.broadcast.subscribe();
307        state
308            .route_upstream_frame(r#"{"jsonrpc":"2.0","method":"notifications/progress"}"#.to_string())
309            .await;
310        let frame = sub.recv().await.unwrap();
311        assert!(frame.contains("notifications/progress"));
312    }
313
314    #[tokio::test]
315    async fn upstream_request_with_id_is_broadcast_not_routed() {
316        // A frame with BOTH method and id is an upstream-initiated
317        // request (e.g. sampling), not a response -- it must go to SSE.
318        let state = HttpDownstream::new();
319        let mut sub = state.broadcast.subscribe();
320        state
321            .route_upstream_frame(
322                r#"{"jsonrpc":"2.0","id":9,"method":"sampling/createMessage"}"#.to_string(),
323            )
324            .await;
325        assert!(sub.recv().await.unwrap().contains("sampling/createMessage"));
326    }
327}