Skip to main content

mcpr_core/proxy/pipeline/middlewares/
csp_rewrite.rs

1//! Response-side middleware: rewrite widget CSP directives carried in
2//! list / call / read results.
3//!
4//! Holds the same `ArcSwap<RewriteConfig>` handle that `ProxyState`
5//! holds so `mcpr.toml` reloads swap the inner `Arc` without restarting
6//! the middleware.
7//!
8//! Fast path: byte-scan the raw `result` bytes for CSP-shaped keys
9//! (`connect_domains`, `openai/widgetCSP`, etc). Miss → no parse, no
10//! allocation. Hit → deserialize, mutate via `rewrite_response`,
11//! re-wrap into the message's `result` field.
12
13use std::sync::Arc;
14
15use arc_swap::ArcSwap;
16use async_trait::async_trait;
17use serde_json::Value;
18
19use crate::protocol::mcp::{ClientMethod, ResourcesMethod, ToolsMethod};
20use crate::proxy::pipeline::middleware::ResponseMiddleware;
21use crate::proxy::pipeline::values::{Context, Response};
22use crate::proxy::{RewriteConfig, rewrite_response};
23
24/// CSP-shaped keys that `rewrite_response` can mutate. If none of these
25/// appear as a substring in the `result` bytes, there is nothing to
26/// rewrite — skip the parse.
27const MARKERS: &[&[u8]] = &[
28    b"connect_domains",
29    b"resource_domains",
30    b"frame_domains",
31    b"connectDomains",
32    b"resourceDomains",
33    b"frameDomains",
34    b"openai/widgetCSP",
35    b"ui.csp",
36    b"openai/widgetDomain",
37];
38
39pub struct CspRewriteMiddleware {
40    config: Arc<ArcSwap<RewriteConfig>>,
41}
42
43impl CspRewriteMiddleware {
44    pub fn new(config: Arc<ArcSwap<RewriteConfig>>) -> Self {
45        Self { config }
46    }
47}
48
49#[async_trait]
50impl ResponseMiddleware for CspRewriteMiddleware {
51    fn name(&self) -> &'static str {
52        "csp_rewrite"
53    }
54
55    async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
56        let Response::McpBuffered {
57            envelope,
58            mut message,
59            status,
60            headers,
61        } = resp
62        else {
63            return resp;
64        };
65
66        let eligible = matches!(
67            cx.working.request_method,
68            Some(ClientMethod::Tools(ToolsMethod::List))
69                | Some(ClientMethod::Tools(ToolsMethod::Call))
70                | Some(ClientMethod::Resources(ResourcesMethod::List))
71                | Some(ClientMethod::Resources(ResourcesMethod::TemplatesList))
72                | Some(ClientMethod::Resources(ResourcesMethod::Read))
73        );
74        let raw_bytes = message.envelope.result.as_ref().map(|r| r.get().as_bytes());
75        let should_rewrite = eligible && raw_bytes.map(has_markers).unwrap_or(false);
76        if !should_rewrite {
77            return Response::McpBuffered {
78                envelope,
79                message,
80                status,
81                headers,
82            };
83        }
84
85        let method_str = cx
86            .working
87            .request_method
88            .as_ref()
89            .and_then(crate::protocol::mcp::ClientMethod::as_str)
90            .unwrap_or("");
91        let Ok(result_val) = serde_json::from_slice::<Value>(raw_bytes.unwrap()) else {
92            return Response::McpBuffered {
93                envelope,
94                message,
95                status,
96                headers,
97            };
98        };
99
100        let mut wrapper = Value::Object(Default::default());
101        wrapper["result"] = result_val;
102        let cfg = self.config.load();
103        if rewrite_response(method_str, &mut wrapper, &cfg) {
104            let rewritten = wrapper
105                .get_mut("result")
106                .map(std::mem::take)
107                .unwrap_or(Value::Null);
108            if let Ok(boxed) = serde_json::value::to_raw_value(&rewritten) {
109                message.envelope.result = Some(boxed);
110            }
111        }
112
113        Response::McpBuffered {
114            envelope,
115            message,
116            status,
117            headers,
118        }
119    }
120}
121
122fn has_markers(body: &[u8]) -> bool {
123    MARKERS.iter().any(|m| contains_slice(body, m))
124}
125
126fn contains_slice(haystack: &[u8], needle: &[u8]) -> bool {
127    if needle.is_empty() || haystack.len() < needle.len() {
128        return needle.is_empty();
129    }
130    haystack.windows(needle.len()).any(|win| win == needle)
131}
132
133#[cfg(test)]
134#[allow(non_snake_case)]
135mod tests {
136    use super::*;
137
138    use axum::http::StatusCode;
139
140    use crate::proxy::CspConfig;
141    use crate::proxy::pipeline::middlewares::test_support::{
142        mcp_buffered_response, set_request_method, test_context, test_proxy_state,
143    };
144
145    fn middleware(proxy: &Arc<crate::proxy::ProxyState>) -> CspRewriteMiddleware {
146        CspRewriteMiddleware::new(proxy.rewrite_config.clone())
147    }
148
149    fn extract_result(resp: &Response) -> Value {
150        match resp {
151            Response::McpBuffered { message, .. } => {
152                message.envelope.result_as::<Value>().unwrap_or(Value::Null)
153            }
154            _ => panic!("expected McpBuffered"),
155        }
156    }
157
158    #[tokio::test]
159    async fn on_response__non_eligible_method_passthrough() {
160        let proxy = test_proxy_state();
161        let mut cx = test_context(proxy.clone());
162        // No request_method stashed → ineligible.
163        let resp = mcp_buffered_response(
164            r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"_meta":{"openai/widgetCSP":{"connect_domains":["http://upstream.test"]}}}]}}"#,
165            StatusCode::OK,
166        );
167
168        let out = middleware(&proxy).on_response(resp, &mut cx).await;
169        let result = extract_result(&out);
170        // Identity: connect_domains still points at upstream, unchanged.
171        let connect = result["tools"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
172            .as_array()
173            .unwrap();
174        assert_eq!(connect[0].as_str(), Some("http://upstream.test"));
175    }
176
177    #[tokio::test]
178    async fn on_response__no_markers_passthrough_identity() {
179        let proxy = test_proxy_state();
180        let mut cx = test_context(proxy.clone());
181        set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
182        let resp = mcp_buffered_response(
183            r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"name":"one"}]}}"#,
184            StatusCode::OK,
185        );
186
187        let out = middleware(&proxy).on_response(resp, &mut cx).await;
188        let result = extract_result(&out);
189        assert_eq!(result["tools"][0]["name"], "one");
190    }
191
192    #[tokio::test]
193    async fn on_response__markers_trigger_rewrite_for_tools_list() {
194        let proxy = test_proxy_state();
195        let mut cx = test_context(proxy.clone());
196        set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
197        let resp = mcp_buffered_response(
198            r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"_meta":{"openai/widgetCSP":{"connect_domains":["http://upstream.test/api"]}}}]}}"#,
199            StatusCode::OK,
200        );
201
202        let out = middleware(&proxy).on_response(resp, &mut cx).await;
203        let result = extract_result(&out);
204        let connect = result["tools"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
205            .as_array()
206            .unwrap();
207        let rewritten = connect.iter().any(|v| {
208            v.as_str()
209                .map(|s| s.contains("proxy.test"))
210                .unwrap_or(false)
211        });
212        assert!(
213            rewritten,
214            "expected upstream rewritten into proxy URL, got {connect:?}"
215        );
216    }
217
218    #[tokio::test]
219    async fn on_response__resources_list_rewrites_csp() {
220        // Regression: the eligibility gate must accept resources/list, or the
221        // operator's csp.domain never replaces upstream localhost in widget
222        // resource listings.
223        let proxy = test_proxy_state();
224        let mut cx = test_context(proxy.clone());
225        set_request_method(&mut cx, ClientMethod::Resources(ResourcesMethod::List));
226        let resp = mcp_buffered_response(
227            r#"{"jsonrpc":"2.0","id":1,"result":{"resources":[{"uri":"ui://widget/x","_meta":{"openai/widgetCSP":{"connect_domains":["http://localhost:9002"]}}}]}}"#,
228            StatusCode::OK,
229        );
230
231        let out = middleware(&proxy).on_response(resp, &mut cx).await;
232        let result = extract_result(&out);
233        let connect = result["resources"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
234            .as_array()
235            .unwrap();
236        assert!(
237            connect
238                .iter()
239                .any(|v| v.as_str() == Some("https://proxy.test")),
240            "expected proxy URL injected, got {connect:?}"
241        );
242        assert!(
243            !connect
244                .iter()
245                .any(|v| v.as_str().unwrap_or("").contains("localhost")),
246            "expected localhost stripped, got {connect:?}"
247        );
248    }
249
250    #[tokio::test]
251    async fn on_response__resources_templates_list_rewrites_csp() {
252        // Regression: same gate fix for resources/templates/list. Without it,
253        // ChatGPT's template fetch sees raw upstream CSP with localhost.
254        let proxy = test_proxy_state();
255        let mut cx = test_context(proxy.clone());
256        set_request_method(
257            &mut cx,
258            ClientMethod::Resources(ResourcesMethod::TemplatesList),
259        );
260        let resp = mcp_buffered_response(
261            r#"{"jsonrpc":"2.0","id":1,"result":{"resourceTemplates":[{"uriTemplate":"ui://widget/{n}","_meta":{"openai/widgetCSP":{"connect_domains":["http://localhost:9002"]}}}]}}"#,
262            StatusCode::OK,
263        );
264
265        let out = middleware(&proxy).on_response(resp, &mut cx).await;
266        let result = extract_result(&out);
267        let connect =
268            result["resourceTemplates"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
269                .as_array()
270                .unwrap();
271        assert!(
272            connect
273                .iter()
274                .any(|v| v.as_str() == Some("https://proxy.test")),
275            "expected proxy URL injected, got {connect:?}"
276        );
277        assert!(
278            !connect
279                .iter()
280                .any(|v| v.as_str().unwrap_or("").contains("localhost")),
281            "expected localhost stripped, got {connect:?}"
282        );
283    }
284
285    #[tokio::test]
286    async fn on_response__arc_swap_hot_reload_uses_new_config() {
287        let proxy = test_proxy_state();
288        let mut cx = test_context(proxy.clone());
289        set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
290
291        // Swap to a config with a different proxy_url.
292        proxy.rewrite_config.store(Arc::new(RewriteConfig {
293            proxy_url: "https://proxy-v2.test".into(),
294            proxy_domain: "proxy-v2.test".into(),
295            mcp_upstream: "http://upstream.test".into(),
296            csp: CspConfig::default(),
297        }));
298
299        let resp = mcp_buffered_response(
300            r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"_meta":{"openai/widgetCSP":{"connect_domains":["http://upstream.test/api"]}}}]}}"#,
301            StatusCode::OK,
302        );
303
304        let out = middleware(&proxy).on_response(resp, &mut cx).await;
305        let result = extract_result(&out);
306        let connect = result["tools"][0]["_meta"]["openai/widgetCSP"]["connect_domains"]
307            .as_array()
308            .unwrap();
309        let seen_v2 = connect.iter().any(|v| {
310            v.as_str()
311                .map(|s| s.contains("proxy-v2.test"))
312                .unwrap_or(false)
313        });
314        assert!(seen_v2, "expected v2 proxy host in rewritten output");
315    }
316
317    #[tokio::test]
318    async fn on_response__non_buffered_passthrough() {
319        let proxy = test_proxy_state();
320        let mut cx = test_context(proxy.clone());
321        set_request_method(&mut cx, ClientMethod::Tools(ToolsMethod::List));
322        let resp = Response::Upstream502 { reason: "x".into() };
323
324        let out = middleware(&proxy).on_response(resp, &mut cx).await;
325        assert!(matches!(out, Response::Upstream502 { .. }));
326    }
327
328    #[test]
329    fn has_markers__finds_snake_case() {
330        assert!(has_markers(br#"{"connect_domains":["http://a"]}"#));
331    }
332
333    #[test]
334    fn has_markers__finds_openai_shape() {
335        assert!(has_markers(br#"{"openai/widgetCSP":{}}"#));
336    }
337
338    #[test]
339    fn has_markers__plain_tool_call_no_markers() {
340        assert!(!has_markers(
341            br#"{"content":[{"type":"text","text":"hi"}]}"#
342        ));
343    }
344}