axum_reverse_proxy/
proxy.rs

1use axum::body::Body;
2use http::StatusCode;
3use http_body_util::BodyExt;
4#[cfg(all(feature = "tls", not(feature = "native-tls")))]
5use hyper_rustls::HttpsConnector;
6#[cfg(feature = "native-tls")]
7use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
8use hyper_util::client::legacy::{
9    Client,
10    connect::{Connect, HttpConnector},
11};
12use std::convert::Infallible;
13use tracing::{error, trace};
14
15use crate::websocket;
16
17/// A reverse proxy that forwards HTTP requests to an upstream server.
18///
19/// The `ReverseProxy` struct handles the forwarding of HTTP requests from a specified path
20/// to a target upstream server. It manages its own HTTP client with configurable settings
21/// for connection pooling, timeouts, and retries.
22#[derive(Clone)]
23pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
24    path: String,
25    target: String,
26    client: Client<C, Body>,
27}
28
29#[cfg(all(feature = "tls", not(feature = "native-tls")))]
30pub type StandardReverseProxy = ReverseProxy<HttpsConnector<HttpConnector>>;
31#[cfg(feature = "native-tls")]
32pub type StandardReverseProxy = ReverseProxy<NativeTlsHttpsConnector<HttpConnector>>;
33#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
34pub type StandardReverseProxy = ReverseProxy<HttpConnector>;
35
36impl StandardReverseProxy {
37    /// Creates a new `ReverseProxy` instance.
38    ///
39    /// # Arguments
40    ///
41    /// * `path` - The base path to match incoming requests against (e.g., "/api")
42    /// * `target` - The upstream server URL to forward requests to (e.g., "https://api.example.com")
43    ///
44    /// # Example
45    ///
46    /// ```rust
47    /// use axum_reverse_proxy::ReverseProxy;
48    ///
49    /// let proxy = ReverseProxy::new("/api", "https://api.example.com");
50    /// ```
51    pub fn new<S>(path: S, target: S) -> Self
52    where
53        S: Into<String>,
54    {
55        let mut connector = HttpConnector::new();
56        connector.set_nodelay(true);
57        connector.enforce_http(false);
58        connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
59        connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
60        connector.set_reuse_address(true);
61
62        #[cfg(all(feature = "tls", not(feature = "native-tls")))]
63        let connector = {
64            use hyper_rustls::HttpsConnectorBuilder;
65            HttpsConnectorBuilder::new()
66                .with_native_roots()
67                .unwrap()
68                .https_or_http()
69                .enable_http1()
70                .wrap_connector(connector)
71        };
72
73        #[cfg(feature = "native-tls")]
74        let connector = NativeTlsHttpsConnector::new_with_connector(connector);
75
76        let client = Client::builder(hyper_util::rt::TokioExecutor::new())
77            .pool_idle_timeout(std::time::Duration::from_secs(60))
78            .pool_max_idle_per_host(32)
79            .retry_canceled_requests(true)
80            .set_host(true)
81            .build(connector);
82
83        Self::new_with_client(path, target, client)
84    }
85}
86
87impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
88    /// Creates a new `ReverseProxy` instance with a custom HTTP client.
89    ///
90    /// This method allows for more fine-grained control over the proxy behavior by accepting
91    /// a pre-configured HTTP client.
92    ///
93    /// # Arguments
94    ///
95    /// * `path` - The base path to match incoming requests against
96    /// * `target` - The upstream server URL to forward requests to
97    /// * `client` - A custom-configured HTTP client
98    ///
99    /// # Example
100    ///
101    /// ```rust
102    /// use axum_reverse_proxy::ReverseProxy;
103    /// use hyper_util::client::legacy::{Client, connect::HttpConnector};
104    /// use axum::body::Body;
105    /// use hyper_util::rt::TokioExecutor;
106    ///
107    /// let client = Client::builder(TokioExecutor::new())
108    ///     .pool_idle_timeout(std::time::Duration::from_secs(120))
109    ///     .build(HttpConnector::new());
110    ///
111    /// let proxy = ReverseProxy::new_with_client(
112    ///     "/api",
113    ///     "https://api.example.com",
114    ///     client,
115    /// );
116    /// ```
117    pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
118    where
119        S: Into<String>,
120    {
121        Self {
122            path: path.into(),
123            target: target.into(),
124            client,
125        }
126    }
127
128    /// Get the base path this proxy is configured to handle
129    pub fn path(&self) -> &str {
130        &self.path
131    }
132
133    /// Get the target URL this proxy forwards requests to
134    pub fn target(&self) -> &str {
135        &self.target
136    }
137
138    /// Handles the proxying of a single request to the upstream server.
139    pub async fn proxy_request(
140        &self,
141        req: axum::http::Request<Body>,
142    ) -> Result<axum::http::Response<Body>, Infallible> {
143        self.handle_request(req).await
144    }
145
146    /// Core proxy logic used by the [`tower::Service`] implementation.
147    async fn handle_request(
148        &self,
149        req: axum::http::Request<Body>,
150    ) -> Result<axum::http::Response<Body>, Infallible> {
151        trace!("Proxying request method={} uri={}", req.method(), req.uri());
152        trace!("Original headers headers={:?}", req.headers());
153
154        // Check if this is a WebSocket upgrade request
155        if websocket::is_websocket_upgrade(req.headers()) {
156            trace!("Detected WebSocket upgrade request");
157            match websocket::handle_websocket(req, &self.target).await {
158                Ok(response) => return Ok(response),
159                Err(e) => {
160                    error!("Failed to handle WebSocket upgrade: {}", e);
161                    return Ok(axum::http::Response::builder()
162                        .status(StatusCode::INTERNAL_SERVER_ERROR)
163                        .body(Body::from(format!("WebSocket upgrade failed: {e}")))
164                        .unwrap());
165                }
166            }
167        }
168
169        let forward_req = {
170            let mut builder =
171                axum::http::Request::builder()
172                    .method(req.method().clone())
173                    .uri(self.transform_uri(
174                        req.uri().path_and_query().map(|x| x.as_str()).unwrap_or(""),
175                    ));
176
177            // Forward headers
178            for (key, value) in req.headers() {
179                if key != "host" {
180                    builder = builder.header(key, value);
181                }
182            }
183
184            // Take the request body
185            let (parts, body) = req.into_parts();
186            drop(parts);
187            builder.body(body).unwrap()
188        };
189
190        trace!(
191            "Forwarding headers forwarded_headers={:?}",
192            forward_req.headers()
193        );
194
195        match self.client.request(forward_req).await {
196            Ok(res) => {
197                trace!(
198                    "Received response status={} headers={:?} version={:?}",
199                    res.status(),
200                    res.headers(),
201                    res.version()
202                );
203
204                let (parts, body) = res.into_parts();
205                let body = Body::from_stream(body.into_data_stream());
206
207                let mut response = axum::http::Response::new(body);
208                *response.status_mut() = parts.status;
209                *response.version_mut() = parts.version;
210                *response.headers_mut() = parts.headers;
211                Ok(response)
212            }
213            Err(e) => {
214                let error_msg = e.to_string();
215                error!("Proxy error occurred err={}", error_msg);
216                Ok(axum::http::Response::builder()
217                    .status(StatusCode::BAD_GATEWAY)
218                    .body(Body::from(format!(
219                        "Failed to connect to upstream server: {error_msg}"
220                    )))
221                    .unwrap())
222            }
223        }
224    }
225
226    /// Transform an incoming request path into the target URI
227    fn transform_uri(&self, path: &str) -> String {
228        let target = self.target.trim_end_matches('/');
229        let base_path = self.path.trim_end_matches('/');
230
231        let remaining = if path == "/" && !self.path.is_empty() {
232            ""
233        } else if let Some(stripped) = path.strip_prefix(base_path) {
234            stripped
235        } else {
236            path
237        };
238
239        let mut uri = String::with_capacity(target.len() + remaining.len());
240        uri.push_str(target);
241        uri.push_str(remaining);
242        uri
243    }
244}
245
246use std::{
247    future::Future,
248    pin::Pin,
249    task::{Context, Poll},
250};
251use tower::Service;
252
253impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
254where
255    C: Connect + Clone + Send + Sync + 'static,
256{
257    type Response = axum::http::Response<Body>;
258    type Error = Infallible;
259    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
260
261    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262        Poll::Ready(Ok(()))
263    }
264
265    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
266        let this = self.clone();
267        Box::pin(async move { this.handle_request(req).await })
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::StandardReverseProxy as ReverseProxy;
274
275    #[test]
276    fn transform_uri_with_and_without_trailing_slash() {
277        let proxy = ReverseProxy::new("/api/", "http://target");
278        assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
279
280        let proxy_no_slash = ReverseProxy::new("/api", "http://target");
281        assert_eq!(
282            proxy_no_slash.transform_uri("/api/test"),
283            "http://target/test"
284        );
285    }
286
287    #[test]
288    fn transform_uri_root() {
289        let proxy = ReverseProxy::new("/", "http://target");
290        assert_eq!(proxy.transform_uri("/test"), "http://target/test");
291    }
292}