axum_reverse_proxy/
proxy.rs

1use axum::body::Body;
2use http::uri::Builder as UriBuilder;
3use http::{StatusCode, Uri};
4use http_body_util::BodyExt;
5#[cfg(all(feature = "tls", not(feature = "native-tls")))]
6use hyper_rustls::HttpsConnector;
7#[cfg(feature = "native-tls")]
8use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
9use hyper_util::client::legacy::{
10    Client,
11    connect::{Connect, HttpConnector},
12};
13use std::convert::Infallible;
14use tracing::{error, trace};
15
16use crate::websocket;
17
18/// A reverse proxy that forwards HTTP requests to an upstream server.
19///
20/// The `ReverseProxy` struct handles the forwarding of HTTP requests from a specified path
21/// to a target upstream server. It manages its own HTTP client with configurable settings
22/// for connection pooling, timeouts, and retries.
23#[derive(Clone)]
24pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
25    path: String,
26    target: String,
27    client: Client<C, Body>,
28}
29
30#[cfg(all(feature = "tls", not(feature = "native-tls")))]
31pub type StandardReverseProxy = ReverseProxy<HttpsConnector<HttpConnector>>;
32#[cfg(feature = "native-tls")]
33pub type StandardReverseProxy = ReverseProxy<NativeTlsHttpsConnector<HttpConnector>>;
34#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
35pub type StandardReverseProxy = ReverseProxy<HttpConnector>;
36
37impl StandardReverseProxy {
38    /// Creates a new `ReverseProxy` instance.
39    ///
40    /// # Arguments
41    ///
42    /// * `path` - The base path to match incoming requests against (e.g., "/api")
43    /// * `target` - The upstream server URL to forward requests to (e.g., "https://api.example.com")
44    ///
45    /// # Example
46    ///
47    /// ```rust
48    /// use axum_reverse_proxy::ReverseProxy;
49    ///
50    /// let proxy = ReverseProxy::new("/api", "https://api.example.com");
51    /// ```
52    pub fn new<S>(path: S, target: S) -> Self
53    where
54        S: Into<String>,
55    {
56        let mut connector = HttpConnector::new();
57        connector.set_nodelay(true);
58        connector.enforce_http(false);
59        connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
60        connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
61        connector.set_reuse_address(true);
62
63        #[cfg(all(feature = "tls", not(feature = "native-tls")))]
64        let connector = {
65            use hyper_rustls::HttpsConnectorBuilder;
66            HttpsConnectorBuilder::new()
67                .with_webpki_roots()
68                .https_or_http()
69                .enable_http1()
70                .wrap_connector(connector)
71        };
72
73        #[cfg(feature = "native-tls")]
74        let connector = NativeTlsHttpsConnector::new_with_connector(connector);
75
76        let client = Client::builder(hyper_util::rt::TokioExecutor::new())
77            .pool_idle_timeout(std::time::Duration::from_secs(60))
78            .pool_max_idle_per_host(32)
79            .retry_canceled_requests(true)
80            .set_host(true)
81            .build(connector);
82
83        Self::new_with_client(path, target, client)
84    }
85}
86
87impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
88    /// Creates a new `ReverseProxy` instance with a custom HTTP client.
89    ///
90    /// This method allows for more fine-grained control over the proxy behavior by accepting
91    /// a pre-configured HTTP client.
92    ///
93    /// # Arguments
94    ///
95    /// * `path` - The base path to match incoming requests against
96    /// * `target` - The upstream server URL to forward requests to
97    /// * `client` - A custom-configured HTTP client
98    ///
99    /// # Example
100    ///
101    /// ```rust
102    /// use axum_reverse_proxy::ReverseProxy;
103    /// use hyper_util::client::legacy::{Client, connect::HttpConnector};
104    /// use axum::body::Body;
105    /// use hyper_util::rt::TokioExecutor;
106    ///
107    /// let client = Client::builder(TokioExecutor::new())
108    ///     .pool_idle_timeout(std::time::Duration::from_secs(120))
109    ///     .build(HttpConnector::new());
110    ///
111    /// let proxy = ReverseProxy::new_with_client(
112    ///     "/api",
113    ///     "https://api.example.com",
114    ///     client,
115    /// );
116    /// ```
117    pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
118    where
119        S: Into<String>,
120    {
121        Self {
122            path: path.into(),
123            target: target.into(),
124            client,
125        }
126    }
127
128    /// Get the base path this proxy is configured to handle
129    pub fn path(&self) -> &str {
130        &self.path
131    }
132
133    /// Get the target URL this proxy forwards requests to
134    pub fn target(&self) -> &str {
135        &self.target
136    }
137
138    /// Handles the proxying of a single request to the upstream server.
139    pub async fn proxy_request(
140        &self,
141        req: axum::http::Request<Body>,
142    ) -> Result<axum::http::Response<Body>, Infallible> {
143        self.handle_request(req).await
144    }
145
146    /// Core proxy logic used by the [`tower::Service`] implementation.
147    async fn handle_request(
148        &self,
149        req: axum::http::Request<Body>,
150    ) -> Result<axum::http::Response<Body>, Infallible> {
151        trace!("Proxying request method={} uri={}", req.method(), req.uri());
152        trace!("Original headers headers={:?}", req.headers());
153
154        // Check if this is a WebSocket upgrade request
155        if websocket::is_websocket_upgrade(req.headers()) {
156            trace!("Detected WebSocket upgrade request");
157            // Build the upstream HTTP URI first, then let WS code map scheme
158            let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
159            let upstream_http_uri = self.transform_uri(path_q);
160            match websocket::handle_websocket_with_upstream_uri(req, upstream_http_uri).await {
161                Ok(response) => return Ok(response),
162                Err(e) => {
163                    error!("Failed to handle WebSocket upgrade: {}", e);
164                    return Ok(axum::http::Response::builder()
165                        .status(StatusCode::INTERNAL_SERVER_ERROR)
166                        .body(Body::from(format!("WebSocket upgrade failed: {e}")))
167                        .unwrap());
168                }
169            }
170        }
171
172        let forward_req = {
173            let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
174            let upstream_uri = self.transform_uri(path_q);
175
176            let mut builder = axum::http::Request::builder()
177                .method(req.method().clone())
178                .uri(upstream_uri.clone());
179
180            // Forward headers
181            for (key, value) in req.headers() {
182                if key != "host" {
183                    builder = builder.header(key, value);
184                }
185            }
186
187            // Take the request body
188            let (parts, body) = req.into_parts();
189            drop(parts);
190            builder.body(body).unwrap()
191        };
192
193        trace!(
194            "Forwarding headers forwarded_headers={:?}",
195            forward_req.headers()
196        );
197
198        match self.client.request(forward_req).await {
199            Ok(res) => {
200                trace!(
201                    "Received response status={} headers={:?} version={:?}",
202                    res.status(),
203                    res.headers(),
204                    res.version()
205                );
206
207                let (parts, body) = res.into_parts();
208                let body = Body::from_stream(body.into_data_stream());
209
210                let mut response = axum::http::Response::new(body);
211                *response.status_mut() = parts.status;
212                *response.version_mut() = parts.version;
213                *response.headers_mut() = parts.headers;
214                Ok(response)
215            }
216            Err(e) => {
217                let error_msg = e.to_string();
218                error!("Proxy error occurred err={}", error_msg);
219                Ok(axum::http::Response::builder()
220                    .status(StatusCode::BAD_GATEWAY)
221                    .body(Body::from(format!(
222                        "Failed to connect to upstream server: {error_msg}"
223                    )))
224                    .unwrap())
225            }
226        }
227    }
228
229    /// Transform an incoming request path+query into the target URI using http::Uri builder
230    ///
231    /// Rules:
232    /// - Trim target trailing slash for joining
233    /// - Strip proxy base path at a boundary (exact or followed by '/')
234    /// - If remainder is exactly '/' under a non-empty base, treat as empty
235    /// - Do not add a slash for query-only joins (avoid target '/?')
236    fn transform_uri(&self, path_and_query: &str) -> Uri {
237        let base_path = self.path.trim_end_matches('/');
238
239        // Parse target URI
240        let target_uri: Uri = self
241            .target
242            .parse()
243            .expect("ReverseProxy target must be a valid URI");
244
245        let scheme = target_uri.scheme_str().unwrap_or("http");
246        let authority = target_uri
247            .authority()
248            .expect("ReverseProxy target must include authority (host)")
249            .as_str()
250            .to_string();
251
252        // Check if target originally had a trailing slash
253        let target_has_trailing_slash =
254            target_uri.path().ends_with('/') && target_uri.path() != "/";
255
256        // Normalize target base path: drop trailing slash and treat "/" as empty
257        let target_base_path = {
258            let p = target_uri.path();
259            if p == "/" {
260                ""
261            } else {
262                p.trim_end_matches('/')
263            }
264        };
265
266        // Split incoming path and query
267        let (path_part, query_part) = match path_and_query.find('?') {
268            Some(i) => (&path_and_query[..i], Some(&path_and_query[i + 1..])),
269            None => (path_and_query, None),
270        };
271
272        // Compute remainder after stripping base when applicable
273        let remaining_path = if path_part == "/" && !self.path.is_empty() {
274            ""
275        } else if !base_path.is_empty() && path_part.starts_with(base_path) {
276            let rem = &path_part[base_path.len()..];
277            if rem.is_empty() || rem.starts_with('/') {
278                rem
279            } else {
280                path_part
281            }
282        } else {
283            path_part
284        };
285
286        // Join target base path with remainder
287        let joined_path = if remaining_path.is_empty() {
288            if target_base_path.is_empty() {
289                "/"
290            } else if target_has_trailing_slash {
291                // Preserve trailing slash from target when no remaining path
292                "__TRAILING__"
293            } else {
294                target_base_path
295            }
296        } else {
297            // remaining_path starts with '/'; concatenate without duplicating slash
298            if target_base_path.is_empty() {
299                remaining_path
300            } else {
301                // allocate a small string to join
302                // SAFETY: both parts are valid path slices
303                // Build into a String for path_and_query
304                // We will rebuild below
305                // Placeholder; real joining below
306                "__JOIN__"
307            }
308        };
309
310        // Build final path_and_query string explicitly to keep exact bytes
311        let final_path = if joined_path == "__JOIN__" {
312            let mut s = String::with_capacity(target_base_path.len() + remaining_path.len());
313            s.push_str(target_base_path);
314            s.push_str(remaining_path);
315            s
316        } else if joined_path == "__TRAILING__" {
317            let mut s = String::with_capacity(target_base_path.len() + 1);
318            s.push_str(target_base_path);
319            s.push('/');
320            s
321        } else {
322            joined_path.to_string()
323        };
324
325        let mut path_and_query_buf = final_path;
326        if let Some(q) = query_part {
327            path_and_query_buf.push('?');
328            path_and_query_buf.push_str(q);
329        }
330
331        // Build the full URI
332        UriBuilder::new()
333            .scheme(scheme)
334            .authority(authority.as_str())
335            .path_and_query(path_and_query_buf.as_str())
336            .build()
337            .expect("Failed to build upstream URI")
338    }
339}
340
341use std::{
342    future::Future,
343    pin::Pin,
344    task::{Context, Poll},
345};
346use tower::Service;
347
348impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
349where
350    C: Connect + Clone + Send + Sync + 'static,
351{
352    type Response = axum::http::Response<Body>;
353    type Error = Infallible;
354    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
355
356    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
357        Poll::Ready(Ok(()))
358    }
359
360    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
361        let this = self.clone();
362        Box::pin(async move { this.handle_request(req).await })
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::StandardReverseProxy as ReverseProxy;
369
370    #[test]
371    fn transform_uri_with_and_without_trailing_slash() {
372        let proxy = ReverseProxy::new("/api/", "http://target");
373        assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
374
375        let proxy_no_slash = ReverseProxy::new("/api", "http://target");
376        assert_eq!(
377            proxy_no_slash.transform_uri("/api/test"),
378            "http://target/test"
379        );
380    }
381
382    #[test]
383    fn transform_uri_root() {
384        let proxy = ReverseProxy::new("/", "http://target");
385        assert_eq!(proxy.transform_uri("/test"), "http://target/test");
386    }
387
388    #[test]
389    fn transform_uri_with_query() {
390        let proxy_root = ReverseProxy::new("/", "http://target");
391
392        assert_eq!(
393            proxy_root.transform_uri("?query=test"),
394            "http://target?query=test"
395        );
396        assert_eq!(
397            proxy_root.transform_uri("/?query=test"),
398            "http://target/?query=test"
399        );
400        assert_eq!(
401            proxy_root.transform_uri("/test?query=test"),
402            "http://target/test?query=test"
403        );
404
405        let proxy_root_no_slash = ReverseProxy::new("/", "http://target/api");
406        assert_eq!(
407            proxy_root_no_slash.transform_uri("/test?query=test"),
408            "http://target/api/test?query=test"
409        );
410        assert_eq!(
411            proxy_root_no_slash.transform_uri("?query=test"),
412            "http://target/api?query=test"
413        );
414
415        let proxy_root_slash = ReverseProxy::new("/", "http://target/api/");
416        assert_eq!(
417            proxy_root_slash.transform_uri("/test?query=test"),
418            "http://target/api/test?query=test"
419        );
420        assert_eq!(
421            proxy_root_slash.transform_uri("?query=test"),
422            "http://target/api/?query=test"
423        );
424
425        let proxy_no_slash = ReverseProxy::new("/test", "http://target/api");
426        assert_eq!(
427            proxy_no_slash.transform_uri("/test?query=test"),
428            "http://target/api?query=test"
429        );
430        assert_eq!(
431            proxy_no_slash.transform_uri("/test/?query=test"),
432            "http://target/api/?query=test"
433        );
434        assert_eq!(
435            proxy_no_slash.transform_uri("?query=test"),
436            "http://target/api?query=test"
437        );
438
439        let proxy_with_slash = ReverseProxy::new("/test", "http://target/api/");
440        assert_eq!(
441            proxy_with_slash.transform_uri("/test?query=test"),
442            "http://target/api/?query=test"
443        );
444        assert_eq!(
445            proxy_with_slash.transform_uri("/test/?query=test"),
446            "http://target/api/?query=test"
447        );
448        assert_eq!(
449            proxy_with_slash.transform_uri("/something"),
450            "http://target/api/something"
451        );
452        assert_eq!(
453            proxy_with_slash.transform_uri("/test/something"),
454            "http://target/api/something"
455        );
456    }
457}