use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
use hyper::body::{Frame, Incoming};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use log::{error, info, warn};
use serde_json::Value;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
type RespBody = BoxBody<Bytes, Infallible>;
#[async_trait::async_trait]
pub trait RequestGate: Send + Sync {
async fn intercept(&self, req: &Value) -> Option<Value>;
}
pub struct HttpDownstream {
pub pending: Mutex<HashMap<String, oneshot::Sender<String>>>,
pub broadcast: broadcast::Sender<String>,
}
impl HttpDownstream {
pub fn new() -> Arc<Self> {
let (tx, _) = broadcast::channel(super::CHANNEL_DEPTH);
Arc::new(Self { pending: Mutex::new(HashMap::new()), broadcast: tx })
}
pub async fn route_upstream_frame(&self, frame: String) {
if let Ok(parsed) = serde_json::from_str::<Value>(&frame) {
if let Some(id) = parsed.get("id") {
if !id.is_null() && parsed.get("method").is_none() {
let key = canonical_id(id);
if let Some(tx) = self.pending.lock().await.remove(&key) {
let _ = tx.send(frame);
return;
}
}
}
}
let _ = self.broadcast.send(frame);
}
}
pub fn canonical_id(id: &Value) -> String {
id.to_string()
}
pub async fn serve(
addr: SocketAddr,
gate: Arc<dyn RequestGate>,
to_upstream: mpsc::Sender<String>,
state: Arc<HttpDownstream>,
) -> anyhow::Result<()> {
if !addr.ip().is_loopback() {
warn!(
"[shield] --http-listen {} is NOT loopback -- anyone who can reach this port \
can drive your MCP tools. Prefer 127.0.0.1.",
addr
);
}
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("[shield] HTTP downstream listening on http://{} (Streamable HTTP MCP)", addr);
serve_on(listener, gate, to_upstream, state).await
}
pub async fn serve_on(
listener: tokio::net::TcpListener,
gate: Arc<dyn RequestGate>,
to_upstream: mpsc::Sender<String>,
state: Arc<HttpDownstream>,
) -> anyhow::Result<()> {
loop {
let (stream, _peer) = match listener.accept().await {
Ok(x) => x,
Err(e) => {
error!("[shield] http accept error: {}", e);
continue;
}
};
let io = TokioIo::new(stream);
let gate = gate.clone();
let to_upstream = to_upstream.clone();
let state = state.clone();
tokio::spawn(async move {
let svc = service_fn(move |req: Request<Incoming>| {
let gate = gate.clone();
let to_upstream = to_upstream.clone();
let state = state.clone();
async move { Ok::<_, Infallible>(handle(req, gate, to_upstream, state).await) }
});
if let Err(e) = http1::Builder::new().serve_connection(io, svc).await {
log::debug!("[shield] http connection ended: {}", e);
}
});
}
}
async fn handle(
req: Request<Incoming>,
gate: Arc<dyn RequestGate>,
to_upstream: mpsc::Sender<String>,
state: Arc<HttpDownstream>,
) -> Response<RespBody> {
match *req.method() {
Method::POST => handle_post(req, gate, to_upstream, state).await,
Method::GET => handle_get_sse(req, state).await,
Method::DELETE => text(StatusCode::OK, "session terminated"),
_ => text(StatusCode::METHOD_NOT_ALLOWED, "use POST / GET / DELETE"),
}
}
async fn handle_post(
req: Request<Incoming>,
gate: Arc<dyn RequestGate>,
to_upstream: mpsc::Sender<String>,
state: Arc<HttpDownstream>,
) -> Response<RespBody> {
let body = match req.into_body().collect().await {
Ok(b) => b.to_bytes(),
Err(e) => return text(StatusCode::BAD_REQUEST, &format!("body read error: {}", e)),
};
let parsed: Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => return text(StatusCode::BAD_REQUEST, &format!("invalid JSON: {}", e)),
};
if parsed.is_array() {
return text(
StatusCode::BAD_REQUEST,
"JSON-RPC batching is not supported (removed in MCP 2025-06-18)",
);
}
let frame = parsed.to_string();
let is_initialize = parsed.get("method").and_then(|m| m.as_str()) == Some("initialize");
let id = parsed.get("id").cloned().unwrap_or(Value::Null);
let is_request = parsed.get("method").is_some() && !id.is_null();
if !is_request {
if to_upstream.send(frame).await.is_err() {
return text(StatusCode::BAD_GATEWAY, "upstream gone");
}
return text(StatusCode::ACCEPTED, "");
}
if let Some(decision_resp) = gate.intercept(&parsed).await {
return json_response(decision_resp.to_string(), is_initialize);
}
let (tx, rx) = oneshot::channel::<String>();
let key = canonical_id(&id);
state.pending.lock().await.insert(key.clone(), tx);
if to_upstream.send(frame).await.is_err() {
state.pending.lock().await.remove(&key);
return text(StatusCode::BAD_GATEWAY, "upstream gone");
}
match tokio::time::timeout(std::time::Duration::from_secs(300), rx).await {
Ok(Ok(resp_frame)) => json_response(resp_frame, is_initialize),
Ok(Err(_)) => text(StatusCode::BAD_GATEWAY, "upstream closed without responding"),
Err(_) => {
state.pending.lock().await.remove(&key);
text(StatusCode::GATEWAY_TIMEOUT, "upstream response timeout")
}
}
}
async fn handle_get_sse(req: Request<Incoming>, state: Arc<HttpDownstream>) -> Response<RespBody> {
let wants_sse = req
.headers()
.get("accept")
.and_then(|v| v.to_str().ok())
.map(|a| a.contains("text/event-stream"))
.unwrap_or(false);
if !wants_sse {
return text(
StatusCode::OK,
"aperion-shield Streamable HTTP MCP endpoint. POST JSON-RPC here; \
GET with Accept: text/event-stream for the server-initiated stream.",
);
}
let rx = state.broadcast.subscribe();
let stream = futures_util::stream::unfold(rx, |mut rx| async move {
loop {
match rx.recv().await {
Ok(frame) => {
let chunk = Bytes::from(format!("data: {}\n\n", frame));
return Some((Ok::<_, Infallible>(Frame::data(chunk)), rx));
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("[shield] SSE subscriber lagged, skipped {} frames", n);
continue;
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
});
Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/event-stream")
.header("cache-control", "no-store")
.body(BoxBody::new(StreamBody::new(stream)))
.unwrap()
}
fn json_response(frame: String, mint_session: bool) -> Response<RespBody> {
let mut b = Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json");
if mint_session {
b = b.header("mcp-session-id", uuid::Uuid::new_v4().simple().to_string());
}
b.body(BoxBody::new(Full::new(Bytes::from(frame)))).unwrap()
}
fn text(status: StatusCode, msg: &str) -> Response<RespBody> {
Response::builder()
.status(status)
.header("content-type", "text/plain; charset=utf-8")
.body(BoxBody::new(Full::new(Bytes::from(msg.to_string()))))
.unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn canonical_id_distinguishes_number_and_string() {
assert_eq!(canonical_id(&json!(1)), "1");
assert_eq!(canonical_id(&json!("1")), "\"1\"");
assert_ne!(canonical_id(&json!(1)), canonical_id(&json!("1")));
}
#[tokio::test]
async fn route_completes_waiting_post() {
let state = HttpDownstream::new();
let (tx, rx) = oneshot::channel();
state.pending.lock().await.insert("7".to_string(), tx);
state
.route_upstream_frame(r#"{"jsonrpc":"2.0","id":7,"result":{}}"#.to_string())
.await;
let frame = rx.await.unwrap();
assert!(frame.contains("\"id\":7"));
}
#[tokio::test]
async fn route_broadcasts_unmatched_frames() {
let state = HttpDownstream::new();
let mut sub = state.broadcast.subscribe();
state
.route_upstream_frame(r#"{"jsonrpc":"2.0","method":"notifications/progress"}"#.to_string())
.await;
let frame = sub.recv().await.unwrap();
assert!(frame.contains("notifications/progress"));
}
#[tokio::test]
async fn upstream_request_with_id_is_broadcast_not_routed() {
let state = HttpDownstream::new();
let mut sub = state.broadcast.subscribe();
state
.route_upstream_frame(
r#"{"jsonrpc":"2.0","id":9,"method":"sampling/createMessage"}"#.to_string(),
)
.await;
assert!(sub.recv().await.unwrap().contains("sampling/createMessage"));
}
}