axum_reverse_proxy/
router_ext.rs

1//! Router extension for adding proxy routes with dynamic target resolution.
2//!
3//! This module provides the [`ProxyRouterExt`] trait which extends [`axum::Router`]
4//! with a convenient [`proxy_route`](ProxyRouterExt::proxy_route) method for adding
5//! proxy routes with static or dynamic target URLs.
6//!
7//! # Example
8//!
9//! ```rust
10//! use axum::Router;
11//! use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
12//!
13//! let app: Router = Router::new()
14//!     // Static target
15//!     .proxy_route("/api/{*rest}", "https://api.example.com")
16//!     // Dynamic target with path parameter substitution
17//!     .proxy_route("/users/{id}/profile", proxy_template("https://profiles.example.com/user/{id}"));
18//! ```
19
20use axum::{
21    Router,
22    body::Body,
23    extract::Path,
24    http::{Request, Response, StatusCode, Uri},
25    routing::any,
26};
27use http::uri::Builder as UriBuilder;
28use std::convert::Infallible;
29use tracing::{error, trace};
30
31use crate::forward::{ProxyClient, create_proxy_client, forward_request};
32
33/// A trait for resolving the target URL for a proxy request.
34///
35/// Implement this trait to provide custom target URL resolution logic.
36/// The resolver receives the full request and path parameters, allowing
37/// routing decisions based on headers, method, query parameters, etc.
38///
39/// # Built-in Implementations
40///
41/// - `String` and `&'static str`: Static target URLs (request/parameters are ignored)
42/// - [`TemplateTarget`]: Template-based URL with `{param}` substitution
43///
44/// # Example
45///
46/// ```rust
47/// use axum::body::Body;
48/// use axum::http::Request;
49/// use axum_reverse_proxy::TargetResolver;
50///
51/// #[derive(Clone)]
52/// struct HeaderBasedResolver {
53///     default_url: String,
54///     premium_url: String,
55/// }
56///
57/// impl TargetResolver for HeaderBasedResolver {
58///     fn resolve(&self, req: &Request<Body>, _params: &[(String, String)]) -> String {
59///         // Route premium users to a different backend
60///         if req.headers().get("x-premium-user").is_some() {
61///             self.premium_url.clone()
62///         } else {
63///             self.default_url.clone()
64///         }
65///     }
66/// }
67/// ```
68pub trait TargetResolver: Clone + Send + Sync + 'static {
69    /// Resolve the target URL based on the request and path parameters.
70    ///
71    /// # Arguments
72    ///
73    /// * `req` - The incoming HTTP request (headers, method, URI, etc.)
74    /// * `params` - Path parameters extracted from the request URL as key-value pairs
75    ///
76    /// # Returns
77    ///
78    /// The target URL as a string. This should be a valid URL including scheme and host.
79    fn resolve(&self, req: &Request<Body>, params: &[(String, String)]) -> String;
80}
81
82impl TargetResolver for String {
83    fn resolve(&self, _req: &Request<Body>, _params: &[(String, String)]) -> String {
84        self.clone()
85    }
86}
87
88impl TargetResolver for &'static str {
89    fn resolve(&self, _req: &Request<Body>, _params: &[(String, String)]) -> String {
90        (*self).to_string()
91    }
92}
93
94/// A template-based target resolver that substitutes path parameters into a URL template.
95///
96/// Template placeholders use the format `{param_name}` and are replaced with the
97/// corresponding path parameter values from the request.
98///
99/// # Example
100///
101/// ```rust
102/// use axum::Router;
103/// use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
104///
105/// let app: Router = Router::new()
106///     // Request to /videos/abc123/720p proxies to https://cdn.example.com/v/abc123/res_720p
107///     .proxy_route("/videos/{id}/{quality}", proxy_template("https://cdn.example.com/v/{id}/res_{quality}"));
108/// ```
109#[derive(Clone)]
110pub struct TemplateTarget {
111    template: String,
112}
113
114impl TemplateTarget {
115    /// Create a new template target with the given URL template.
116    ///
117    /// # Arguments
118    ///
119    /// * `template` - A URL template with `{param}` placeholders
120    pub fn new(template: impl Into<String>) -> Self {
121        Self {
122            template: template.into(),
123        }
124    }
125}
126
127impl TargetResolver for TemplateTarget {
128    fn resolve(&self, _req: &Request<Body>, params: &[(String, String)]) -> String {
129        let mut result = self.template.clone();
130        for (key, value) in params {
131            let placeholder = format!("{{{}}}", key);
132            result = result.replace(&placeholder, value);
133        }
134        result
135    }
136}
137
138/// Create a new [`TemplateTarget`] with the given URL template.
139///
140/// This is a convenience function for creating template-based target resolvers.
141///
142/// # Arguments
143///
144/// * `template` - A URL template with `{param}` placeholders
145///
146/// # Example
147///
148/// ```rust
149/// use axum::Router;
150/// use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
151///
152/// let app: Router = Router::new()
153///     .proxy_route("/api/{version}/{*path}", proxy_template("https://api.example.com/{version}/{path}"));
154/// ```
155pub fn proxy_template(template: impl Into<String>) -> TemplateTarget {
156    TemplateTarget::new(template)
157}
158
159/// Extension trait for [`axum::Router`] that adds proxy routing capabilities.
160///
161/// This trait provides a convenient way to add proxy routes to an Axum router
162/// with support for both static and dynamic target URLs.
163pub trait ProxyRouterExt<S> {
164    /// Add a proxy route that forwards requests to a target URL.
165    ///
166    /// The target can be:
167    /// - A static string (`&str` or `String`)
168    /// - A [`TemplateTarget`] for dynamic URL generation based on path parameters
169    /// - Any custom type implementing [`TargetResolver`]
170    ///
171    /// # Arguments
172    ///
173    /// * `path` - The route path pattern (e.g., `/api/{id}` or `/proxy/{*rest}`)
174    /// * `target` - The target URL or resolver
175    ///
176    /// # Example
177    ///
178    /// ```rust
179    /// use axum::Router;
180    /// use axum_reverse_proxy::{ProxyRouterExt, proxy_template};
181    ///
182    /// let app: Router = Router::new()
183    ///     // Static proxy
184    ///     .proxy_route("/api/{*rest}", "https://api.example.com")
185    ///     // Dynamic proxy with path substitution
186    ///     .proxy_route("/users/{id}", proxy_template("https://users.example.com/{id}"));
187    /// ```
188    fn proxy_route<T: TargetResolver>(self, path: &str, target: T) -> Self;
189}
190
191impl<S> ProxyRouterExt<S> for Router<S>
192where
193    S: Clone + Send + Sync + 'static,
194{
195    fn proxy_route<T: TargetResolver>(self, path: &str, target: T) -> Self {
196        let client = create_proxy_client();
197
198        self.route(
199            path,
200            any(
201                move |Path(params): Path<Vec<(String, String)>>, req: Request<Body>| {
202                    let target = target.clone();
203                    let client = client.clone();
204                    async move { proxy_request(target, params, req, client).await }
205                },
206            ),
207        )
208    }
209}
210
211async fn proxy_request<T: TargetResolver>(
212    target: T,
213    params: Vec<(String, String)>,
214    req: Request<Body>,
215    client: ProxyClient,
216) -> Result<Response<Body>, Infallible> {
217    let target_url = target.resolve(&req, &params);
218    trace!("Proxying request to resolved target: {}", target_url);
219
220    // Parse target URL
221    let target_uri: Uri = match target_url.parse() {
222        Ok(uri) => uri,
223        Err(e) => {
224            error!("Invalid target URL '{}': {}", target_url, e);
225            return Ok(Response::builder()
226                .status(StatusCode::INTERNAL_SERVER_ERROR)
227                .body(Body::from(format!("Invalid target URL: {e}")))
228                .unwrap());
229        }
230    };
231
232    // Build the upstream URI, preserving query string from original request
233    let upstream_uri = build_upstream_uri(&target_uri, req.uri());
234
235    // Use shared forwarding logic
236    forward_request(upstream_uri, req, &client).await
237}
238
239/// Build the upstream URI from the target and original request.
240///
241/// If the original request has a query string and the target doesn't,
242/// the query string is appended to the target.
243fn build_upstream_uri(target: &Uri, original: &Uri) -> Uri {
244    let scheme = target.scheme_str().unwrap_or("http");
245    let authority = target
246        .authority()
247        .map(|a| a.as_str())
248        .unwrap_or("localhost");
249    let path = target.path();
250
251    // Combine query strings: prefer target's query, fall back to original's
252    let query = target.query().or_else(|| original.query());
253
254    let path_and_query = match query {
255        Some(q) => format!("{}?{}", path, q),
256        None => path.to_string(),
257    };
258
259    UriBuilder::new()
260        .scheme(scheme)
261        .authority(authority)
262        .path_and_query(path_and_query)
263        .build()
264        .expect("Failed to build upstream URI")
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    fn dummy_request() -> Request<Body> {
272        Request::builder().uri("/test").body(Body::empty()).unwrap()
273    }
274
275    #[test]
276    fn test_static_string_resolver() {
277        let resolver = "https://example.com".to_string();
278        let req = dummy_request();
279        let params = vec![("id".to_string(), "123".to_string())];
280        assert_eq!(resolver.resolve(&req, &params), "https://example.com");
281    }
282
283    #[test]
284    fn test_static_str_resolver() {
285        let resolver: &'static str = "https://example.com";
286        let req = dummy_request();
287        let params = vec![("id".to_string(), "123".to_string())];
288        assert_eq!(resolver.resolve(&req, &params), "https://example.com");
289    }
290
291    #[test]
292    fn test_template_resolver_single_param() {
293        let resolver = proxy_template("https://example.com/users/{id}");
294        let req = dummy_request();
295        let params = vec![("id".to_string(), "123".to_string())];
296        assert_eq!(
297            resolver.resolve(&req, &params),
298            "https://example.com/users/123"
299        );
300    }
301
302    #[test]
303    fn test_template_resolver_multiple_params() {
304        let resolver = proxy_template("https://cdn.example.com/{id}/quality_{quality}");
305        let req = dummy_request();
306        let params = vec![
307            ("id".to_string(), "video123".to_string()),
308            ("quality".to_string(), "720p".to_string()),
309        ];
310        assert_eq!(
311            resolver.resolve(&req, &params),
312            "https://cdn.example.com/video123/quality_720p"
313        );
314    }
315
316    #[test]
317    fn test_template_resolver_missing_param() {
318        let resolver = proxy_template("https://example.com/{id}/{missing}");
319        let req = dummy_request();
320        let params = vec![("id".to_string(), "123".to_string())];
321        // Missing params are left as-is (placeholder remains)
322        assert_eq!(
323            resolver.resolve(&req, &params),
324            "https://example.com/123/{missing}"
325        );
326    }
327
328    #[test]
329    fn test_template_resolver_no_params() {
330        let resolver = proxy_template("https://example.com/static/path");
331        let req = dummy_request();
332        let params = vec![("id".to_string(), "123".to_string())];
333        assert_eq!(
334            resolver.resolve(&req, &params),
335            "https://example.com/static/path"
336        );
337    }
338
339    #[test]
340    fn test_build_upstream_uri_with_target_query() {
341        let target: Uri = "https://example.com/path?foo=bar".parse().unwrap();
342        let original: Uri = "/request?baz=qux".parse().unwrap();
343        let result = build_upstream_uri(&target, &original);
344        // Target query takes precedence
345        assert_eq!(result.to_string(), "https://example.com/path?foo=bar");
346    }
347
348    #[test]
349    fn test_build_upstream_uri_with_original_query() {
350        let target: Uri = "https://example.com/path".parse().unwrap();
351        let original: Uri = "/request?baz=qux".parse().unwrap();
352        let result = build_upstream_uri(&target, &original);
353        // Falls back to original query
354        assert_eq!(result.to_string(), "https://example.com/path?baz=qux");
355    }
356
357    #[test]
358    fn test_build_upstream_uri_no_query() {
359        let target: Uri = "https://example.com/path".parse().unwrap();
360        let original: Uri = "/request".parse().unwrap();
361        let result = build_upstream_uri(&target, &original);
362        assert_eq!(result.to_string(), "https://example.com/path");
363    }
364}