Skip to main content

mcpr_core/proxy/pipeline/middlewares/
envelope_seal.rs

1//! Response-side middleware: serialize the buffered MCP message and
2//! re-wrap as SSE if the upstream framing requires it.
3//!
4//! Emits `Response::Raw` carrying the final bytes and the correct
5//! `Content-Type` header, so the axum `IntoResponse` edge needs no
6//! discriminator beyond what is already on the response.
7
8use async_trait::async_trait;
9use axum::body::Body;
10use axum::http::HeaderValue;
11use axum::http::header::CONTENT_TYPE;
12
13use crate::proxy::pipeline::middleware::ResponseMiddleware;
14use crate::proxy::pipeline::values::{Context, Envelope, Response};
15use crate::proxy::sse::wrap_as_sse;
16
17pub struct EnvelopeSealMiddleware;
18
19#[async_trait]
20impl ResponseMiddleware for EnvelopeSealMiddleware {
21    fn name(&self) -> &'static str {
22        "envelope_seal"
23    }
24
25    async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
26        let Response::McpBuffered {
27            envelope,
28            message,
29            status,
30            mut headers,
31        } = resp
32        else {
33            return resp;
34        };
35
36        let json_bytes = message.envelope.to_bytes();
37        let (bytes, ct) = match envelope {
38            Envelope::Json => (json_bytes, "application/json"),
39            Envelope::Sse => (wrap_as_sse(&json_bytes), "text/event-stream"),
40        };
41        headers.insert(CONTENT_TYPE, HeaderValue::from_static(ct));
42
43        // Tag `rewritten` whenever we parsed JSON (regardless of whether
44        // a middleware mutated it), plus `sse` when the upstream body
45        // was SSE-framed. Both facts are known here; stashing them on
46        // `cx.working.tags` lets `emit::build_request_event` produce the
47        // `note` string without re-inspecting the response.
48        if !cx.working.tags.as_slice().contains(&"rewritten") {
49            cx.working.tags.push("rewritten");
50        }
51        if matches!(envelope, Envelope::Sse) && !cx.working.tags.as_slice().contains(&"sse") {
52            cx.working.tags.push("sse");
53        }
54        cx.working.response_size = Some(bytes.len() as u64);
55
56        Response::Raw {
57            body: Body::from(bytes),
58            status,
59            headers,
60        }
61    }
62}
63
64#[cfg(test)]
65#[allow(non_snake_case)]
66mod tests {
67    use super::*;
68
69    use axum::http::{HeaderMap, StatusCode};
70    use serde_json::Value;
71
72    use crate::protocol::jsonrpc::JsonRpcEnvelope;
73    use crate::protocol::mcp::{McpMessage, MessageKind, ServerKind};
74    use crate::proxy::pipeline::middlewares::test_support::{test_context, test_proxy_state};
75
76    fn buffered(envelope: Envelope, body: &str) -> Response {
77        let env = JsonRpcEnvelope::parse(body.as_bytes()).unwrap();
78        let message = McpMessage {
79            envelope: env,
80            kind: MessageKind::Server(ServerKind::Result),
81        };
82        Response::McpBuffered {
83            envelope,
84            message,
85            status: StatusCode::OK,
86            headers: HeaderMap::new(),
87        }
88    }
89
90    async fn body_bytes(resp: Response) -> (String, axum::http::HeaderMap, StatusCode) {
91        match resp {
92            Response::Raw {
93                body,
94                status,
95                headers,
96            } => {
97                let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
98                (String::from_utf8(bytes.to_vec()).unwrap(), headers, status)
99            }
100            _ => panic!("expected Raw"),
101        }
102    }
103
104    #[tokio::test]
105    async fn on_response__json_envelope_emits_raw_with_json_content_type() {
106        let proxy = test_proxy_state();
107        let mut cx = test_context(proxy);
108        let resp = buffered(
109            Envelope::Json,
110            r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#,
111        );
112
113        let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
114        let (body, headers, status) = body_bytes(out).await;
115        assert_eq!(status, StatusCode::OK);
116        assert_eq!(
117            headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(),
118            "application/json"
119        );
120        let v: Value = serde_json::from_str(&body).unwrap();
121        assert_eq!(v["jsonrpc"], "2.0");
122        assert_eq!(v["id"], 1);
123        assert_eq!(v["result"]["ok"], true);
124    }
125
126    #[tokio::test]
127    async fn on_response__sse_envelope_wraps_as_event_stream() {
128        let proxy = test_proxy_state();
129        let mut cx = test_context(proxy);
130        let resp = buffered(
131            Envelope::Sse,
132            r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#,
133        );
134
135        let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
136        let (body, headers, _) = body_bytes(out).await;
137        assert_eq!(
138            headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(),
139            "text/event-stream"
140        );
141        assert!(body.starts_with("data: "), "got {body:?}");
142        assert!(body.ends_with("\n\n"));
143    }
144
145    #[tokio::test]
146    async fn on_response__error_envelope_preserves_code_and_message() {
147        let proxy = test_proxy_state();
148        let mut cx = test_context(proxy);
149        let resp = buffered(
150            Envelope::Json,
151            r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"bad req"}}"#,
152        );
153
154        let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
155        let (body, _, _) = body_bytes(out).await;
156        let v: Value = serde_json::from_str(&body).unwrap();
157        assert_eq!(v["error"]["code"], -32600);
158        assert_eq!(v["error"]["message"], "bad req");
159    }
160
161    #[tokio::test]
162    async fn on_response__non_buffered_passthrough() {
163        let proxy = test_proxy_state();
164        let mut cx = test_context(proxy);
165        let resp = Response::Upstream502 {
166            reason: "boom".into(),
167        };
168        let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
169        assert!(matches!(out, Response::Upstream502 { .. }));
170    }
171
172    #[tokio::test]
173    async fn on_response__preserves_status_and_custom_headers() {
174        let proxy = test_proxy_state();
175        let mut cx = test_context(proxy);
176        let env = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"result":{}}"#).unwrap();
177        let message = McpMessage {
178            envelope: env,
179            kind: MessageKind::Server(ServerKind::Result),
180        };
181        let mut headers = HeaderMap::new();
182        headers.insert("x-trace-id", "abc".parse().unwrap());
183        let resp = Response::McpBuffered {
184            envelope: Envelope::Json,
185            message,
186            status: StatusCode::ACCEPTED,
187            headers,
188        };
189
190        let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
191        let (_, headers, status) = body_bytes(out).await;
192        assert_eq!(status, StatusCode::ACCEPTED);
193        assert_eq!(headers.get("x-trace-id").unwrap().to_str().unwrap(), "abc");
194    }
195}