axum_reverse_proxy/
proxy.rs1use axum::body::Body;
2use http::Uri;
3use http::uri::Builder as UriBuilder;
4use hyper_util::client::legacy::{Client, connect::Connect};
5use std::convert::Infallible;
6use tracing::trace;
7
8use crate::forward::{ProxyConnector, create_http_connector, forward_request};
9
10#[derive(Clone)]
16pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
17 path: String,
18 target: String,
19 client: Client<C, Body>,
20}
21
22pub type StandardReverseProxy = ReverseProxy<ProxyConnector>;
23
24impl StandardReverseProxy {
25 pub fn new<S>(path: S, target: S) -> Self
40 where
41 S: Into<String>,
42 {
43 let client = Client::builder(hyper_util::rt::TokioExecutor::new())
44 .pool_idle_timeout(std::time::Duration::from_secs(60))
45 .pool_max_idle_per_host(32)
46 .retry_canceled_requests(true)
47 .set_host(true)
48 .build(create_http_connector());
49
50 Self::new_with_client(path, target, client)
51 }
52}
53
54impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
55 pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
85 where
86 S: Into<String>,
87 {
88 Self {
89 path: path.into(),
90 target: target.into(),
91 client,
92 }
93 }
94
95 pub fn path(&self) -> &str {
97 &self.path
98 }
99
100 pub fn target(&self) -> &str {
102 &self.target
103 }
104
105 pub async fn proxy_request(
107 &self,
108 req: axum::http::Request<Body>,
109 ) -> Result<axum::http::Response<Body>, Infallible> {
110 self.handle_request(req).await
111 }
112
113 async fn handle_request(
115 &self,
116 req: axum::http::Request<Body>,
117 ) -> Result<axum::http::Response<Body>, Infallible> {
118 trace!("Proxying request method={} uri={}", req.method(), req.uri());
119
120 let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
122 let upstream_uri = self.transform_uri(path_q);
123
124 forward_request(upstream_uri, req, &self.client).await
126 }
127
128 fn transform_uri(&self, path_and_query: &str) -> Uri {
136 let base_path = self.path.trim_end_matches('/');
137
138 let target_uri: Uri = self
140 .target
141 .parse()
142 .expect("ReverseProxy target must be a valid URI");
143
144 let scheme = target_uri.scheme_str().unwrap_or("http");
145 let authority = target_uri
146 .authority()
147 .expect("ReverseProxy target must include authority (host)")
148 .as_str()
149 .to_string();
150
151 let target_has_trailing_slash =
153 target_uri.path().ends_with('/') && target_uri.path() != "/";
154
155 let target_base_path = {
157 let p = target_uri.path();
158 if p == "/" {
159 ""
160 } else {
161 p.trim_end_matches('/')
162 }
163 };
164
165 let (path_part, query_part) = match path_and_query.find('?') {
167 Some(i) => (&path_and_query[..i], Some(&path_and_query[i + 1..])),
168 None => (path_and_query, None),
169 };
170
171 let remaining_path = if path_part == "/" && !self.path.is_empty() {
173 ""
174 } else if !base_path.is_empty() && path_part.starts_with(base_path) {
175 let rem = &path_part[base_path.len()..];
176 if rem.is_empty() || rem.starts_with('/') {
177 rem
178 } else {
179 path_part
180 }
181 } else {
182 path_part
183 };
184
185 let joined_path = if remaining_path.is_empty() {
187 if target_base_path.is_empty() {
188 "/"
189 } else if target_has_trailing_slash {
190 "__TRAILING__"
192 } else {
193 target_base_path
194 }
195 } else {
196 if target_base_path.is_empty() {
198 remaining_path
199 } else {
200 "__JOIN__"
206 }
207 };
208
209 let final_path = if joined_path == "__JOIN__" {
211 let mut s = String::with_capacity(target_base_path.len() + remaining_path.len());
212 s.push_str(target_base_path);
213 s.push_str(remaining_path);
214 s
215 } else if joined_path == "__TRAILING__" {
216 let mut s = String::with_capacity(target_base_path.len() + 1);
217 s.push_str(target_base_path);
218 s.push('/');
219 s
220 } else {
221 joined_path.to_string()
222 };
223
224 let mut path_and_query_buf = final_path;
225 if let Some(q) = query_part {
226 path_and_query_buf.push('?');
227 path_and_query_buf.push_str(q);
228 }
229
230 UriBuilder::new()
232 .scheme(scheme)
233 .authority(authority.as_str())
234 .path_and_query(path_and_query_buf.as_str())
235 .build()
236 .expect("Failed to build upstream URI")
237 }
238}
239
240use std::{
241 future::Future,
242 pin::Pin,
243 task::{Context, Poll},
244};
245use tower::Service;
246
247impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
248where
249 C: Connect + Clone + Send + Sync + 'static,
250{
251 type Response = axum::http::Response<Body>;
252 type Error = Infallible;
253 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
254
255 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
256 Poll::Ready(Ok(()))
257 }
258
259 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
260 let this = self.clone();
261 Box::pin(async move { this.handle_request(req).await })
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::StandardReverseProxy as ReverseProxy;
268
269 #[test]
270 fn transform_uri_with_and_without_trailing_slash() {
271 let proxy = ReverseProxy::new("/api/", "http://target");
272 assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
273
274 let proxy_no_slash = ReverseProxy::new("/api", "http://target");
275 assert_eq!(
276 proxy_no_slash.transform_uri("/api/test"),
277 "http://target/test"
278 );
279 }
280
281 #[test]
282 fn transform_uri_root() {
283 let proxy = ReverseProxy::new("/", "http://target");
284 assert_eq!(proxy.transform_uri("/test"), "http://target/test");
285 }
286
287 #[test]
288 fn transform_uri_with_query() {
289 let proxy_root = ReverseProxy::new("/", "http://target");
290
291 assert_eq!(
292 proxy_root.transform_uri("?query=test"),
293 "http://target?query=test"
294 );
295 assert_eq!(
296 proxy_root.transform_uri("/?query=test"),
297 "http://target/?query=test"
298 );
299 assert_eq!(
300 proxy_root.transform_uri("/test?query=test"),
301 "http://target/test?query=test"
302 );
303
304 let proxy_root_no_slash = ReverseProxy::new("/", "http://target/api");
305 assert_eq!(
306 proxy_root_no_slash.transform_uri("/test?query=test"),
307 "http://target/api/test?query=test"
308 );
309 assert_eq!(
310 proxy_root_no_slash.transform_uri("?query=test"),
311 "http://target/api?query=test"
312 );
313
314 let proxy_root_slash = ReverseProxy::new("/", "http://target/api/");
315 assert_eq!(
316 proxy_root_slash.transform_uri("/test?query=test"),
317 "http://target/api/test?query=test"
318 );
319 assert_eq!(
320 proxy_root_slash.transform_uri("?query=test"),
321 "http://target/api/?query=test"
322 );
323
324 let proxy_no_slash = ReverseProxy::new("/test", "http://target/api");
325 assert_eq!(
326 proxy_no_slash.transform_uri("/test?query=test"),
327 "http://target/api?query=test"
328 );
329 assert_eq!(
330 proxy_no_slash.transform_uri("/test/?query=test"),
331 "http://target/api/?query=test"
332 );
333 assert_eq!(
334 proxy_no_slash.transform_uri("?query=test"),
335 "http://target/api?query=test"
336 );
337
338 let proxy_with_slash = ReverseProxy::new("/test", "http://target/api/");
339 assert_eq!(
340 proxy_with_slash.transform_uri("/test?query=test"),
341 "http://target/api/?query=test"
342 );
343 assert_eq!(
344 proxy_with_slash.transform_uri("/test/?query=test"),
345 "http://target/api/?query=test"
346 );
347 assert_eq!(
348 proxy_with_slash.transform_uri("/something"),
349 "http://target/api/something"
350 );
351 assert_eq!(
352 proxy_with_slash.transform_uri("/test/something"),
353 "http://target/api/something"
354 );
355 }
356}