axum_reverse_proxy/
proxy.rs1use axum::body::Body;
2use http::uri::Builder as UriBuilder;
3use http::{StatusCode, Uri};
4use http_body_util::BodyExt;
5#[cfg(all(feature = "tls", not(feature = "native-tls")))]
6use hyper_rustls::HttpsConnector;
7#[cfg(feature = "native-tls")]
8use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
9use hyper_util::client::legacy::{
10 Client,
11 connect::{Connect, HttpConnector},
12};
13use std::convert::Infallible;
14use tracing::{error, trace};
15
16use crate::websocket;
17
18#[derive(Clone)]
24pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
25 path: String,
26 target: String,
27 client: Client<C, Body>,
28}
29
30#[cfg(all(feature = "tls", not(feature = "native-tls")))]
31pub type StandardReverseProxy = ReverseProxy<HttpsConnector<HttpConnector>>;
32#[cfg(feature = "native-tls")]
33pub type StandardReverseProxy = ReverseProxy<NativeTlsHttpsConnector<HttpConnector>>;
34#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
35pub type StandardReverseProxy = ReverseProxy<HttpConnector>;
36
37impl StandardReverseProxy {
38 pub fn new<S>(path: S, target: S) -> Self
53 where
54 S: Into<String>,
55 {
56 let mut connector = HttpConnector::new();
57 connector.set_nodelay(true);
58 connector.enforce_http(false);
59 connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
60 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
61 connector.set_reuse_address(true);
62
63 #[cfg(all(feature = "tls", not(feature = "native-tls")))]
64 let connector = {
65 use hyper_rustls::HttpsConnectorBuilder;
66 HttpsConnectorBuilder::new()
67 .with_webpki_roots()
68 .https_or_http()
69 .enable_http1()
70 .wrap_connector(connector)
71 };
72
73 #[cfg(feature = "native-tls")]
74 let connector = NativeTlsHttpsConnector::new_with_connector(connector);
75
76 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
77 .pool_idle_timeout(std::time::Duration::from_secs(60))
78 .pool_max_idle_per_host(32)
79 .retry_canceled_requests(true)
80 .set_host(true)
81 .build(connector);
82
83 Self::new_with_client(path, target, client)
84 }
85}
86
87impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
88 pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
118 where
119 S: Into<String>,
120 {
121 Self {
122 path: path.into(),
123 target: target.into(),
124 client,
125 }
126 }
127
128 pub fn path(&self) -> &str {
130 &self.path
131 }
132
133 pub fn target(&self) -> &str {
135 &self.target
136 }
137
138 pub async fn proxy_request(
140 &self,
141 req: axum::http::Request<Body>,
142 ) -> Result<axum::http::Response<Body>, Infallible> {
143 self.handle_request(req).await
144 }
145
146 async fn handle_request(
148 &self,
149 req: axum::http::Request<Body>,
150 ) -> Result<axum::http::Response<Body>, Infallible> {
151 trace!("Proxying request method={} uri={}", req.method(), req.uri());
152 trace!("Original headers headers={:?}", req.headers());
153
154 if websocket::is_websocket_upgrade(req.headers()) {
156 trace!("Detected WebSocket upgrade request");
157 let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
159 let upstream_http_uri = self.transform_uri(path_q);
160 match websocket::handle_websocket_with_upstream_uri(req, upstream_http_uri).await {
161 Ok(response) => return Ok(response),
162 Err(e) => {
163 error!("Failed to handle WebSocket upgrade: {}", e);
164 return Ok(axum::http::Response::builder()
165 .status(StatusCode::INTERNAL_SERVER_ERROR)
166 .body(Body::from(format!("WebSocket upgrade failed: {e}")))
167 .unwrap());
168 }
169 }
170 }
171
172 let forward_req = {
173 let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
174 let upstream_uri = self.transform_uri(path_q);
175
176 let mut builder = axum::http::Request::builder()
177 .method(req.method().clone())
178 .uri(upstream_uri.clone());
179
180 for (key, value) in req.headers() {
182 if key != "host" {
183 builder = builder.header(key, value);
184 }
185 }
186
187 let (parts, body) = req.into_parts();
189 drop(parts);
190 builder.body(body).unwrap()
191 };
192
193 trace!(
194 "Forwarding headers forwarded_headers={:?}",
195 forward_req.headers()
196 );
197
198 match self.client.request(forward_req).await {
199 Ok(res) => {
200 trace!(
201 "Received response status={} headers={:?} version={:?}",
202 res.status(),
203 res.headers(),
204 res.version()
205 );
206
207 let (parts, body) = res.into_parts();
208 let body = Body::from_stream(body.into_data_stream());
209
210 let mut response = axum::http::Response::new(body);
211 *response.status_mut() = parts.status;
212 *response.version_mut() = parts.version;
213 *response.headers_mut() = parts.headers;
214 Ok(response)
215 }
216 Err(e) => {
217 let error_msg = e.to_string();
218 error!("Proxy error occurred err={}", error_msg);
219 Ok(axum::http::Response::builder()
220 .status(StatusCode::BAD_GATEWAY)
221 .body(Body::from(format!(
222 "Failed to connect to upstream server: {error_msg}"
223 )))
224 .unwrap())
225 }
226 }
227 }
228
229 fn transform_uri(&self, path_and_query: &str) -> Uri {
237 let base_path = self.path.trim_end_matches('/');
238
239 let target_uri: Uri = self
241 .target
242 .parse()
243 .expect("ReverseProxy target must be a valid URI");
244
245 let scheme = target_uri.scheme_str().unwrap_or("http");
246 let authority = target_uri
247 .authority()
248 .expect("ReverseProxy target must include authority (host)")
249 .as_str()
250 .to_string();
251
252 let target_has_trailing_slash =
254 target_uri.path().ends_with('/') && target_uri.path() != "/";
255
256 let target_base_path = {
258 let p = target_uri.path();
259 if p == "/" {
260 ""
261 } else {
262 p.trim_end_matches('/')
263 }
264 };
265
266 let (path_part, query_part) = match path_and_query.find('?') {
268 Some(i) => (&path_and_query[..i], Some(&path_and_query[i + 1..])),
269 None => (path_and_query, None),
270 };
271
272 let remaining_path = if path_part == "/" && !self.path.is_empty() {
274 ""
275 } else if !base_path.is_empty() && path_part.starts_with(base_path) {
276 let rem = &path_part[base_path.len()..];
277 if rem.is_empty() || rem.starts_with('/') {
278 rem
279 } else {
280 path_part
281 }
282 } else {
283 path_part
284 };
285
286 let joined_path = if remaining_path.is_empty() {
288 if target_base_path.is_empty() {
289 "/"
290 } else if target_has_trailing_slash {
291 "__TRAILING__"
293 } else {
294 target_base_path
295 }
296 } else {
297 if target_base_path.is_empty() {
299 remaining_path
300 } else {
301 "__JOIN__"
307 }
308 };
309
310 let final_path = if joined_path == "__JOIN__" {
312 let mut s = String::with_capacity(target_base_path.len() + remaining_path.len());
313 s.push_str(target_base_path);
314 s.push_str(remaining_path);
315 s
316 } else if joined_path == "__TRAILING__" {
317 let mut s = String::with_capacity(target_base_path.len() + 1);
318 s.push_str(target_base_path);
319 s.push('/');
320 s
321 } else {
322 joined_path.to_string()
323 };
324
325 let mut path_and_query_buf = final_path;
326 if let Some(q) = query_part {
327 path_and_query_buf.push('?');
328 path_and_query_buf.push_str(q);
329 }
330
331 UriBuilder::new()
333 .scheme(scheme)
334 .authority(authority.as_str())
335 .path_and_query(path_and_query_buf.as_str())
336 .build()
337 .expect("Failed to build upstream URI")
338 }
339}
340
341use std::{
342 future::Future,
343 pin::Pin,
344 task::{Context, Poll},
345};
346use tower::Service;
347
348impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
349where
350 C: Connect + Clone + Send + Sync + 'static,
351{
352 type Response = axum::http::Response<Body>;
353 type Error = Infallible;
354 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
355
356 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
357 Poll::Ready(Ok(()))
358 }
359
360 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
361 let this = self.clone();
362 Box::pin(async move { this.handle_request(req).await })
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::StandardReverseProxy as ReverseProxy;
369
370 #[test]
371 fn transform_uri_with_and_without_trailing_slash() {
372 let proxy = ReverseProxy::new("/api/", "http://target");
373 assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
374
375 let proxy_no_slash = ReverseProxy::new("/api", "http://target");
376 assert_eq!(
377 proxy_no_slash.transform_uri("/api/test"),
378 "http://target/test"
379 );
380 }
381
382 #[test]
383 fn transform_uri_root() {
384 let proxy = ReverseProxy::new("/", "http://target");
385 assert_eq!(proxy.transform_uri("/test"), "http://target/test");
386 }
387
388 #[test]
389 fn transform_uri_with_query() {
390 let proxy_root = ReverseProxy::new("/", "http://target");
391
392 assert_eq!(
393 proxy_root.transform_uri("?query=test"),
394 "http://target?query=test"
395 );
396 assert_eq!(
397 proxy_root.transform_uri("/?query=test"),
398 "http://target/?query=test"
399 );
400 assert_eq!(
401 proxy_root.transform_uri("/test?query=test"),
402 "http://target/test?query=test"
403 );
404
405 let proxy_root_no_slash = ReverseProxy::new("/", "http://target/api");
406 assert_eq!(
407 proxy_root_no_slash.transform_uri("/test?query=test"),
408 "http://target/api/test?query=test"
409 );
410 assert_eq!(
411 proxy_root_no_slash.transform_uri("?query=test"),
412 "http://target/api?query=test"
413 );
414
415 let proxy_root_slash = ReverseProxy::new("/", "http://target/api/");
416 assert_eq!(
417 proxy_root_slash.transform_uri("/test?query=test"),
418 "http://target/api/test?query=test"
419 );
420 assert_eq!(
421 proxy_root_slash.transform_uri("?query=test"),
422 "http://target/api/?query=test"
423 );
424
425 let proxy_no_slash = ReverseProxy::new("/test", "http://target/api");
426 assert_eq!(
427 proxy_no_slash.transform_uri("/test?query=test"),
428 "http://target/api?query=test"
429 );
430 assert_eq!(
431 proxy_no_slash.transform_uri("/test/?query=test"),
432 "http://target/api/?query=test"
433 );
434 assert_eq!(
435 proxy_no_slash.transform_uri("?query=test"),
436 "http://target/api?query=test"
437 );
438
439 let proxy_with_slash = ReverseProxy::new("/test", "http://target/api/");
440 assert_eq!(
441 proxy_with_slash.transform_uri("/test?query=test"),
442 "http://target/api/?query=test"
443 );
444 assert_eq!(
445 proxy_with_slash.transform_uri("/test/?query=test"),
446 "http://target/api/?query=test"
447 );
448 assert_eq!(
449 proxy_with_slash.transform_uri("/something"),
450 "http://target/api/something"
451 );
452 assert_eq!(
453 proxy_with_slash.transform_uri("/test/something"),
454 "http://target/api/something"
455 );
456 }
457}