Skip to main content

mcpr_core/proxy/pipeline/middlewares/
url_map.rs

1//! Response-side middleware: rewrite upstream base URLs to the proxy
2//! URL in OAuth discovery / JSON passthrough bodies.
3//!
4//! `Response::Raw` rewriting is gated on a JSON content-type header.
5//! Non-JSON `Raw` responses stream through untouched.
6
7use std::sync::Arc;
8
9use arc_swap::ArcSwap;
10use async_trait::async_trait;
11use axum::body::{Body, Bytes};
12use axum::http::HeaderMap;
13use axum::http::header::CONTENT_TYPE;
14
15use crate::proxy::RewriteConfig;
16use crate::proxy::pipeline::middleware::ResponseMiddleware;
17use crate::proxy::pipeline::values::{Context, Response};
18use crate::proxy::sse::split_upstream;
19
20pub struct UrlMapMiddleware {
21    config: Arc<ArcSwap<RewriteConfig>>,
22}
23
24impl UrlMapMiddleware {
25    pub fn new(config: Arc<ArcSwap<RewriteConfig>>) -> Self {
26        Self { config }
27    }
28}
29
30#[async_trait]
31impl ResponseMiddleware for UrlMapMiddleware {
32    fn name(&self) -> &'static str {
33        "url_map"
34    }
35
36    async fn on_response(&self, resp: Response, _cx: &mut Context) -> Response {
37        match resp {
38            Response::OauthJson {
39                doc,
40                status,
41                headers,
42            } => {
43                let bytes = serde_json::to_vec(&doc).unwrap_or_default();
44                let rewritten = rewrite_bytes(&self.config, Bytes::from(bytes));
45                let doc = serde_json::from_slice(&rewritten).unwrap_or(doc);
46                Response::OauthJson {
47                    doc,
48                    status,
49                    headers,
50                }
51            }
52            Response::Raw {
53                body,
54                status,
55                headers,
56            } if is_json(&headers) => {
57                let bytes = axum::body::to_bytes(body, usize::MAX)
58                    .await
59                    .unwrap_or_default();
60                let rewritten = rewrite_bytes(&self.config, bytes);
61                Response::Raw {
62                    body: Body::from(rewritten),
63                    status,
64                    headers,
65                }
66            }
67            other => other,
68        }
69    }
70}
71
72fn rewrite_bytes(config: &ArcSwap<RewriteConfig>, body: Bytes) -> Bytes {
73    let cfg = config.load();
74    let (upstream_base, _) = split_upstream(&cfg.mcp_upstream);
75    let upstream_base = upstream_base.trim_end_matches('/');
76    let proxy_url = cfg.proxy_url.trim_end_matches('/');
77
78    if !contains_slice(&body, upstream_base.as_bytes()) {
79        return body;
80    }
81
82    let body_str = String::from_utf8_lossy(&body);
83    Bytes::from(body_str.replace(upstream_base, proxy_url).into_bytes())
84}
85
86fn is_json(headers: &HeaderMap) -> bool {
87    headers
88        .get(CONTENT_TYPE)
89        .and_then(|v| v.to_str().ok())
90        .map(|ct| ct.contains("json"))
91        .unwrap_or(false)
92}
93
94fn contains_slice(haystack: &[u8], needle: &[u8]) -> bool {
95    if needle.is_empty() || haystack.len() < needle.len() {
96        return needle.is_empty();
97    }
98    haystack.windows(needle.len()).any(|win| win == needle)
99}
100
101#[cfg(test)]
102#[allow(non_snake_case)]
103mod tests {
104    use super::*;
105
106    use axum::http::StatusCode;
107    use serde_json::json;
108
109    use crate::proxy::pipeline::middlewares::test_support::{test_context, test_proxy_state};
110
111    fn middleware(proxy: &Arc<crate::proxy::ProxyState>) -> UrlMapMiddleware {
112        UrlMapMiddleware::new(proxy.rewrite_config.clone())
113    }
114
115    #[tokio::test]
116    async fn on_response__oauth_rewrites_upstream_to_proxy() {
117        let proxy = test_proxy_state();
118        let mut cx = test_context(proxy.clone());
119        let resp = Response::OauthJson {
120            doc: json!({"issuer": "http://upstream.test/auth"}),
121            status: StatusCode::OK,
122            headers: HeaderMap::new(),
123        };
124
125        let out = middleware(&proxy).on_response(resp, &mut cx).await;
126        match out {
127            Response::OauthJson { doc, .. } => {
128                assert_eq!(doc["issuer"].as_str(), Some("https://proxy.test/auth"));
129            }
130            _ => panic!("expected OauthJson"),
131        }
132    }
133
134    #[tokio::test]
135    async fn on_response__oauth_no_match_is_identity() {
136        let proxy = test_proxy_state();
137        let mut cx = test_context(proxy.clone());
138        let resp = Response::OauthJson {
139            doc: json!({"issuer": "http://other.example.com"}),
140            status: StatusCode::OK,
141            headers: HeaderMap::new(),
142        };
143
144        let out = middleware(&proxy).on_response(resp, &mut cx).await;
145        match out {
146            Response::OauthJson { doc, .. } => {
147                assert_eq!(doc["issuer"].as_str(), Some("http://other.example.com"));
148            }
149            _ => panic!("expected OauthJson"),
150        }
151    }
152
153    #[tokio::test]
154    async fn on_response__raw_json_rewrites() {
155        let proxy = test_proxy_state();
156        let mut cx = test_context(proxy.clone());
157        let mut headers = HeaderMap::new();
158        headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
159        let resp = Response::Raw {
160            body: Body::from(r#"{"url":"http://upstream.test/path"}"#),
161            status: StatusCode::OK,
162            headers,
163        };
164
165        let out = middleware(&proxy).on_response(resp, &mut cx).await;
166        match out {
167            Response::Raw { body, .. } => {
168                let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
169                let s = std::str::from_utf8(&bytes).unwrap();
170                assert!(
171                    s.contains("https://proxy.test/path"),
172                    "expected proxy url in {s}"
173                );
174            }
175            _ => panic!("expected Raw"),
176        }
177    }
178
179    #[tokio::test]
180    async fn on_response__raw_non_json_passthrough() {
181        let proxy = test_proxy_state();
182        let mut cx = test_context(proxy.clone());
183        let mut headers = HeaderMap::new();
184        headers.insert(CONTENT_TYPE, "text/html".parse().unwrap());
185        let resp = Response::Raw {
186            body: Body::from("http://upstream.test/"),
187            status: StatusCode::OK,
188            headers,
189        };
190
191        let out = middleware(&proxy).on_response(resp, &mut cx).await;
192        match out {
193            Response::Raw { body, .. } => {
194                let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
195                assert_eq!(
196                    std::str::from_utf8(&bytes).unwrap(),
197                    "http://upstream.test/"
198                );
199            }
200            _ => panic!("expected Raw"),
201        }
202    }
203
204    #[tokio::test]
205    async fn on_response__mcp_buffered_passthrough() {
206        let proxy = test_proxy_state();
207        let mut cx = test_context(proxy.clone());
208        let resp = Response::Upstream502 { reason: "x".into() };
209
210        let out = middleware(&proxy).on_response(resp, &mut cx).await;
211        assert!(matches!(out, Response::Upstream502 { .. }));
212    }
213}