Skip to main content

axum_reverse_proxy/
proxy.rs

1use axum::body::Body;
2use http::Uri;
3use http::uri::Builder as UriBuilder;
4use hyper_util::client::legacy::{Client, connect::Connect};
5use std::convert::Infallible;
6use tracing::trace;
7
8use crate::forward::{ProxyConnector, create_http_connector, forward_request};
9
10/// A reverse proxy that forwards HTTP requests to an upstream server.
11///
12/// The `ReverseProxy` struct handles the forwarding of HTTP requests from a specified path
13/// to a target upstream server. It manages its own HTTP client with configurable settings
14/// for connection pooling, timeouts, and retries.
15#[derive(Clone)]
16pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
17    path: String,
18    target: String,
19    client: Client<C, Body>,
20}
21
22pub type StandardReverseProxy = ReverseProxy<ProxyConnector>;
23
24impl StandardReverseProxy {
25    /// Creates a new `ReverseProxy` instance.
26    ///
27    /// # Arguments
28    ///
29    /// * `path` - The base path to match incoming requests against (e.g., "/api")
30    /// * `target` - The upstream server URL to forward requests to (e.g., "https://api.example.com")
31    ///
32    /// # Example
33    ///
34    /// ```rust
35    /// use axum_reverse_proxy::ReverseProxy;
36    ///
37    /// let proxy = ReverseProxy::new("/api", "https://api.example.com");
38    /// ```
39    pub fn new<S>(path: S, target: S) -> Self
40    where
41        S: Into<String>,
42    {
43        let client = Client::builder(hyper_util::rt::TokioExecutor::new())
44            .pool_idle_timeout(std::time::Duration::from_secs(60))
45            .pool_max_idle_per_host(32)
46            .retry_canceled_requests(true)
47            .set_host(true)
48            .build(create_http_connector());
49
50        Self::new_with_client(path, target, client)
51    }
52}
53
54impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
55    /// Creates a new `ReverseProxy` instance with a custom HTTP client.
56    ///
57    /// This method allows for more fine-grained control over the proxy behavior by accepting
58    /// a pre-configured HTTP client.
59    ///
60    /// # Arguments
61    ///
62    /// * `path` - The base path to match incoming requests against
63    /// * `target` - The upstream server URL to forward requests to
64    /// * `client` - A custom-configured HTTP client
65    ///
66    /// # Example
67    ///
68    /// ```rust
69    /// use axum_reverse_proxy::ReverseProxy;
70    /// use hyper_util::client::legacy::{Client, connect::HttpConnector};
71    /// use axum::body::Body;
72    /// use hyper_util::rt::TokioExecutor;
73    ///
74    /// let client = Client::builder(TokioExecutor::new())
75    ///     .pool_idle_timeout(std::time::Duration::from_secs(120))
76    ///     .build(HttpConnector::new());
77    ///
78    /// let proxy = ReverseProxy::new_with_client(
79    ///     "/api",
80    ///     "https://api.example.com",
81    ///     client,
82    /// );
83    /// ```
84    pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
85    where
86        S: Into<String>,
87    {
88        Self {
89            path: path.into(),
90            target: target.into(),
91            client,
92        }
93    }
94
95    /// Get the base path this proxy is configured to handle
96    pub fn path(&self) -> &str {
97        &self.path
98    }
99
100    /// Get the target URL this proxy forwards requests to
101    pub fn target(&self) -> &str {
102        &self.target
103    }
104
105    /// Handles the proxying of a single request to the upstream server.
106    pub async fn proxy_request(
107        &self,
108        req: axum::http::Request<Body>,
109    ) -> Result<axum::http::Response<Body>, Infallible> {
110        self.handle_request(req).await
111    }
112
113    /// Core proxy logic used by the [`tower::Service`] implementation.
114    async fn handle_request(
115        &self,
116        req: axum::http::Request<Body>,
117    ) -> Result<axum::http::Response<Body>, Infallible> {
118        trace!("Proxying request method={} uri={}", req.method(), req.uri());
119
120        // Transform the URI to the upstream target
121        let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
122        let upstream_uri = self.transform_uri(path_q);
123
124        // Use shared forwarding logic
125        forward_request(upstream_uri, req, &self.client).await
126    }
127
128    /// Transform an incoming request path+query into the target URI using http::Uri builder
129    ///
130    /// Rules:
131    /// - Trim target trailing slash for joining
132    /// - Strip proxy base path at a boundary (exact or followed by '/')
133    /// - If remainder is exactly '/' under a non-empty base, treat as empty
134    /// - Do not add a slash for query-only joins (avoid target '/?')
135    fn transform_uri(&self, path_and_query: &str) -> Uri {
136        let base_path = self.path.trim_end_matches('/');
137
138        // Parse target URI
139        let target_uri: Uri = self
140            .target
141            .parse()
142            .expect("ReverseProxy target must be a valid URI");
143
144        let scheme = target_uri.scheme_str().unwrap_or("http");
145        let authority = target_uri
146            .authority()
147            .expect("ReverseProxy target must include authority (host)")
148            .as_str()
149            .to_string();
150
151        // Check if target originally had a trailing slash
152        let target_has_trailing_slash =
153            target_uri.path().ends_with('/') && target_uri.path() != "/";
154
155        // Normalize target base path: drop trailing slash and treat "/" as empty
156        let target_base_path = {
157            let p = target_uri.path();
158            if p == "/" {
159                ""
160            } else {
161                p.trim_end_matches('/')
162            }
163        };
164
165        // Split incoming path and query
166        let (path_part, query_part) = match path_and_query.find('?') {
167            Some(i) => (&path_and_query[..i], Some(&path_and_query[i + 1..])),
168            None => (path_and_query, None),
169        };
170
171        // Compute remainder after stripping base when applicable
172        let remaining_path = if path_part == "/" && !self.path.is_empty() {
173            ""
174        } else if !base_path.is_empty() && path_part.starts_with(base_path) {
175            let rem = &path_part[base_path.len()..];
176            if rem.is_empty() || rem.starts_with('/') {
177                rem
178            } else {
179                path_part
180            }
181        } else {
182            path_part
183        };
184
185        // Join target base path with remainder
186        let joined_path = if remaining_path.is_empty() {
187            if target_base_path.is_empty() {
188                "/"
189            } else if target_has_trailing_slash {
190                // Preserve trailing slash from target when no remaining path
191                "__TRAILING__"
192            } else {
193                target_base_path
194            }
195        } else {
196            // remaining_path starts with '/'; concatenate without duplicating slash
197            if target_base_path.is_empty() {
198                remaining_path
199            } else {
200                // allocate a small string to join
201                // SAFETY: both parts are valid path slices
202                // Build into a String for path_and_query
203                // We will rebuild below
204                // Placeholder; real joining below
205                "__JOIN__"
206            }
207        };
208
209        // Build final path_and_query string explicitly to keep exact bytes
210        let final_path = if joined_path == "__JOIN__" {
211            let mut s = String::with_capacity(target_base_path.len() + remaining_path.len());
212            s.push_str(target_base_path);
213            s.push_str(remaining_path);
214            s
215        } else if joined_path == "__TRAILING__" {
216            let mut s = String::with_capacity(target_base_path.len() + 1);
217            s.push_str(target_base_path);
218            s.push('/');
219            s
220        } else {
221            joined_path.to_string()
222        };
223
224        let mut path_and_query_buf = final_path;
225        if let Some(q) = query_part {
226            path_and_query_buf.push('?');
227            path_and_query_buf.push_str(q);
228        }
229
230        // Build the full URI
231        UriBuilder::new()
232            .scheme(scheme)
233            .authority(authority.as_str())
234            .path_and_query(path_and_query_buf.as_str())
235            .build()
236            .expect("Failed to build upstream URI")
237    }
238}
239
240use std::{
241    future::Future,
242    pin::Pin,
243    task::{Context, Poll},
244};
245use tower::Service;
246
247impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
248where
249    C: Connect + Clone + Send + Sync + 'static,
250{
251    type Response = axum::http::Response<Body>;
252    type Error = Infallible;
253    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
254
255    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
256        Poll::Ready(Ok(()))
257    }
258
259    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
260        let this = self.clone();
261        Box::pin(async move { this.handle_request(req).await })
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::StandardReverseProxy as ReverseProxy;
268
269    #[test]
270    fn transform_uri_with_and_without_trailing_slash() {
271        let proxy = ReverseProxy::new("/api/", "http://target");
272        assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
273
274        let proxy_no_slash = ReverseProxy::new("/api", "http://target");
275        assert_eq!(
276            proxy_no_slash.transform_uri("/api/test"),
277            "http://target/test"
278        );
279    }
280
281    #[test]
282    fn transform_uri_root() {
283        let proxy = ReverseProxy::new("/", "http://target");
284        assert_eq!(proxy.transform_uri("/test"), "http://target/test");
285    }
286
287    #[test]
288    fn transform_uri_with_query() {
289        let proxy_root = ReverseProxy::new("/", "http://target");
290
291        assert_eq!(
292            proxy_root.transform_uri("?query=test"),
293            "http://target?query=test"
294        );
295        assert_eq!(
296            proxy_root.transform_uri("/?query=test"),
297            "http://target/?query=test"
298        );
299        assert_eq!(
300            proxy_root.transform_uri("/test?query=test"),
301            "http://target/test?query=test"
302        );
303
304        let proxy_root_no_slash = ReverseProxy::new("/", "http://target/api");
305        assert_eq!(
306            proxy_root_no_slash.transform_uri("/test?query=test"),
307            "http://target/api/test?query=test"
308        );
309        assert_eq!(
310            proxy_root_no_slash.transform_uri("?query=test"),
311            "http://target/api?query=test"
312        );
313
314        let proxy_root_slash = ReverseProxy::new("/", "http://target/api/");
315        assert_eq!(
316            proxy_root_slash.transform_uri("/test?query=test"),
317            "http://target/api/test?query=test"
318        );
319        assert_eq!(
320            proxy_root_slash.transform_uri("?query=test"),
321            "http://target/api/?query=test"
322        );
323
324        let proxy_no_slash = ReverseProxy::new("/test", "http://target/api");
325        assert_eq!(
326            proxy_no_slash.transform_uri("/test?query=test"),
327            "http://target/api?query=test"
328        );
329        assert_eq!(
330            proxy_no_slash.transform_uri("/test/?query=test"),
331            "http://target/api/?query=test"
332        );
333        assert_eq!(
334            proxy_no_slash.transform_uri("?query=test"),
335            "http://target/api?query=test"
336        );
337
338        let proxy_with_slash = ReverseProxy::new("/test", "http://target/api/");
339        assert_eq!(
340            proxy_with_slash.transform_uri("/test?query=test"),
341            "http://target/api/?query=test"
342        );
343        assert_eq!(
344            proxy_with_slash.transform_uri("/test/?query=test"),
345            "http://target/api/?query=test"
346        );
347        assert_eq!(
348            proxy_with_slash.transform_uri("/something"),
349            "http://target/api/something"
350        );
351        assert_eq!(
352            proxy_with_slash.transform_uri("/test/something"),
353            "http://target/api/something"
354        );
355    }
356}