axum_reverse_proxy/
proxy.rs1use axum::body::Body;
2use http::StatusCode;
3use http_body_util::BodyExt;
4#[cfg(all(feature = "tls", not(feature = "native-tls")))]
5use hyper_rustls::HttpsConnector;
6#[cfg(feature = "native-tls")]
7use hyper_tls::HttpsConnector as NativeTlsHttpsConnector;
8use hyper_util::client::legacy::{
9 Client,
10 connect::{Connect, HttpConnector},
11};
12use std::convert::Infallible;
13use tracing::{error, trace};
14
15use crate::websocket;
16
17#[derive(Clone)]
23pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
24 path: String,
25 target: String,
26 client: Client<C, Body>,
27}
28
29#[cfg(all(feature = "tls", not(feature = "native-tls")))]
30pub type StandardReverseProxy = ReverseProxy<HttpsConnector<HttpConnector>>;
31#[cfg(feature = "native-tls")]
32pub type StandardReverseProxy = ReverseProxy<NativeTlsHttpsConnector<HttpConnector>>;
33#[cfg(all(not(feature = "tls"), not(feature = "native-tls")))]
34pub type StandardReverseProxy = ReverseProxy<HttpConnector>;
35
36impl StandardReverseProxy {
37 pub fn new<S>(path: S, target: S) -> Self
52 where
53 S: Into<String>,
54 {
55 let mut connector = HttpConnector::new();
56 connector.set_nodelay(true);
57 connector.enforce_http(false);
58 connector.set_keepalive(Some(std::time::Duration::from_secs(60)));
59 connector.set_connect_timeout(Some(std::time::Duration::from_secs(10)));
60 connector.set_reuse_address(true);
61
62 #[cfg(all(feature = "tls", not(feature = "native-tls")))]
63 let connector = {
64 use hyper_rustls::HttpsConnectorBuilder;
65 HttpsConnectorBuilder::new()
66 .with_native_roots()
67 .unwrap()
68 .https_or_http()
69 .enable_http1()
70 .wrap_connector(connector)
71 };
72
73 #[cfg(feature = "native-tls")]
74 let connector = NativeTlsHttpsConnector::new_with_connector(connector);
75
76 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
77 .pool_idle_timeout(std::time::Duration::from_secs(60))
78 .pool_max_idle_per_host(32)
79 .retry_canceled_requests(true)
80 .set_host(true)
81 .build(connector);
82
83 Self::new_with_client(path, target, client)
84 }
85}
86
87impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
88 pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
118 where
119 S: Into<String>,
120 {
121 Self {
122 path: path.into(),
123 target: target.into(),
124 client,
125 }
126 }
127
128 pub fn path(&self) -> &str {
130 &self.path
131 }
132
133 pub fn target(&self) -> &str {
135 &self.target
136 }
137
138 pub async fn proxy_request(
140 &self,
141 req: axum::http::Request<Body>,
142 ) -> Result<axum::http::Response<Body>, Infallible> {
143 self.handle_request(req).await
144 }
145
146 async fn handle_request(
148 &self,
149 req: axum::http::Request<Body>,
150 ) -> Result<axum::http::Response<Body>, Infallible> {
151 trace!("Proxying request method={} uri={}", req.method(), req.uri());
152 trace!("Original headers headers={:?}", req.headers());
153
154 if websocket::is_websocket_upgrade(req.headers()) {
156 trace!("Detected WebSocket upgrade request");
157 match websocket::handle_websocket(req, &self.target).await {
158 Ok(response) => return Ok(response),
159 Err(e) => {
160 error!("Failed to handle WebSocket upgrade: {}", e);
161 return Ok(axum::http::Response::builder()
162 .status(StatusCode::INTERNAL_SERVER_ERROR)
163 .body(Body::from(format!("WebSocket upgrade failed: {e}")))
164 .unwrap());
165 }
166 }
167 }
168
169 let forward_req = {
170 let mut builder =
171 axum::http::Request::builder()
172 .method(req.method().clone())
173 .uri(self.transform_uri(
174 req.uri().path_and_query().map(|x| x.as_str()).unwrap_or(""),
175 ));
176
177 for (key, value) in req.headers() {
179 if key != "host" {
180 builder = builder.header(key, value);
181 }
182 }
183
184 let (parts, body) = req.into_parts();
186 drop(parts);
187 builder.body(body).unwrap()
188 };
189
190 trace!(
191 "Forwarding headers forwarded_headers={:?}",
192 forward_req.headers()
193 );
194
195 match self.client.request(forward_req).await {
196 Ok(res) => {
197 trace!(
198 "Received response status={} headers={:?} version={:?}",
199 res.status(),
200 res.headers(),
201 res.version()
202 );
203
204 let (parts, body) = res.into_parts();
205 let body = Body::from_stream(body.into_data_stream());
206
207 let mut response = axum::http::Response::new(body);
208 *response.status_mut() = parts.status;
209 *response.version_mut() = parts.version;
210 *response.headers_mut() = parts.headers;
211 Ok(response)
212 }
213 Err(e) => {
214 let error_msg = e.to_string();
215 error!("Proxy error occurred err={}", error_msg);
216 Ok(axum::http::Response::builder()
217 .status(StatusCode::BAD_GATEWAY)
218 .body(Body::from(format!(
219 "Failed to connect to upstream server: {error_msg}"
220 )))
221 .unwrap())
222 }
223 }
224 }
225
226 fn transform_uri(&self, path: &str) -> String {
228 let target = self.target.trim_end_matches('/');
229 let base_path = self.path.trim_end_matches('/');
230
231 let remaining = if path == "/" && !self.path.is_empty() {
232 ""
233 } else if let Some(stripped) = path.strip_prefix(base_path) {
234 stripped
235 } else {
236 path
237 };
238
239 let mut uri = String::with_capacity(target.len() + remaining.len());
240 uri.push_str(target);
241 uri.push_str(remaining);
242 uri
243 }
244}
245
246use std::{
247 future::Future,
248 pin::Pin,
249 task::{Context, Poll},
250};
251use tower::Service;
252
253impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
254where
255 C: Connect + Clone + Send + Sync + 'static,
256{
257 type Response = axum::http::Response<Body>;
258 type Error = Infallible;
259 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
260
261 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262 Poll::Ready(Ok(()))
263 }
264
265 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
266 let this = self.clone();
267 Box::pin(async move { this.handle_request(req).await })
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::StandardReverseProxy as ReverseProxy;
274
275 #[test]
276 fn transform_uri_with_and_without_trailing_slash() {
277 let proxy = ReverseProxy::new("/api/", "http://target");
278 assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
279
280 let proxy_no_slash = ReverseProxy::new("/api", "http://target");
281 assert_eq!(
282 proxy_no_slash.transform_uri("/api/test"),
283 "http://target/test"
284 );
285 }
286
287 #[test]
288 fn transform_uri_root() {
289 let proxy = ReverseProxy::new("/", "http://target");
290 assert_eq!(proxy.transform_uri("/test"), "http://target/test");
291 }
292}