aperion_shield/transport/
http_server.rs1use std::collections::HashMap;
27use std::convert::Infallible;
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31use bytes::Bytes;
32use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
33use hyper::body::{Frame, Incoming};
34use hyper::server::conn::http1;
35use hyper::service::service_fn;
36use hyper::{Method, Request, Response, StatusCode};
37use hyper_util::rt::TokioIo;
38use log::{error, info, warn};
39use serde_json::Value;
40use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
41
42type RespBody = BoxBody<Bytes, Infallible>;
43
44#[async_trait::async_trait]
47pub trait RequestGate: Send + Sync {
48 async fn intercept(&self, req: &Value) -> Option<Value>;
51}
52
53pub struct HttpDownstream {
55 pub pending: Mutex<HashMap<String, oneshot::Sender<String>>>,
57 pub broadcast: broadcast::Sender<String>,
60}
61
62impl HttpDownstream {
63 pub fn new() -> Arc<Self> {
64 let (tx, _) = broadcast::channel(super::CHANNEL_DEPTH);
65 Arc::new(Self { pending: Mutex::new(HashMap::new()), broadcast: tx })
66 }
67
68 pub async fn route_upstream_frame(&self, frame: String) {
72 if let Ok(parsed) = serde_json::from_str::<Value>(&frame) {
73 if let Some(id) = parsed.get("id") {
74 if !id.is_null() && parsed.get("method").is_none() {
75 let key = canonical_id(id);
76 if let Some(tx) = self.pending.lock().await.remove(&key) {
77 let _ = tx.send(frame);
78 return;
79 }
80 }
81 }
82 }
83 let _ = self.broadcast.send(frame);
85 }
86}
87
88pub fn canonical_id(id: &Value) -> String {
90 id.to_string()
91}
92
93pub async fn serve(
95 addr: SocketAddr,
96 gate: Arc<dyn RequestGate>,
97 to_upstream: mpsc::Sender<String>,
98 state: Arc<HttpDownstream>,
99) -> anyhow::Result<()> {
100 if !addr.ip().is_loopback() {
101 warn!(
102 "[shield] --http-listen {} is NOT loopback -- anyone who can reach this port \
103 can drive your MCP tools. Prefer 127.0.0.1.",
104 addr
105 );
106 }
107 let listener = tokio::net::TcpListener::bind(addr).await?;
108 info!("[shield] HTTP downstream listening on http://{} (Streamable HTTP MCP)", addr);
109 serve_on(listener, gate, to_upstream, state).await
110}
111
112pub async fn serve_on(
115 listener: tokio::net::TcpListener,
116 gate: Arc<dyn RequestGate>,
117 to_upstream: mpsc::Sender<String>,
118 state: Arc<HttpDownstream>,
119) -> anyhow::Result<()> {
120 loop {
121 let (stream, _peer) = match listener.accept().await {
122 Ok(x) => x,
123 Err(e) => {
124 error!("[shield] http accept error: {}", e);
125 continue;
126 }
127 };
128 let io = TokioIo::new(stream);
129 let gate = gate.clone();
130 let to_upstream = to_upstream.clone();
131 let state = state.clone();
132 tokio::spawn(async move {
133 let svc = service_fn(move |req: Request<Incoming>| {
134 let gate = gate.clone();
135 let to_upstream = to_upstream.clone();
136 let state = state.clone();
137 async move { Ok::<_, Infallible>(handle(req, gate, to_upstream, state).await) }
138 });
139 if let Err(e) = http1::Builder::new().serve_connection(io, svc).await {
140 log::debug!("[shield] http connection ended: {}", e);
142 }
143 });
144 }
145}
146
147async fn handle(
148 req: Request<Incoming>,
149 gate: Arc<dyn RequestGate>,
150 to_upstream: mpsc::Sender<String>,
151 state: Arc<HttpDownstream>,
152) -> Response<RespBody> {
153 match *req.method() {
154 Method::POST => handle_post(req, gate, to_upstream, state).await,
155 Method::GET => handle_get_sse(req, state).await,
156 Method::DELETE => text(StatusCode::OK, "session terminated"),
157 _ => text(StatusCode::METHOD_NOT_ALLOWED, "use POST / GET / DELETE"),
158 }
159}
160
161async fn handle_post(
162 req: Request<Incoming>,
163 gate: Arc<dyn RequestGate>,
164 to_upstream: mpsc::Sender<String>,
165 state: Arc<HttpDownstream>,
166) -> Response<RespBody> {
167 let body = match req.into_body().collect().await {
168 Ok(b) => b.to_bytes(),
169 Err(e) => return text(StatusCode::BAD_REQUEST, &format!("body read error: {}", e)),
170 };
171 let parsed: Value = match serde_json::from_slice(&body) {
172 Ok(v) => v,
173 Err(e) => return text(StatusCode::BAD_REQUEST, &format!("invalid JSON: {}", e)),
174 };
175 if parsed.is_array() {
176 return text(
177 StatusCode::BAD_REQUEST,
178 "JSON-RPC batching is not supported (removed in MCP 2025-06-18)",
179 );
180 }
181
182 let frame = parsed.to_string();
183 let is_initialize = parsed.get("method").and_then(|m| m.as_str()) == Some("initialize");
184 let id = parsed.get("id").cloned().unwrap_or(Value::Null);
185 let is_request = parsed.get("method").is_some() && !id.is_null();
186
187 if !is_request {
188 if to_upstream.send(frame).await.is_err() {
190 return text(StatusCode::BAD_GATEWAY, "upstream gone");
191 }
192 return text(StatusCode::ACCEPTED, "");
193 }
194
195 if let Some(decision_resp) = gate.intercept(&parsed).await {
197 return json_response(decision_resp.to_string(), is_initialize);
198 }
199
200 let (tx, rx) = oneshot::channel::<String>();
202 let key = canonical_id(&id);
203 state.pending.lock().await.insert(key.clone(), tx);
204
205 if to_upstream.send(frame).await.is_err() {
206 state.pending.lock().await.remove(&key);
207 return text(StatusCode::BAD_GATEWAY, "upstream gone");
208 }
209
210 match tokio::time::timeout(std::time::Duration::from_secs(300), rx).await {
212 Ok(Ok(resp_frame)) => json_response(resp_frame, is_initialize),
213 Ok(Err(_)) => text(StatusCode::BAD_GATEWAY, "upstream closed without responding"),
214 Err(_) => {
215 state.pending.lock().await.remove(&key);
216 text(StatusCode::GATEWAY_TIMEOUT, "upstream response timeout")
217 }
218 }
219}
220
221async fn handle_get_sse(req: Request<Incoming>, state: Arc<HttpDownstream>) -> Response<RespBody> {
222 let wants_sse = req
223 .headers()
224 .get("accept")
225 .and_then(|v| v.to_str().ok())
226 .map(|a| a.contains("text/event-stream"))
227 .unwrap_or(false);
228 if !wants_sse {
229 return text(
230 StatusCode::OK,
231 "aperion-shield Streamable HTTP MCP endpoint. POST JSON-RPC here; \
232 GET with Accept: text/event-stream for the server-initiated stream.",
233 );
234 }
235
236 let rx = state.broadcast.subscribe();
237 let stream = futures_util::stream::unfold(rx, |mut rx| async move {
238 loop {
239 match rx.recv().await {
240 Ok(frame) => {
241 let chunk = Bytes::from(format!("data: {}\n\n", frame));
242 return Some((Ok::<_, Infallible>(Frame::data(chunk)), rx));
243 }
244 Err(broadcast::error::RecvError::Lagged(n)) => {
245 warn!("[shield] SSE subscriber lagged, skipped {} frames", n);
246 continue;
247 }
248 Err(broadcast::error::RecvError::Closed) => return None,
249 }
250 }
251 });
252
253 Response::builder()
254 .status(StatusCode::OK)
255 .header("content-type", "text/event-stream")
256 .header("cache-control", "no-store")
257 .body(BoxBody::new(StreamBody::new(stream)))
258 .unwrap()
259}
260
261fn json_response(frame: String, mint_session: bool) -> Response<RespBody> {
262 let mut b = Response::builder()
263 .status(StatusCode::OK)
264 .header("content-type", "application/json");
265 if mint_session {
266 b = b.header("mcp-session-id", uuid::Uuid::new_v4().simple().to_string());
267 }
268 b.body(BoxBody::new(Full::new(Bytes::from(frame)))).unwrap()
269}
270
271fn text(status: StatusCode, msg: &str) -> Response<RespBody> {
272 Response::builder()
273 .status(status)
274 .header("content-type", "text/plain; charset=utf-8")
275 .body(BoxBody::new(Full::new(Bytes::from(msg.to_string()))))
276 .unwrap()
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use serde_json::json;
283
284 #[test]
285 fn canonical_id_distinguishes_number_and_string() {
286 assert_eq!(canonical_id(&json!(1)), "1");
287 assert_eq!(canonical_id(&json!("1")), "\"1\"");
288 assert_ne!(canonical_id(&json!(1)), canonical_id(&json!("1")));
289 }
290
291 #[tokio::test]
292 async fn route_completes_waiting_post() {
293 let state = HttpDownstream::new();
294 let (tx, rx) = oneshot::channel();
295 state.pending.lock().await.insert("7".to_string(), tx);
296 state
297 .route_upstream_frame(r#"{"jsonrpc":"2.0","id":7,"result":{}}"#.to_string())
298 .await;
299 let frame = rx.await.unwrap();
300 assert!(frame.contains("\"id\":7"));
301 }
302
303 #[tokio::test]
304 async fn route_broadcasts_unmatched_frames() {
305 let state = HttpDownstream::new();
306 let mut sub = state.broadcast.subscribe();
307 state
308 .route_upstream_frame(r#"{"jsonrpc":"2.0","method":"notifications/progress"}"#.to_string())
309 .await;
310 let frame = sub.recv().await.unwrap();
311 assert!(frame.contains("notifications/progress"));
312 }
313
314 #[tokio::test]
315 async fn upstream_request_with_id_is_broadcast_not_routed() {
316 let state = HttpDownstream::new();
319 let mut sub = state.broadcast.subscribe();
320 state
321 .route_upstream_frame(
322 r#"{"jsonrpc":"2.0","id":9,"method":"sampling/createMessage"}"#.to_string(),
323 )
324 .await;
325 assert!(sub.recv().await.unwrap().contains("sampling/createMessage"));
326 }
327}