axum_reverse_proxy/
proxy.rs

1use axum::body::Body;
2use http::Uri;
3use http::uri::Builder as UriBuilder;
4use hyper_util::client::legacy::{Client, connect::Connect};
5use std::convert::Infallible;
6use tracing::trace;
7
8use crate::forward::{ProxyConnector, create_http_connector, forward_request};
9
10/// A reverse proxy that forwards HTTP requests to an upstream server.
11///
12/// The `ReverseProxy` struct handles the forwarding of HTTP requests from a specified path
13/// to a target upstream server. It manages its own HTTP client with configurable settings
14/// for connection pooling, timeouts, and retries.
15#[derive(Clone)]
16pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
17    path: String,
18    target: String,
19    client: Client<C, Body>,
20}
21
22pub type StandardReverseProxy = ReverseProxy<ProxyConnector>;
23
24impl StandardReverseProxy {
25    /// Creates a new `ReverseProxy` instance.
26    ///
27    /// # Arguments
28    ///
29    /// * `path` - The base path to match incoming requests against (e.g., "/api")
30    /// * `target` - The upstream server URL to forward requests to (e.g., "https://api.example.com")
31    ///
32    /// # Example
33    ///
34    /// ```rust
35    /// use axum_reverse_proxy::ReverseProxy;
36    ///
37    /// let proxy = ReverseProxy::new("/api", "https://api.example.com");
38    /// ```
39    pub fn new<S>(path: S, target: S) -> Self
40    where
41        S: Into<String>,
42    {
43        let client = Client::builder(hyper_util::rt::TokioExecutor::new())
44            .pool_idle_timeout(std::time::Duration::from_secs(60))
45            .pool_max_idle_per_host(32)
46            .retry_canceled_requests(true)
47            .set_host(true)
48            .build(create_http_connector());
49
50        Self::new_with_client(path, target, client)
51    }
52}
53
54impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
55    /// Creates a new `ReverseProxy` instance with a custom HTTP client.
56    ///
57    /// This method allows for more fine-grained control over the proxy behavior by accepting
58    /// a pre-configured HTTP client.
59    ///
60    /// # Arguments
61    ///
62    /// * `path` - The base path to match incoming requests against
63    /// * `target` - The upstream server URL to forward requests to
64    /// * `client` - A custom-configured HTTP client
65    ///
66    /// # Example
67    ///
68    /// ```rust
69    /// use axum_reverse_proxy::ReverseProxy;
70    /// use hyper_util::client::legacy::{Client, connect::HttpConnector};
71    /// use axum::body::Body;
72    /// use hyper_util::rt::TokioExecutor;
73    ///
74    /// let client = Client::builder(TokioExecutor::new())
75    ///     .pool_idle_timeout(std::time::Duration::from_secs(120))
76    ///     .build(HttpConnector::new());
77    ///
78    /// let proxy = ReverseProxy::new_with_client(
79    ///     "/api",
80    ///     "https://api.example.com",
81    ///     client,
82    /// );
83    /// ```
84    pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
85    where
86        S: Into<String>,
87    {
88        Self {
89            path: path.into(),
90            target: target.into(),
91            client,
92        }
93    }
94
95    /// Get the base path this proxy is configured to handle
96    pub fn path(&self) -> &str {
97        &self.path
98    }
99
100    /// Get the target URL this proxy forwards requests to
101    pub fn target(&self) -> &str {
102        &self.target
103    }
104
105    /// Handles the proxying of a single request to the upstream server.
106    pub async fn proxy_request(
107        &self,
108        req: axum::http::Request<Body>,
109    ) -> Result<axum::http::Response<Body>, Infallible> {
110        self.handle_request(req).await
111    }
112
113    /// Core proxy logic used by the [`tower::Service`] implementation.
114    async fn handle_request(
115        &self,
116        req: axum::http::Request<Body>,
117    ) -> Result<axum::http::Response<Body>, Infallible> {
118        trace!("Proxying request method={} uri={}", req.method(), req.uri());
119        trace!("Original headers headers={:?}", req.headers());
120
121        // Transform the URI to the upstream target
122        let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
123        let upstream_uri = self.transform_uri(path_q);
124
125        // Use shared forwarding logic
126        forward_request(upstream_uri, req, &self.client).await
127    }
128
129    /// Transform an incoming request path+query into the target URI using http::Uri builder
130    ///
131    /// Rules:
132    /// - Trim target trailing slash for joining
133    /// - Strip proxy base path at a boundary (exact or followed by '/')
134    /// - If remainder is exactly '/' under a non-empty base, treat as empty
135    /// - Do not add a slash for query-only joins (avoid target '/?')
136    fn transform_uri(&self, path_and_query: &str) -> Uri {
137        let base_path = self.path.trim_end_matches('/');
138
139        // Parse target URI
140        let target_uri: Uri = self
141            .target
142            .parse()
143            .expect("ReverseProxy target must be a valid URI");
144
145        let scheme = target_uri.scheme_str().unwrap_or("http");
146        let authority = target_uri
147            .authority()
148            .expect("ReverseProxy target must include authority (host)")
149            .as_str()
150            .to_string();
151
152        // Check if target originally had a trailing slash
153        let target_has_trailing_slash =
154            target_uri.path().ends_with('/') && target_uri.path() != "/";
155
156        // Normalize target base path: drop trailing slash and treat "/" as empty
157        let target_base_path = {
158            let p = target_uri.path();
159            if p == "/" {
160                ""
161            } else {
162                p.trim_end_matches('/')
163            }
164        };
165
166        // Split incoming path and query
167        let (path_part, query_part) = match path_and_query.find('?') {
168            Some(i) => (&path_and_query[..i], Some(&path_and_query[i + 1..])),
169            None => (path_and_query, None),
170        };
171
172        // Compute remainder after stripping base when applicable
173        let remaining_path = if path_part == "/" && !self.path.is_empty() {
174            ""
175        } else if !base_path.is_empty() && path_part.starts_with(base_path) {
176            let rem = &path_part[base_path.len()..];
177            if rem.is_empty() || rem.starts_with('/') {
178                rem
179            } else {
180                path_part
181            }
182        } else {
183            path_part
184        };
185
186        // Join target base path with remainder
187        let joined_path = if remaining_path.is_empty() {
188            if target_base_path.is_empty() {
189                "/"
190            } else if target_has_trailing_slash {
191                // Preserve trailing slash from target when no remaining path
192                "__TRAILING__"
193            } else {
194                target_base_path
195            }
196        } else {
197            // remaining_path starts with '/'; concatenate without duplicating slash
198            if target_base_path.is_empty() {
199                remaining_path
200            } else {
201                // allocate a small string to join
202                // SAFETY: both parts are valid path slices
203                // Build into a String for path_and_query
204                // We will rebuild below
205                // Placeholder; real joining below
206                "__JOIN__"
207            }
208        };
209
210        // Build final path_and_query string explicitly to keep exact bytes
211        let final_path = if joined_path == "__JOIN__" {
212            let mut s = String::with_capacity(target_base_path.len() + remaining_path.len());
213            s.push_str(target_base_path);
214            s.push_str(remaining_path);
215            s
216        } else if joined_path == "__TRAILING__" {
217            let mut s = String::with_capacity(target_base_path.len() + 1);
218            s.push_str(target_base_path);
219            s.push('/');
220            s
221        } else {
222            joined_path.to_string()
223        };
224
225        let mut path_and_query_buf = final_path;
226        if let Some(q) = query_part {
227            path_and_query_buf.push('?');
228            path_and_query_buf.push_str(q);
229        }
230
231        // Build the full URI
232        UriBuilder::new()
233            .scheme(scheme)
234            .authority(authority.as_str())
235            .path_and_query(path_and_query_buf.as_str())
236            .build()
237            .expect("Failed to build upstream URI")
238    }
239}
240
241use std::{
242    future::Future,
243    pin::Pin,
244    task::{Context, Poll},
245};
246use tower::Service;
247
248impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
249where
250    C: Connect + Clone + Send + Sync + 'static,
251{
252    type Response = axum::http::Response<Body>;
253    type Error = Infallible;
254    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
255
256    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
257        Poll::Ready(Ok(()))
258    }
259
260    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
261        let this = self.clone();
262        Box::pin(async move { this.handle_request(req).await })
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::StandardReverseProxy as ReverseProxy;
269
270    #[test]
271    fn transform_uri_with_and_without_trailing_slash() {
272        let proxy = ReverseProxy::new("/api/", "http://target");
273        assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
274
275        let proxy_no_slash = ReverseProxy::new("/api", "http://target");
276        assert_eq!(
277            proxy_no_slash.transform_uri("/api/test"),
278            "http://target/test"
279        );
280    }
281
282    #[test]
283    fn transform_uri_root() {
284        let proxy = ReverseProxy::new("/", "http://target");
285        assert_eq!(proxy.transform_uri("/test"), "http://target/test");
286    }
287
288    #[test]
289    fn transform_uri_with_query() {
290        let proxy_root = ReverseProxy::new("/", "http://target");
291
292        assert_eq!(
293            proxy_root.transform_uri("?query=test"),
294            "http://target?query=test"
295        );
296        assert_eq!(
297            proxy_root.transform_uri("/?query=test"),
298            "http://target/?query=test"
299        );
300        assert_eq!(
301            proxy_root.transform_uri("/test?query=test"),
302            "http://target/test?query=test"
303        );
304
305        let proxy_root_no_slash = ReverseProxy::new("/", "http://target/api");
306        assert_eq!(
307            proxy_root_no_slash.transform_uri("/test?query=test"),
308            "http://target/api/test?query=test"
309        );
310        assert_eq!(
311            proxy_root_no_slash.transform_uri("?query=test"),
312            "http://target/api?query=test"
313        );
314
315        let proxy_root_slash = ReverseProxy::new("/", "http://target/api/");
316        assert_eq!(
317            proxy_root_slash.transform_uri("/test?query=test"),
318            "http://target/api/test?query=test"
319        );
320        assert_eq!(
321            proxy_root_slash.transform_uri("?query=test"),
322            "http://target/api/?query=test"
323        );
324
325        let proxy_no_slash = ReverseProxy::new("/test", "http://target/api");
326        assert_eq!(
327            proxy_no_slash.transform_uri("/test?query=test"),
328            "http://target/api?query=test"
329        );
330        assert_eq!(
331            proxy_no_slash.transform_uri("/test/?query=test"),
332            "http://target/api/?query=test"
333        );
334        assert_eq!(
335            proxy_no_slash.transform_uri("?query=test"),
336            "http://target/api?query=test"
337        );
338
339        let proxy_with_slash = ReverseProxy::new("/test", "http://target/api/");
340        assert_eq!(
341            proxy_with_slash.transform_uri("/test?query=test"),
342            "http://target/api/?query=test"
343        );
344        assert_eq!(
345            proxy_with_slash.transform_uri("/test/?query=test"),
346            "http://target/api/?query=test"
347        );
348        assert_eq!(
349            proxy_with_slash.transform_uri("/something"),
350            "http://target/api/something"
351        );
352        assert_eq!(
353            proxy_with_slash.transform_uri("/test/something"),
354            "http://target/api/something"
355        );
356    }
357}