mcpr_core/proxy/pipeline/middlewares/
envelope_seal.rs1use async_trait::async_trait;
9use axum::body::Body;
10use axum::http::HeaderValue;
11use axum::http::header::CONTENT_TYPE;
12
13use crate::proxy::pipeline::middleware::ResponseMiddleware;
14use crate::proxy::pipeline::values::{Context, Envelope, Response};
15use crate::proxy::sse::wrap_as_sse;
16
17pub struct EnvelopeSealMiddleware;
18
19#[async_trait]
20impl ResponseMiddleware for EnvelopeSealMiddleware {
21 fn name(&self) -> &'static str {
22 "envelope_seal"
23 }
24
25 async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
26 let Response::McpBuffered {
27 envelope,
28 message,
29 status,
30 mut headers,
31 } = resp
32 else {
33 return resp;
34 };
35
36 let json_bytes = message.envelope.to_bytes();
37 let (bytes, ct) = match envelope {
38 Envelope::Json => (json_bytes, "application/json"),
39 Envelope::Sse => (wrap_as_sse(&json_bytes), "text/event-stream"),
40 };
41 headers.insert(CONTENT_TYPE, HeaderValue::from_static(ct));
42
43 if !cx.working.tags.as_slice().contains(&"rewritten") {
49 cx.working.tags.push("rewritten");
50 }
51 if matches!(envelope, Envelope::Sse) && !cx.working.tags.as_slice().contains(&"sse") {
52 cx.working.tags.push("sse");
53 }
54 cx.working.response_size = Some(bytes.len() as u64);
55
56 Response::Raw {
57 body: Body::from(bytes),
58 status,
59 headers,
60 }
61 }
62}
63
64#[cfg(test)]
65#[allow(non_snake_case)]
66mod tests {
67 use super::*;
68
69 use axum::http::{HeaderMap, StatusCode};
70 use serde_json::Value;
71
72 use crate::protocol::jsonrpc::JsonRpcEnvelope;
73 use crate::protocol::mcp::{McpMessage, MessageKind, ServerKind};
74 use crate::proxy::pipeline::middlewares::test_support::{test_context, test_proxy_state};
75
76 fn buffered(envelope: Envelope, body: &str) -> Response {
77 let env = JsonRpcEnvelope::parse(body.as_bytes()).unwrap();
78 let message = McpMessage {
79 envelope: env,
80 kind: MessageKind::Server(ServerKind::Result),
81 };
82 Response::McpBuffered {
83 envelope,
84 message,
85 status: StatusCode::OK,
86 headers: HeaderMap::new(),
87 }
88 }
89
90 async fn body_bytes(resp: Response) -> (String, axum::http::HeaderMap, StatusCode) {
91 match resp {
92 Response::Raw {
93 body,
94 status,
95 headers,
96 } => {
97 let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
98 (String::from_utf8(bytes.to_vec()).unwrap(), headers, status)
99 }
100 _ => panic!("expected Raw"),
101 }
102 }
103
104 #[tokio::test]
105 async fn on_response__json_envelope_emits_raw_with_json_content_type() {
106 let proxy = test_proxy_state();
107 let mut cx = test_context(proxy);
108 let resp = buffered(
109 Envelope::Json,
110 r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#,
111 );
112
113 let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
114 let (body, headers, status) = body_bytes(out).await;
115 assert_eq!(status, StatusCode::OK);
116 assert_eq!(
117 headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(),
118 "application/json"
119 );
120 let v: Value = serde_json::from_str(&body).unwrap();
121 assert_eq!(v["jsonrpc"], "2.0");
122 assert_eq!(v["id"], 1);
123 assert_eq!(v["result"]["ok"], true);
124 }
125
126 #[tokio::test]
127 async fn on_response__sse_envelope_wraps_as_event_stream() {
128 let proxy = test_proxy_state();
129 let mut cx = test_context(proxy);
130 let resp = buffered(
131 Envelope::Sse,
132 r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#,
133 );
134
135 let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
136 let (body, headers, _) = body_bytes(out).await;
137 assert_eq!(
138 headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(),
139 "text/event-stream"
140 );
141 assert!(body.starts_with("data: "), "got {body:?}");
142 assert!(body.ends_with("\n\n"));
143 }
144
145 #[tokio::test]
146 async fn on_response__error_envelope_preserves_code_and_message() {
147 let proxy = test_proxy_state();
148 let mut cx = test_context(proxy);
149 let resp = buffered(
150 Envelope::Json,
151 r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"bad req"}}"#,
152 );
153
154 let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
155 let (body, _, _) = body_bytes(out).await;
156 let v: Value = serde_json::from_str(&body).unwrap();
157 assert_eq!(v["error"]["code"], -32600);
158 assert_eq!(v["error"]["message"], "bad req");
159 }
160
161 #[tokio::test]
162 async fn on_response__non_buffered_passthrough() {
163 let proxy = test_proxy_state();
164 let mut cx = test_context(proxy);
165 let resp = Response::Upstream502 {
166 reason: "boom".into(),
167 };
168 let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
169 assert!(matches!(out, Response::Upstream502 { .. }));
170 }
171
172 #[tokio::test]
173 async fn on_response__preserves_status_and_custom_headers() {
174 let proxy = test_proxy_state();
175 let mut cx = test_context(proxy);
176 let env = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"result":{}}"#).unwrap();
177 let message = McpMessage {
178 envelope: env,
179 kind: MessageKind::Server(ServerKind::Result),
180 };
181 let mut headers = HeaderMap::new();
182 headers.insert("x-trace-id", "abc".parse().unwrap());
183 let resp = Response::McpBuffered {
184 envelope: Envelope::Json,
185 message,
186 status: StatusCode::ACCEPTED,
187 headers,
188 };
189
190 let out = EnvelopeSealMiddleware.on_response(resp, &mut cx).await;
191 let (_, headers, status) = body_bytes(out).await;
192 assert_eq!(status, StatusCode::ACCEPTED);
193 assert_eq!(headers.get("x-trace-id").unwrap().to_str().unwrap(), "abc");
194 }
195}