axum_reverse_proxy/router_ext.rs
1//! Router extension for adding proxy routes with dynamic target resolution.
2//!
3//! This module provides the [`ProxyRouterExt`] trait which extends [`axum::Router`]
4//! with a convenient [`proxy_route`](ProxyRouterExt::proxy_route) method for adding
5//! proxy routes with static or dynamic target URLs.
6//!
7//! # Example
8//!
9//! ```rust
10//! use axum::Router;
11//! use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
12//!
13//! let app: Router = Router::new()
14//! // Static target
15//! .proxy_route("/api/{*rest}", "https://api.example.com")
16//! // Dynamic target with path parameter substitution
17//! .proxy_route("/users/{id}/profile", proxy_template("https://profiles.example.com/user/{id}"));
18//! ```
19
20use axum::{
21 Router,
22 body::Body,
23 extract::Path,
24 http::{Request, Response, StatusCode, Uri},
25 routing::any,
26};
27use http::uri::Builder as UriBuilder;
28use std::convert::Infallible;
29use tracing::{error, trace};
30
31use crate::forward::{ProxyClient, create_proxy_client, forward_request};
32
33/// A trait for resolving the target URL for a proxy request.
34///
35/// Implement this trait to provide custom target URL resolution logic.
36/// The resolver receives the full request and path parameters, allowing
37/// routing decisions based on headers, method, query parameters, etc.
38///
39/// # Built-in Implementations
40///
41/// - `String` and `&'static str`: Static target URLs (request/parameters are ignored)
42/// - [`TemplateTarget`]: Template-based URL with `{param}` substitution
43///
44/// # Example
45///
46/// ```rust
47/// use axum::body::Body;
48/// use axum::http::Request;
49/// use axum_reverse_proxy::TargetResolver;
50///
51/// #[derive(Clone)]
52/// struct HeaderBasedResolver {
53/// default_url: String,
54/// premium_url: String,
55/// }
56///
57/// impl TargetResolver for HeaderBasedResolver {
58/// fn resolve(&self, req: &Request<Body>, _params: &[(String, String)]) -> String {
59/// // Route premium users to a different backend
60/// if req.headers().get("x-premium-user").is_some() {
61/// self.premium_url.clone()
62/// } else {
63/// self.default_url.clone()
64/// }
65/// }
66/// }
67/// ```
68pub trait TargetResolver: Clone + Send + Sync + 'static {
69 /// Resolve the target URL based on the request and path parameters.
70 ///
71 /// # Arguments
72 ///
73 /// * `req` - The incoming HTTP request (headers, method, URI, etc.)
74 /// * `params` - Path parameters extracted from the request URL as key-value pairs
75 ///
76 /// # Returns
77 ///
78 /// The target URL as a string. This should be a valid URL including scheme and host.
79 fn resolve(&self, req: &Request<Body>, params: &[(String, String)]) -> String;
80}
81
82impl TargetResolver for String {
83 fn resolve(&self, _req: &Request<Body>, _params: &[(String, String)]) -> String {
84 self.clone()
85 }
86}
87
88impl TargetResolver for &'static str {
89 fn resolve(&self, _req: &Request<Body>, _params: &[(String, String)]) -> String {
90 (*self).to_string()
91 }
92}
93
94/// A template-based target resolver that substitutes path parameters into a URL template.
95///
96/// Template placeholders use the format `{param_name}` and are replaced with the
97/// corresponding path parameter values from the request.
98///
99/// # Example
100///
101/// ```rust
102/// use axum::Router;
103/// use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
104///
105/// let app: Router = Router::new()
106/// // Request to /videos/abc123/720p proxies to https://cdn.example.com/v/abc123/res_720p
107/// .proxy_route("/videos/{id}/{quality}", proxy_template("https://cdn.example.com/v/{id}/res_{quality}"));
108/// ```
109#[derive(Clone)]
110pub struct TemplateTarget {
111 template: String,
112}
113
114impl TemplateTarget {
115 /// Create a new template target with the given URL template.
116 ///
117 /// # Arguments
118 ///
119 /// * `template` - A URL template with `{param}` placeholders
120 pub fn new(template: impl Into<String>) -> Self {
121 Self {
122 template: template.into(),
123 }
124 }
125}
126
127impl TargetResolver for TemplateTarget {
128 fn resolve(&self, _req: &Request<Body>, params: &[(String, String)]) -> String {
129 let mut result = self.template.clone();
130 for (key, value) in params {
131 let placeholder = format!("{{{}}}", key);
132 result = result.replace(&placeholder, value);
133 }
134 result
135 }
136}
137
138/// Create a new [`TemplateTarget`] with the given URL template.
139///
140/// This is a convenience function for creating template-based target resolvers.
141///
142/// # Arguments
143///
144/// * `template` - A URL template with `{param}` placeholders
145///
146/// # Example
147///
148/// ```rust
149/// use axum::Router;
150/// use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
151///
152/// let app: Router = Router::new()
153/// .proxy_route("/api/{version}/{*path}", proxy_template("https://api.example.com/{version}/{path}"));
154/// ```
155pub fn proxy_template(template: impl Into<String>) -> TemplateTarget {
156 TemplateTarget::new(template)
157}
158
159/// Extension trait for [`axum::Router`] that adds proxy routing capabilities.
160///
161/// This trait provides a convenient way to add proxy routes to an Axum router
162/// with support for both static and dynamic target URLs.
163pub trait ProxyRouterExt<S> {
164 /// Add a proxy route that forwards requests to a target URL.
165 ///
166 /// The target can be:
167 /// - A static string (`&str` or `String`)
168 /// - A [`TemplateTarget`] for dynamic URL generation based on path parameters
169 /// - Any custom type implementing [`TargetResolver`]
170 ///
171 /// # Arguments
172 ///
173 /// * `path` - The route path pattern (e.g., `/api/{id}` or `/proxy/{*rest}`)
174 /// * `target` - The target URL or resolver
175 ///
176 /// # Example
177 ///
178 /// ```rust
179 /// use axum::Router;
180 /// use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
181 ///
182 /// let app: Router = Router::new()
183 /// // Static proxy
184 /// .proxy_route("/api/{*rest}", "https://api.example.com")
185 /// // Dynamic proxy with path substitution
186 /// .proxy_route("/users/{id}", proxy_template("https://users.example.com/{id}"));
187 /// ```
188 fn proxy_route<T: TargetResolver>(self, path: &str, target: T) -> Self;
189}
190
191impl<S> ProxyRouterExt<S> for Router<S>
192where
193 S: Clone + Send + Sync + 'static,
194{
195 fn proxy_route<T: TargetResolver>(self, path: &str, target: T) -> Self {
196 let client = create_proxy_client();
197
198 self.route(
199 path,
200 any(
201 move |Path(params): Path<Vec<(String, String)>>, req: Request<Body>| {
202 let target = target.clone();
203 let client = client.clone();
204 async move { proxy_request(target, params, req, client).await }
205 },
206 ),
207 )
208 }
209}
210
211async fn proxy_request<T: TargetResolver>(
212 target: T,
213 params: Vec<(String, String)>,
214 req: Request<Body>,
215 client: ProxyClient,
216) -> Result<Response<Body>, Infallible> {
217 let target_url = target.resolve(&req, ¶ms);
218 trace!("Proxying request to resolved target: {}", target_url);
219
220 // Parse target URL
221 let target_uri: Uri = match target_url.parse() {
222 Ok(uri) => uri,
223 Err(e) => {
224 error!("Invalid target URL '{}': {}", target_url, e);
225 return Ok(Response::builder()
226 .status(StatusCode::INTERNAL_SERVER_ERROR)
227 .body(Body::from(format!("Invalid target URL: {e}")))
228 .unwrap());
229 }
230 };
231
232 // Build the upstream URI, preserving query string from original request
233 let upstream_uri = build_upstream_uri(&target_uri, req.uri());
234
235 // Use shared forwarding logic
236 forward_request(upstream_uri, req, &client).await
237}
238
239/// Build the upstream URI from the target and original request.
240///
241/// If the original request has a query string and the target doesn't,
242/// the query string is appended to the target.
243fn build_upstream_uri(target: &Uri, original: &Uri) -> Uri {
244 let scheme = target.scheme_str().unwrap_or("http");
245 let authority = target
246 .authority()
247 .map(|a| a.as_str())
248 .unwrap_or("localhost");
249 let path = target.path();
250
251 // Combine query strings: prefer target's query, fall back to original's
252 let query = target.query().or_else(|| original.query());
253
254 let path_and_query = match query {
255 Some(q) => format!("{}?{}", path, q),
256 None => path.to_string(),
257 };
258
259 UriBuilder::new()
260 .scheme(scheme)
261 .authority(authority)
262 .path_and_query(path_and_query)
263 .build()
264 .expect("Failed to build upstream URI")
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 fn dummy_request() -> Request<Body> {
272 Request::builder().uri("/test").body(Body::empty()).unwrap()
273 }
274
275 #[test]
276 fn test_static_string_resolver() {
277 let resolver = "https://example.com".to_string();
278 let req = dummy_request();
279 let params = vec![("id".to_string(), "123".to_string())];
280 assert_eq!(resolver.resolve(&req, ¶ms), "https://example.com");
281 }
282
283 #[test]
284 fn test_static_str_resolver() {
285 let resolver: &'static str = "https://example.com";
286 let req = dummy_request();
287 let params = vec![("id".to_string(), "123".to_string())];
288 assert_eq!(resolver.resolve(&req, ¶ms), "https://example.com");
289 }
290
291 #[test]
292 fn test_template_resolver_single_param() {
293 let resolver = proxy_template("https://example.com/users/{id}");
294 let req = dummy_request();
295 let params = vec![("id".to_string(), "123".to_string())];
296 assert_eq!(
297 resolver.resolve(&req, ¶ms),
298 "https://example.com/users/123"
299 );
300 }
301
302 #[test]
303 fn test_template_resolver_multiple_params() {
304 let resolver = proxy_template("https://cdn.example.com/{id}/quality_{quality}");
305 let req = dummy_request();
306 let params = vec![
307 ("id".to_string(), "video123".to_string()),
308 ("quality".to_string(), "720p".to_string()),
309 ];
310 assert_eq!(
311 resolver.resolve(&req, ¶ms),
312 "https://cdn.example.com/video123/quality_720p"
313 );
314 }
315
316 #[test]
317 fn test_template_resolver_missing_param() {
318 let resolver = proxy_template("https://example.com/{id}/{missing}");
319 let req = dummy_request();
320 let params = vec![("id".to_string(), "123".to_string())];
321 // Missing params are left as-is (placeholder remains)
322 assert_eq!(
323 resolver.resolve(&req, ¶ms),
324 "https://example.com/123/{missing}"
325 );
326 }
327
328 #[test]
329 fn test_template_resolver_no_params() {
330 let resolver = proxy_template("https://example.com/static/path");
331 let req = dummy_request();
332 let params = vec![("id".to_string(), "123".to_string())];
333 assert_eq!(
334 resolver.resolve(&req, ¶ms),
335 "https://example.com/static/path"
336 );
337 }
338
339 #[test]
340 fn test_build_upstream_uri_with_target_query() {
341 let target: Uri = "https://example.com/path?foo=bar".parse().unwrap();
342 let original: Uri = "/request?baz=qux".parse().unwrap();
343 let result = build_upstream_uri(&target, &original);
344 // Target query takes precedence
345 assert_eq!(result.to_string(), "https://example.com/path?foo=bar");
346 }
347
348 #[test]
349 fn test_build_upstream_uri_with_original_query() {
350 let target: Uri = "https://example.com/path".parse().unwrap();
351 let original: Uri = "/request?baz=qux".parse().unwrap();
352 let result = build_upstream_uri(&target, &original);
353 // Falls back to original query
354 assert_eq!(result.to_string(), "https://example.com/path?baz=qux");
355 }
356
357 #[test]
358 fn test_build_upstream_uri_no_query() {
359 let target: Uri = "https://example.com/path".parse().unwrap();
360 let original: Uri = "/request".parse().unwrap();
361 let result = build_upstream_uri(&target, &original);
362 assert_eq!(result.to_string(), "https://example.com/path");
363 }
364}