mcpr_core/proxy/pipeline/middlewares/
url_map.rs1use std::sync::Arc;
8
9use arc_swap::ArcSwap;
10use async_trait::async_trait;
11use axum::body::{Body, Bytes};
12use axum::http::HeaderMap;
13use axum::http::header::CONTENT_TYPE;
14
15use crate::proxy::RewriteConfig;
16use crate::proxy::pipeline::middleware::ResponseMiddleware;
17use crate::proxy::pipeline::values::{Context, Response};
18use crate::proxy::sse::split_upstream;
19
20pub struct UrlMapMiddleware {
21 config: Arc<ArcSwap<RewriteConfig>>,
22}
23
24impl UrlMapMiddleware {
25 pub fn new(config: Arc<ArcSwap<RewriteConfig>>) -> Self {
26 Self { config }
27 }
28}
29
30#[async_trait]
31impl ResponseMiddleware for UrlMapMiddleware {
32 fn name(&self) -> &'static str {
33 "url_map"
34 }
35
36 async fn on_response(&self, resp: Response, _cx: &mut Context) -> Response {
37 match resp {
38 Response::OauthJson {
39 doc,
40 status,
41 headers,
42 } => {
43 let bytes = serde_json::to_vec(&doc).unwrap_or_default();
44 let rewritten = rewrite_bytes(&self.config, Bytes::from(bytes));
45 let doc = serde_json::from_slice(&rewritten).unwrap_or(doc);
46 Response::OauthJson {
47 doc,
48 status,
49 headers,
50 }
51 }
52 Response::Raw {
53 body,
54 status,
55 headers,
56 } if is_json(&headers) => {
57 let bytes = axum::body::to_bytes(body, usize::MAX)
58 .await
59 .unwrap_or_default();
60 let rewritten = rewrite_bytes(&self.config, bytes);
61 Response::Raw {
62 body: Body::from(rewritten),
63 status,
64 headers,
65 }
66 }
67 other => other,
68 }
69 }
70}
71
72fn rewrite_bytes(config: &ArcSwap<RewriteConfig>, body: Bytes) -> Bytes {
73 let cfg = config.load();
74 let (upstream_base, _) = split_upstream(&cfg.mcp_upstream);
75 let upstream_base = upstream_base.trim_end_matches('/');
76 let proxy_url = cfg.proxy_url.trim_end_matches('/');
77
78 if !contains_slice(&body, upstream_base.as_bytes()) {
79 return body;
80 }
81
82 let body_str = String::from_utf8_lossy(&body);
83 Bytes::from(body_str.replace(upstream_base, proxy_url).into_bytes())
84}
85
86fn is_json(headers: &HeaderMap) -> bool {
87 headers
88 .get(CONTENT_TYPE)
89 .and_then(|v| v.to_str().ok())
90 .map(|ct| ct.contains("json"))
91 .unwrap_or(false)
92}
93
94fn contains_slice(haystack: &[u8], needle: &[u8]) -> bool {
95 if needle.is_empty() || haystack.len() < needle.len() {
96 return needle.is_empty();
97 }
98 haystack.windows(needle.len()).any(|win| win == needle)
99}
100
101#[cfg(test)]
102#[allow(non_snake_case)]
103mod tests {
104 use super::*;
105
106 use axum::http::StatusCode;
107 use serde_json::json;
108
109 use crate::proxy::pipeline::middlewares::test_support::{test_context, test_proxy_state};
110
111 fn middleware(proxy: &Arc<crate::proxy::ProxyState>) -> UrlMapMiddleware {
112 UrlMapMiddleware::new(proxy.rewrite_config.clone())
113 }
114
115 #[tokio::test]
116 async fn on_response__oauth_rewrites_upstream_to_proxy() {
117 let proxy = test_proxy_state();
118 let mut cx = test_context(proxy.clone());
119 let resp = Response::OauthJson {
120 doc: json!({"issuer": "http://upstream.test/auth"}),
121 status: StatusCode::OK,
122 headers: HeaderMap::new(),
123 };
124
125 let out = middleware(&proxy).on_response(resp, &mut cx).await;
126 match out {
127 Response::OauthJson { doc, .. } => {
128 assert_eq!(doc["issuer"].as_str(), Some("https://proxy.test/auth"));
129 }
130 _ => panic!("expected OauthJson"),
131 }
132 }
133
134 #[tokio::test]
135 async fn on_response__oauth_no_match_is_identity() {
136 let proxy = test_proxy_state();
137 let mut cx = test_context(proxy.clone());
138 let resp = Response::OauthJson {
139 doc: json!({"issuer": "http://other.example.com"}),
140 status: StatusCode::OK,
141 headers: HeaderMap::new(),
142 };
143
144 let out = middleware(&proxy).on_response(resp, &mut cx).await;
145 match out {
146 Response::OauthJson { doc, .. } => {
147 assert_eq!(doc["issuer"].as_str(), Some("http://other.example.com"));
148 }
149 _ => panic!("expected OauthJson"),
150 }
151 }
152
153 #[tokio::test]
154 async fn on_response__raw_json_rewrites() {
155 let proxy = test_proxy_state();
156 let mut cx = test_context(proxy.clone());
157 let mut headers = HeaderMap::new();
158 headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
159 let resp = Response::Raw {
160 body: Body::from(r#"{"url":"http://upstream.test/path"}"#),
161 status: StatusCode::OK,
162 headers,
163 };
164
165 let out = middleware(&proxy).on_response(resp, &mut cx).await;
166 match out {
167 Response::Raw { body, .. } => {
168 let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
169 let s = std::str::from_utf8(&bytes).unwrap();
170 assert!(
171 s.contains("https://proxy.test/path"),
172 "expected proxy url in {s}"
173 );
174 }
175 _ => panic!("expected Raw"),
176 }
177 }
178
179 #[tokio::test]
180 async fn on_response__raw_non_json_passthrough() {
181 let proxy = test_proxy_state();
182 let mut cx = test_context(proxy.clone());
183 let mut headers = HeaderMap::new();
184 headers.insert(CONTENT_TYPE, "text/html".parse().unwrap());
185 let resp = Response::Raw {
186 body: Body::from("http://upstream.test/"),
187 status: StatusCode::OK,
188 headers,
189 };
190
191 let out = middleware(&proxy).on_response(resp, &mut cx).await;
192 match out {
193 Response::Raw { body, .. } => {
194 let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
195 assert_eq!(
196 std::str::from_utf8(&bytes).unwrap(),
197 "http://upstream.test/"
198 );
199 }
200 _ => panic!("expected Raw"),
201 }
202 }
203
204 #[tokio::test]
205 async fn on_response__mcp_buffered_passthrough() {
206 let proxy = test_proxy_state();
207 let mut cx = test_context(proxy.clone());
208 let resp = Response::Upstream502 { reason: "x".into() };
209
210 let out = middleware(&proxy).on_response(resp, &mut cx).await;
211 assert!(matches!(out, Response::Upstream502 { .. }));
212 }
213}