use axum::{
Router,
body::Body,
extract::Path,
http::{Request, Response, StatusCode, Uri},
routing::any,
};
use http::uri::Builder as UriBuilder;
use std::convert::Infallible;
use tracing::{error, trace};
use crate::forward::{ProxyClient, create_proxy_client, forward_request};
pub trait TargetResolver: Clone + Send + Sync + 'static {
fn resolve(&self, req: &Request<Body>, params: &[(String, String)]) -> String;
}
impl TargetResolver for String {
fn resolve(&self, _req: &Request<Body>, _params: &[(String, String)]) -> String {
self.clone()
}
}
impl TargetResolver for &'static str {
fn resolve(&self, _req: &Request<Body>, _params: &[(String, String)]) -> String {
(*self).to_string()
}
}
#[derive(Clone)]
pub struct TemplateTarget {
template: String,
}
impl TemplateTarget {
pub fn new(template: impl Into<String>) -> Self {
Self {
template: template.into(),
}
}
}
impl TargetResolver for TemplateTarget {
fn resolve(&self, _req: &Request<Body>, params: &[(String, String)]) -> String {
let mut result = self.template.clone();
for (key, value) in params {
let placeholder = format!("{{{}}}", key);
result = result.replace(&placeholder, value);
}
result
}
}
pub fn proxy_template(template: impl Into<String>) -> TemplateTarget {
TemplateTarget::new(template)
}
pub trait ProxyRouterExt<S> {
fn proxy_route<T: TargetResolver>(self, path: &str, target: T) -> Self;
}
impl<S> ProxyRouterExt<S> for Router<S>
where
S: Clone + Send + Sync + 'static,
{
fn proxy_route<T: TargetResolver>(self, path: &str, target: T) -> Self {
let client = create_proxy_client();
self.route(
path,
any(
move |Path(params): Path<Vec<(String, String)>>, req: Request<Body>| {
let target = target.clone();
let client = client.clone();
async move { proxy_request(target, params, req, client).await }
},
),
)
}
}
async fn proxy_request<T: TargetResolver>(
target: T,
params: Vec<(String, String)>,
req: Request<Body>,
client: ProxyClient,
) -> Result<Response<Body>, Infallible> {
let target_url = target.resolve(&req, ¶ms);
trace!("Proxying request to resolved target: {}", target_url);
let target_uri: Uri = match target_url.parse() {
Ok(uri) => uri,
Err(e) => {
error!("Invalid target URL '{}': {}", target_url, e);
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("Invalid target URL: {e}")))
.unwrap());
}
};
let upstream_uri = build_upstream_uri(&target_uri, req.uri());
forward_request(upstream_uri, req, &client).await
}
fn build_upstream_uri(target: &Uri, original: &Uri) -> Uri {
let scheme = target.scheme_str().unwrap_or("http");
let authority = target
.authority()
.map(|a| a.as_str())
.unwrap_or("localhost");
let path = target.path();
let query = target.query().or_else(|| original.query());
let path_and_query = match query {
Some(q) => format!("{}?{}", path, q),
None => path.to_string(),
};
UriBuilder::new()
.scheme(scheme)
.authority(authority)
.path_and_query(path_and_query)
.build()
.expect("Failed to build upstream URI")
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_request() -> Request<Body> {
Request::builder().uri("/test").body(Body::empty()).unwrap()
}
#[test]
fn test_static_string_resolver() {
let resolver = "https://example.com".to_string();
let req = dummy_request();
let params = vec![("id".to_string(), "123".to_string())];
assert_eq!(resolver.resolve(&req, ¶ms), "https://example.com");
}
#[test]
fn test_static_str_resolver() {
let resolver: &'static str = "https://example.com";
let req = dummy_request();
let params = vec![("id".to_string(), "123".to_string())];
assert_eq!(resolver.resolve(&req, ¶ms), "https://example.com");
}
#[test]
fn test_template_resolver_single_param() {
let resolver = proxy_template("https://example.com/users/{id}");
let req = dummy_request();
let params = vec![("id".to_string(), "123".to_string())];
assert_eq!(
resolver.resolve(&req, ¶ms),
"https://example.com/users/123"
);
}
#[test]
fn test_template_resolver_multiple_params() {
let resolver = proxy_template("https://cdn.example.com/{id}/quality_{quality}");
let req = dummy_request();
let params = vec![
("id".to_string(), "video123".to_string()),
("quality".to_string(), "720p".to_string()),
];
assert_eq!(
resolver.resolve(&req, ¶ms),
"https://cdn.example.com/video123/quality_720p"
);
}
#[test]
fn test_template_resolver_missing_param() {
let resolver = proxy_template("https://example.com/{id}/{missing}");
let req = dummy_request();
let params = vec![("id".to_string(), "123".to_string())];
assert_eq!(
resolver.resolve(&req, ¶ms),
"https://example.com/123/{missing}"
);
}
#[test]
fn test_template_resolver_no_params() {
let resolver = proxy_template("https://example.com/static/path");
let req = dummy_request();
let params = vec![("id".to_string(), "123".to_string())];
assert_eq!(
resolver.resolve(&req, ¶ms),
"https://example.com/static/path"
);
}
#[test]
fn test_build_upstream_uri_with_target_query() {
let target: Uri = "https://example.com/path?foo=bar".parse().unwrap();
let original: Uri = "/request?baz=qux".parse().unwrap();
let result = build_upstream_uri(&target, &original);
assert_eq!(result.to_string(), "https://example.com/path?foo=bar");
}
#[test]
fn test_build_upstream_uri_with_original_query() {
let target: Uri = "https://example.com/path".parse().unwrap();
let original: Uri = "/request?baz=qux".parse().unwrap();
let result = build_upstream_uri(&target, &original);
assert_eq!(result.to_string(), "https://example.com/path?baz=qux");
}
#[test]
fn test_build_upstream_uri_no_query() {
let target: Uri = "https://example.com/path".parse().unwrap();
let original: Uri = "/request".parse().unwrap();
let result = build_upstream_uri(&target, &original);
assert_eq!(result.to_string(), "https://example.com/path");
}
}