mcpr_core/proxy/pipeline/middlewares/
csp_rewrite.rs1use std::sync::Arc;
14
15use arc_swap::ArcSwap;
16use async_trait::async_trait;
17use serde_json::Value;
18
19use crate::protocol::mcp::{ClientMethod, ResourcesMethod, ToolsMethod};
20use crate::proxy::pipeline::middleware::ResponseMiddleware;
21use crate::proxy::pipeline::values::{Context, Response};
22use crate::proxy::{RewriteConfig, rewrite_response};
23
24const MARKERS: &[&[u8]] = &[
28 b"connect_domains",
29 b"resource_domains",
30 b"frame_domains",
31 b"connectDomains",
32 b"resourceDomains",
33 b"frameDomains",
34 b"openai/widgetCSP",
35 b"ui.csp",
36 b"openai/widgetDomain",
37];
38
39pub struct CspRewriteMiddleware {
40 config: Arc<ArcSwap<RewriteConfig>>,
41}
42
43impl CspRewriteMiddleware {
44 pub fn new(config: Arc<ArcSwap<RewriteConfig>>) -> Self {
45 Self { config }
46 }
47}
48
49#[async_trait]
50impl ResponseMiddleware for CspRewriteMiddleware {
51 fn name(&self) -> &'static str {
52 "csp_rewrite"
53 }
54
55 async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
56 let Response::McpBuffered {
57 envelope,
58 mut message,
59 status,
60 headers,
61 } = resp
62 else {
63 return resp;
64 };
65
66 let eligible = matches!(
67 cx.working.request_method,
68 Some(ClientMethod::Tools(ToolsMethod::List))
69 | Some(ClientMethod::Tools(ToolsMethod::Call))
70 | Some(ClientMethod::Resources(ResourcesMethod::List))
71 | Some(ClientMethod::Resources(ResourcesMethod::TemplatesList))
72 | Some(ClientMethod::Resources(ResourcesMethod::Read))
73 );
74 let raw_bytes = message.envelope.result.as_ref().map(|r| r.get().as_bytes());
75 let should_rewrite = eligible && raw_bytes.map(has_markers).unwrap_or(false);
76 if !should_rewrite {
77 return Response::McpBuffered {
78 envelope,
79 message,
80 status,
81 headers,
82 };
83 }
84
85 let method_str = cx
86 .working
87 .request_method
88 .as_ref()
89 .and_then(crate::protocol::mcp::ClientMethod::as_str)
90 .unwrap_or("");
91 let Ok(result_val) = serde_json::from_slice::<Value>(raw_bytes.unwrap()) else {
92 return Response::McpBuffered {
93 envelope,
94 message,
95 status,
96 headers,
97 };
98 };
99
100 let mut wrapper = Value::Object(Default::default());
101 wrapper["result"] = result_val;
102 let cfg = self.config.load();
103 if rewrite_response(method_str, &mut wrapper, &cfg) {
104 let rewritten = wrapper
105 .get_mut("result")
106 .map(std::mem::take)
107 .unwrap_or(Value::Null);
108 if let Ok(boxed) = serde_json::value::to_raw_value(&rewritten) {
109 message.envelope.result = Some(boxed);
110 }
111 }
112
113 Response::McpBuffered {
114 envelope,
115 message,
116 status,
117 headers,
118 }
119 }
120}
121
122fn has_markers(body: &[u8]) -> bool {
123 MARKERS.iter().any(|m| contains_slice(body, m))
124}
125
126fn contains_slice(haystack: &[u8], needle: &[u8]) -> bool {
127 if needle.is_empty() || haystack.len() < needle.len() {
128 return needle.is_empty();
129 }
130 haystack.windows(needle.len()).any(|win| win == needle)
131}
132
133#[cfg(test)]
134#[allow(non_snake_case)]
135mod tests {
136 use super::*;
137
138 use axum::http::StatusCode;
139
140 use crate::proxy::CspConfig;
141 use crate::proxy::pipeline::middlewares::test_support::{
142 mcp_buffered_response, set_request_method, test_context, test_proxy_state,
143 };
144
145 fn middleware(proxy: &Arc<crate::proxy::ProxyState>) -> CspRewriteMiddleware {
146 CspRewriteMiddleware::new(proxy.rewrite_config.clone())
147 }
148
149 fn extract_result(resp: &Response) -> Value {
150 match resp {
151 Response::McpBuffered { message, .. } => {
152 message.envelope.result_as::<Value>().unwrap_or(Value::Null)
153 }
154 _ => panic!("expected McpBuffered"),
155 }
156 }
157
158 #[tokio::test]
159 async fn on_response__non_eligible_method_passthrough() {
160 let proxy = test_proxy_state();
161 let mut cx = test_context(proxy.clone());
162 let resp = mcp_buffered_response(
164 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"_meta":{"openai/widgetCSP":{"connect_domains":["http://upstream.test"]}}}]}}"#,
165 StatusCode::OK,
166 );
167
168 let out = middleware(&proxy).on_response(resp, &mut cx).await;
169 let result = extract_result(&out);
170 let connect = result["tools"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
172 .as_array()
173 .unwrap();
174 assert_eq!(connect[0].as_str(), Some("http://upstream.test"));
175 }
176
177 #[tokio::test]
178 async fn on_response__no_markers_passthrough_identity() {
179 let proxy = test_proxy_state();
180 let mut cx = test_context(proxy.clone());
181 set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
182 let resp = mcp_buffered_response(
183 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"name":"one"}]}}"#,
184 StatusCode::OK,
185 );
186
187 let out = middleware(&proxy).on_response(resp, &mut cx).await;
188 let result = extract_result(&out);
189 assert_eq!(result["tools"][0]["name"], "one");
190 }
191
192 #[tokio::test]
193 async fn on_response__markers_trigger_rewrite_for_tools_list() {
194 let proxy = test_proxy_state();
195 let mut cx = test_context(proxy.clone());
196 set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
197 let resp = mcp_buffered_response(
198 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"_meta":{"openai/widgetCSP":{"connect_domains":["http://upstream.test/api"]}}}]}}"#,
199 StatusCode::OK,
200 );
201
202 let out = middleware(&proxy).on_response(resp, &mut cx).await;
203 let result = extract_result(&out);
204 let connect = result["tools"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
205 .as_array()
206 .unwrap();
207 let rewritten = connect.iter().any(|v| {
208 v.as_str()
209 .map(|s| s.contains("proxy.test"))
210 .unwrap_or(false)
211 });
212 assert!(
213 rewritten,
214 "expected upstream rewritten into proxy URL, got {connect:?}"
215 );
216 }
217
218 #[tokio::test]
219 async fn on_response__resources_list_rewrites_csp() {
220 let proxy = test_proxy_state();
224 let mut cx = test_context(proxy.clone());
225 set_request_method(&mut cx, ClientMethod::Resources(ResourcesMethod::List));
226 let resp = mcp_buffered_response(
227 r#"{"jsonrpc":"2.0","id":1,"result":{"resources":[{"uri":"ui://widget/x","_meta":{"openai/widgetCSP":{"connect_domains":["http://localhost:9002"]}}}]}}"#,
228 StatusCode::OK,
229 );
230
231 let out = middleware(&proxy).on_response(resp, &mut cx).await;
232 let result = extract_result(&out);
233 let connect = result["resources"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
234 .as_array()
235 .unwrap();
236 assert!(
237 connect
238 .iter()
239 .any(|v| v.as_str() == Some("https://proxy.test")),
240 "expected proxy URL injected, got {connect:?}"
241 );
242 assert!(
243 !connect
244 .iter()
245 .any(|v| v.as_str().unwrap_or("").contains("localhost")),
246 "expected localhost stripped, got {connect:?}"
247 );
248 }
249
250 #[tokio::test]
251 async fn on_response__resources_templates_list_rewrites_csp() {
252 let proxy = test_proxy_state();
255 let mut cx = test_context(proxy.clone());
256 set_request_method(
257 &mut cx,
258 ClientMethod::Resources(ResourcesMethod::TemplatesList),
259 );
260 let resp = mcp_buffered_response(
261 r#"{"jsonrpc":"2.0","id":1,"result":{"resourceTemplates":[{"uriTemplate":"ui://widget/{n}","_meta":{"openai/widgetCSP":{"connect_domains":["http://localhost:9002"]}}}]}}"#,
262 StatusCode::OK,
263 );
264
265 let out = middleware(&proxy).on_response(resp, &mut cx).await;
266 let result = extract_result(&out);
267 let connect =
268 result["resourceTemplates"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
269 .as_array()
270 .unwrap();
271 assert!(
272 connect
273 .iter()
274 .any(|v| v.as_str() == Some("https://proxy.test")),
275 "expected proxy URL injected, got {connect:?}"
276 );
277 assert!(
278 !connect
279 .iter()
280 .any(|v| v.as_str().unwrap_or("").contains("localhost")),
281 "expected localhost stripped, got {connect:?}"
282 );
283 }
284
285 #[tokio::test]
286 async fn on_response__arc_swap_hot_reload_uses_new_config() {
287 let proxy = test_proxy_state();
288 let mut cx = test_context(proxy.clone());
289 set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
290
291 proxy.rewrite_config.store(Arc::new(RewriteConfig {
293 proxy_url: "https://proxy-v2.test".into(),
294 proxy_domain: "proxy-v2.test".into(),
295 mcp_upstream: "http://upstream.test".into(),
296 csp: CspConfig::default(),
297 }));
298
299 let resp = mcp_buffered_response(
300 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"_meta":{"openai/widgetCSP":{"connect_domains":["http://upstream.test/api"]}}}]}}"#,
301 StatusCode::OK,
302 );
303
304 let out = middleware(&proxy).on_response(resp, &mut cx).await;
305 let result = extract_result(&out);
306 let connect = result["tools"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
307 .as_array()
308 .unwrap();
309 let seen_v2 = connect.iter().any(|v| {
310 v.as_str()
311 .map(|s| s.contains("proxy-v2.test"))
312 .unwrap_or(false)
313 });
314 assert!(seen_v2, "expected v2 proxy host in rewritten output");
315 }
316
317 #[tokio::test]
318 async fn on_response__non_buffered_passthrough() {
319 let proxy = test_proxy_state();
320 let mut cx = test_context(proxy.clone());
321 set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
322 let resp = Response::Upstream502 { reason: "x".into() };
323
324 let out = middleware(&proxy).on_response(resp, &mut cx).await;
325 assert!(matches!(out, Response::Upstream502 { .. }));
326 }
327
328 #[test]
329 fn has_markers__finds_snake_case() {
330 assert!(has_markers(br#"{"connect_domains":["http://a"]}"#));
331 }
332
333 #[test]
334 fn has_markers__finds_openai_shape() {
335 assert!(has_markers(br#"{"openai/widgetCSP":{}}"#));
336 }
337
338 #[test]
339 fn has_markers__plain_tool_call_no_markers() {
340 assert!(!has_markers(
341 br#"{"content":[{"type":"text","text":"hi"}]}"#
342 ));
343 }
344}