axum_reverse_proxy/
proxy.rs1use axum::body::Body;
2use http::Uri;
3use http::uri::Builder as UriBuilder;
4use hyper_util::client::legacy::{Client, connect::Connect};
5use std::convert::Infallible;
6use tracing::trace;
7
8use crate::forward::{ProxyConnector, create_http_connector, forward_request};
9
10#[derive(Clone)]
16pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
17 path: String,
18 target: String,
19 client: Client<C, Body>,
20}
21
22pub type StandardReverseProxy = ReverseProxy<ProxyConnector>;
23
24impl StandardReverseProxy {
25 pub fn new<S>(path: S, target: S) -> Self
40 where
41 S: Into<String>,
42 {
43 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
44 .pool_idle_timeout(std::time::Duration::from_secs(60))
45 .pool_max_idle_per_host(32)
46 .retry_canceled_requests(true)
47 .set_host(true)
48 .build(create_http_connector());
49
50 Self::new_with_client(path, target, client)
51 }
52}
53
54impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
55 pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
85 where
86 S: Into<String>,
87 {
88 Self {
89 path: path.into(),
90 target: target.into(),
91 client,
92 }
93 }
94
95 pub fn path(&self) -> &str {
97 &self.path
98 }
99
100 pub fn target(&self) -> &str {
102 &self.target
103 }
104
105 pub async fn proxy_request(
107 &self,
108 req: axum::http::Request<Body>,
109 ) -> Result<axum::http::Response<Body>, Infallible> {
110 self.handle_request(req).await
111 }
112
113 async fn handle_request(
115 &self,
116 req: axum::http::Request<Body>,
117 ) -> Result<axum::http::Response<Body>, Infallible> {
118 trace!("Proxying request method={} uri={}", req.method(), req.uri());
119 trace!("Original headers headers={:?}", req.headers());
120
121 let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
123 let upstream_uri = self.transform_uri(path_q);
124
125 forward_request(upstream_uri, req, &self.client).await
127 }
128
129 fn transform_uri(&self, path_and_query: &str) -> Uri {
137 let base_path = self.path.trim_end_matches('/');
138
139 let target_uri: Uri = self
141 .target
142 .parse()
143 .expect("ReverseProxy target must be a valid URI");
144
145 let scheme = target_uri.scheme_str().unwrap_or("http");
146 let authority = target_uri
147 .authority()
148 .expect("ReverseProxy target must include authority (host)")
149 .as_str()
150 .to_string();
151
152 let target_has_trailing_slash =
154 target_uri.path().ends_with('/') && target_uri.path() != "/";
155
156 let target_base_path = {
158 let p = target_uri.path();
159 if p == "/" {
160 ""
161 } else {
162 p.trim_end_matches('/')
163 }
164 };
165
166 let (path_part, query_part) = match path_and_query.find('?') {
168 Some(i) => (&path_and_query[..i], Some(&path_and_query[i + 1..])),
169 None => (path_and_query, None),
170 };
171
172 let remaining_path = if path_part == "/" && !self.path.is_empty() {
174 ""
175 } else if !base_path.is_empty() && path_part.starts_with(base_path) {
176 let rem = &path_part[base_path.len()..];
177 if rem.is_empty() || rem.starts_with('/') {
178 rem
179 } else {
180 path_part
181 }
182 } else {
183 path_part
184 };
185
186 let joined_path = if remaining_path.is_empty() {
188 if target_base_path.is_empty() {
189 "/"
190 } else if target_has_trailing_slash {
191 "__TRAILING__"
193 } else {
194 target_base_path
195 }
196 } else {
197 if target_base_path.is_empty() {
199 remaining_path
200 } else {
201 "__JOIN__"
207 }
208 };
209
210 let final_path = if joined_path == "__JOIN__" {
212 let mut s = String::with_capacity(target_base_path.len() + remaining_path.len());
213 s.push_str(target_base_path);
214 s.push_str(remaining_path);
215 s
216 } else if joined_path == "__TRAILING__" {
217 let mut s = String::with_capacity(target_base_path.len() + 1);
218 s.push_str(target_base_path);
219 s.push('/');
220 s
221 } else {
222 joined_path.to_string()
223 };
224
225 let mut path_and_query_buf = final_path;
226 if let Some(q) = query_part {
227 path_and_query_buf.push('?');
228 path_and_query_buf.push_str(q);
229 }
230
231 UriBuilder::new()
233 .scheme(scheme)
234 .authority(authority.as_str())
235 .path_and_query(path_and_query_buf.as_str())
236 .build()
237 .expect("Failed to build upstream URI")
238 }
239}
240
241use std::{
242 future::Future,
243 pin::Pin,
244 task::{Context, Poll},
245};
246use tower::Service;
247
248impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
249where
250 C: Connect + Clone + Send + Sync + 'static,
251{
252 type Response = axum::http::Response<Body>;
253 type Error = Infallible;
254 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
255
256 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
257 Poll::Ready(Ok(()))
258 }
259
260 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
261 let this = self.clone();
262 Box::pin(async move { this.handle_request(req).await })
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::StandardReverseProxy as ReverseProxy;
269
270 #[test]
271 fn transform_uri_with_and_without_trailing_slash() {
272 let proxy = ReverseProxy::new("/api/", "http://target");
273 assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
274
275 let proxy_no_slash = ReverseProxy::new("/api", "http://target");
276 assert_eq!(
277 proxy_no_slash.transform_uri("/api/test"),
278 "http://target/test"
279 );
280 }
281
282 #[test]
283 fn transform_uri_root() {
284 let proxy = ReverseProxy::new("/", "http://target");
285 assert_eq!(proxy.transform_uri("/test"), "http://target/test");
286 }
287
288 #[test]
289 fn transform_uri_with_query() {
290 let proxy_root = ReverseProxy::new("/", "http://target");
291
292 assert_eq!(
293 proxy_root.transform_uri("?query=test"),
294 "http://target?query=test"
295 );
296 assert_eq!(
297 proxy_root.transform_uri("/?query=test"),
298 "http://target/?query=test"
299 );
300 assert_eq!(
301 proxy_root.transform_uri("/test?query=test"),
302 "http://target/test?query=test"
303 );
304
305 let proxy_root_no_slash = ReverseProxy::new("/", "http://target/api");
306 assert_eq!(
307 proxy_root_no_slash.transform_uri("/test?query=test"),
308 "http://target/api/test?query=test"
309 );
310 assert_eq!(
311 proxy_root_no_slash.transform_uri("?query=test"),
312 "http://target/api?query=test"
313 );
314
315 let proxy_root_slash = ReverseProxy::new("/", "http://target/api/");
316 assert_eq!(
317 proxy_root_slash.transform_uri("/test?query=test"),
318 "http://target/api/test?query=test"
319 );
320 assert_eq!(
321 proxy_root_slash.transform_uri("?query=test"),
322 "http://target/api/?query=test"
323 );
324
325 let proxy_no_slash = ReverseProxy::new("/test", "http://target/api");
326 assert_eq!(
327 proxy_no_slash.transform_uri("/test?query=test"),
328 "http://target/api?query=test"
329 );
330 assert_eq!(
331 proxy_no_slash.transform_uri("/test/?query=test"),
332 "http://target/api/?query=test"
333 );
334 assert_eq!(
335 proxy_no_slash.transform_uri("?query=test"),
336 "http://target/api?query=test"
337 );
338
339 let proxy_with_slash = ReverseProxy::new("/test", "http://target/api/");
340 assert_eq!(
341 proxy_with_slash.transform_uri("/test?query=test"),
342 "http://target/api/?query=test"
343 );
344 assert_eq!(
345 proxy_with_slash.transform_uri("/test/?query=test"),
346 "http://target/api/?query=test"
347 );
348 assert_eq!(
349 proxy_with_slash.transform_uri("/something"),
350 "http://target/api/something"
351 );
352 assert_eq!(
353 proxy_with_slash.transform_uri("/test/something"),
354 "http://target/api/something"
355 );
356 }
357}