1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
use axum::body::Body;
use http::Uri;
use http::uri::Builder as UriBuilder;
use hyper_util::client::legacy::{Client, connect::Connect};
use std::convert::Infallible;
use tracing::trace;
use crate::forward::{ProxyConnector, create_http_connector, forward_request};
/// A reverse proxy that forwards HTTP requests to an upstream server.
///
/// The `ReverseProxy` struct handles the forwarding of HTTP requests from a specified path
/// to a target upstream server. It manages its own HTTP client with configurable settings
/// for connection pooling, timeouts, and retries.
#[derive(Clone)]
pub struct ReverseProxy<C: Connect + Clone + Send + Sync + 'static> {
path: String,
target: String,
client: Client<C, Body>,
}
pub type StandardReverseProxy = ReverseProxy<ProxyConnector>;
impl StandardReverseProxy {
/// Creates a new `ReverseProxy` instance.
///
/// # Arguments
///
/// * `path` - The base path to match incoming requests against (e.g., "/api")
/// * `target` - The upstream server URL to forward requests to (e.g., "https://api.example.com")
///
/// # Example
///
/// ```rust
/// use axum_reverse_proxy::ReverseProxy;
///
/// let proxy = ReverseProxy::new("/api", "https://api.example.com");
/// ```
pub fn new<S>(path: S, target: S) -> Self
where
S: Into<String>,
{
let client = Client::builder(hyper_util::rt::TokioExecutor::new())
.pool_idle_timeout(std::time::Duration::from_secs(60))
.pool_max_idle_per_host(32)
.retry_canceled_requests(true)
.set_host(true)
.build(create_http_connector());
Self::new_with_client(path, target, client)
}
}
impl<C: Connect + Clone + Send + Sync + 'static> ReverseProxy<C> {
/// Creates a new `ReverseProxy` instance with a custom HTTP client.
///
/// This method allows for more fine-grained control over the proxy behavior by accepting
/// a pre-configured HTTP client.
///
/// # Arguments
///
/// * `path` - The base path to match incoming requests against
/// * `target` - The upstream server URL to forward requests to
/// * `client` - A custom-configured HTTP client
///
/// # Example
///
/// ```rust
/// use axum_reverse_proxy::ReverseProxy;
/// use hyper_util::client::legacy::{Client, connect::HttpConnector};
/// use axum::body::Body;
/// use hyper_util::rt::TokioExecutor;
///
/// let client = Client::builder(TokioExecutor::new())
/// .pool_idle_timeout(std::time::Duration::from_secs(120))
/// .build(HttpConnector::new());
///
/// let proxy = ReverseProxy::new_with_client(
/// "/api",
/// "https://api.example.com",
/// client,
/// );
/// ```
pub fn new_with_client<S>(path: S, target: S, client: Client<C, Body>) -> Self
where
S: Into<String>,
{
Self {
path: path.into(),
target: target.into(),
client,
}
}
/// Get the base path this proxy is configured to handle
pub fn path(&self) -> &str {
&self.path
}
/// Get the target URL this proxy forwards requests to
pub fn target(&self) -> &str {
&self.target
}
/// Handles the proxying of a single request to the upstream server.
pub async fn proxy_request(
&self,
req: axum::http::Request<Body>,
) -> Result<axum::http::Response<Body>, Infallible> {
self.handle_request(req).await
}
/// Core proxy logic used by the [`tower::Service`] implementation.
async fn handle_request(
&self,
req: axum::http::Request<Body>,
) -> Result<axum::http::Response<Body>, Infallible> {
trace!("Proxying request method={} uri={}", req.method(), req.uri());
// Transform the URI to the upstream target
let path_q = req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("");
let upstream_uri = self.transform_uri(path_q);
// Use shared forwarding logic
forward_request(upstream_uri, req, &self.client).await
}
/// Transform an incoming request path+query into the target URI using http::Uri builder
///
/// Rules:
/// - Trim target trailing slash for joining
/// - Strip proxy base path at a boundary (exact or followed by '/')
/// - If remainder is exactly '/' under a non-empty base, treat as empty
/// - Do not add a slash for query-only joins (avoid target '/?')
fn transform_uri(&self, path_and_query: &str) -> Uri {
let base_path = self.path.trim_end_matches('/');
// Parse target URI
let target_uri: Uri = self
.target
.parse()
.expect("ReverseProxy target must be a valid URI");
let scheme = target_uri.scheme_str().unwrap_or("http");
let authority = target_uri
.authority()
.expect("ReverseProxy target must include authority (host)")
.as_str()
.to_string();
// Check if target originally had a trailing slash
let target_has_trailing_slash =
target_uri.path().ends_with('/') && target_uri.path() != "/";
// Normalize target base path: drop trailing slash and treat "/" as empty
let target_base_path = {
let p = target_uri.path();
if p == "/" {
""
} else {
p.trim_end_matches('/')
}
};
// Split incoming path and query
let (path_part, query_part) = match path_and_query.find('?') {
Some(i) => (&path_and_query[..i], Some(&path_and_query[i + 1..])),
None => (path_and_query, None),
};
// Compute remainder after stripping base when applicable
let remaining_path = if path_part == "/" && !self.path.is_empty() {
""
} else if !base_path.is_empty() && path_part.starts_with(base_path) {
let rem = &path_part[base_path.len()..];
if rem.is_empty() || rem.starts_with('/') {
rem
} else {
path_part
}
} else {
path_part
};
// Join target base path with remainder
let joined_path = if remaining_path.is_empty() {
if target_base_path.is_empty() {
"/"
} else if target_has_trailing_slash {
// Preserve trailing slash from target when no remaining path
"__TRAILING__"
} else {
target_base_path
}
} else {
// remaining_path starts with '/'; concatenate without duplicating slash
if target_base_path.is_empty() {
remaining_path
} else {
// allocate a small string to join
// SAFETY: both parts are valid path slices
// Build into a String for path_and_query
// We will rebuild below
// Placeholder; real joining below
"__JOIN__"
}
};
// Build final path_and_query string explicitly to keep exact bytes
let final_path = if joined_path == "__JOIN__" {
let mut s = String::with_capacity(target_base_path.len() + remaining_path.len());
s.push_str(target_base_path);
s.push_str(remaining_path);
s
} else if joined_path == "__TRAILING__" {
let mut s = String::with_capacity(target_base_path.len() + 1);
s.push_str(target_base_path);
s.push('/');
s
} else {
joined_path.to_string()
};
let mut path_and_query_buf = final_path;
if let Some(q) = query_part {
path_and_query_buf.push('?');
path_and_query_buf.push_str(q);
}
// Build the full URI
UriBuilder::new()
.scheme(scheme)
.authority(authority.as_str())
.path_and_query(path_and_query_buf.as_str())
.build()
.expect("Failed to build upstream URI")
}
}
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;
impl<C> Service<axum::http::Request<Body>> for ReverseProxy<C>
where
C: Connect + Clone + Send + Sync + 'static,
{
type Response = axum::http::Response<Body>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
let this = self.clone();
Box::pin(async move { this.handle_request(req).await })
}
}
#[cfg(test)]
mod tests {
use super::StandardReverseProxy as ReverseProxy;
#[test]
fn transform_uri_with_and_without_trailing_slash() {
let proxy = ReverseProxy::new("/api/", "http://target");
assert_eq!(proxy.transform_uri("/api/test"), "http://target/test");
let proxy_no_slash = ReverseProxy::new("/api", "http://target");
assert_eq!(
proxy_no_slash.transform_uri("/api/test"),
"http://target/test"
);
}
#[test]
fn transform_uri_root() {
let proxy = ReverseProxy::new("/", "http://target");
assert_eq!(proxy.transform_uri("/test"), "http://target/test");
}
#[test]
fn transform_uri_with_query() {
let proxy_root = ReverseProxy::new("/", "http://target");
assert_eq!(
proxy_root.transform_uri("?query=test"),
"http://target?query=test"
);
assert_eq!(
proxy_root.transform_uri("/?query=test"),
"http://target/?query=test"
);
assert_eq!(
proxy_root.transform_uri("/test?query=test"),
"http://target/test?query=test"
);
let proxy_root_no_slash = ReverseProxy::new("/", "http://target/api");
assert_eq!(
proxy_root_no_slash.transform_uri("/test?query=test"),
"http://target/api/test?query=test"
);
assert_eq!(
proxy_root_no_slash.transform_uri("?query=test"),
"http://target/api?query=test"
);
let proxy_root_slash = ReverseProxy::new("/", "http://target/api/");
assert_eq!(
proxy_root_slash.transform_uri("/test?query=test"),
"http://target/api/test?query=test"
);
assert_eq!(
proxy_root_slash.transform_uri("?query=test"),
"http://target/api/?query=test"
);
let proxy_no_slash = ReverseProxy::new("/test", "http://target/api");
assert_eq!(
proxy_no_slash.transform_uri("/test?query=test"),
"http://target/api?query=test"
);
assert_eq!(
proxy_no_slash.transform_uri("/test/?query=test"),
"http://target/api/?query=test"
);
assert_eq!(
proxy_no_slash.transform_uri("?query=test"),
"http://target/api?query=test"
);
let proxy_with_slash = ReverseProxy::new("/test", "http://target/api/");
assert_eq!(
proxy_with_slash.transform_uri("/test?query=test"),
"http://target/api/?query=test"
);
assert_eq!(
proxy_with_slash.transform_uri("/test/?query=test"),
"http://target/api/?query=test"
);
assert_eq!(
proxy_with_slash.transform_uri("/something"),
"http://target/api/something"
);
assert_eq!(
proxy_with_slash.transform_uri("/test/something"),
"http://target/api/something"
);
}
}