use std::sync::Arc;
use axum::{
body::{to_bytes, Body},
extract::{Path, Query, State},
http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
response::{IntoResponse, Response},
routing::{any, get},
Json, Router,
};
use reqwest::Client;
use crate::Config;
#[derive(Clone)]
pub struct ProductRouters {
pub login: Option<Router>,
pub pay: Option<Router>,
}
impl ProductRouters {
pub fn empty() -> Self {
Self {
login: None,
pay: None,
}
}
}
impl Default for ProductRouters {
fn default() -> Self {
Self::empty()
}
}
#[derive(Clone)]
struct ProxyState {
client: Client,
base_url: String,
}
pub fn router(config: &Config) -> Router {
router_with_products(config, ProductRouters::default())
}
pub fn router_with_products(config: &Config, products: ProductRouters) -> Router {
let mut app = Router::new().route("/health", get(health));
if let Some(login_router) = products.login {
app = app.nest("/login", login_router);
} else if let Some(upstream) = config.login_upstream_url.clone() {
app = app.nest("/login", proxy_router(upstream));
}
if let Some(pay_router) = products.pay {
app = app.nest("/pay", pay_router);
} else if let Some(upstream) = config.pay_upstream_url.clone() {
app = app.nest("/pay", proxy_router(upstream));
}
app
}
fn proxy_router(base_url: String) -> Router {
let state = ProxyState {
client: Client::new(),
base_url,
};
Router::new()
.route("/", any(proxy_root))
.route("/{*path}", any(proxy_path))
.with_state(Arc::new(state))
}
async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"service": "cedros-admin-server"
}))
}
async fn proxy_root(
State(state): State<Arc<ProxyState>>,
method: Method,
headers: HeaderMap,
Query(query): Query<std::collections::HashMap<String, String>>,
body: Body,
) -> Response {
proxy_request(state, method, headers, query, None, body).await
}
async fn proxy_path(
State(state): State<Arc<ProxyState>>,
method: Method,
headers: HeaderMap,
Path(path): Path<String>,
Query(query): Query<std::collections::HashMap<String, String>>,
body: Body,
) -> Response {
proxy_request(state, method, headers, query, Some(path), body).await
}
async fn proxy_request(
state: Arc<ProxyState>,
method: Method,
headers: HeaderMap,
query: std::collections::HashMap<String, String>,
path: Option<String>,
body: Body,
) -> Response {
let target_url = build_target_url(&state.base_url, path.as_deref(), &query);
let body_bytes = match to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(error) => {
return (
StatusCode::BAD_REQUEST,
format!("failed to read request body: {error}"),
)
.into_response()
}
};
let reqwest_method = match reqwest::Method::from_bytes(method.as_str().as_bytes()) {
Ok(method) => method,
Err(error) => {
return (
StatusCode::BAD_REQUEST,
format!("unsupported method: {error}"),
)
.into_response()
}
};
let mut request = state.client.request(reqwest_method, target_url).body(body_bytes.to_vec());
for (name, value) in forwardable_headers(&headers) {
request = request.header(name, value);
}
match request.send().await {
Ok(response) => proxy_response(response).await,
Err(error) => (
StatusCode::BAD_GATEWAY,
format!("upstream request failed: {error}"),
)
.into_response(),
}
}
fn build_target_url(
base_url: &str,
path: Option<&str>,
query: &std::collections::HashMap<String, String>,
) -> String {
let trimmed_base = base_url.trim_end_matches('/');
let path = path.unwrap_or("").trim_start_matches('/');
let mut url = if path.is_empty() {
trimmed_base.to_string()
} else {
format!("{trimmed_base}/{path}")
};
if !query.is_empty() {
let mut pairs: Vec<_> = query.iter().collect();
pairs.sort_by(|(left, _), (right, _)| left.cmp(right));
let query_string = pairs
.into_iter()
.map(|(key, value)| format!("{}={}", urlencoding::encode(key), urlencoding::encode(value)))
.collect::<Vec<_>>()
.join("&");
url.push('?');
url.push_str(&query_string);
}
url
}
fn forwardable_headers(headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
const HOP_BY_HOP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
"host",
"content-length",
];
headers
.iter()
.filter_map(|(name, value)| {
if HOP_BY_HOP.iter().any(|blocked| name.as_str().eq_ignore_ascii_case(blocked)) {
return None;
}
Some((name.clone(), value.clone()))
})
.collect()
}
async fn proxy_response(response: reqwest::Response) -> Response {
let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let mut builder = Response::builder().status(status);
for (name, value) in response.headers().iter() {
if name.as_str().eq_ignore_ascii_case("content-length")
|| name.as_str().eq_ignore_ascii_case("transfer-encoding")
{
continue;
}
builder = builder.header(name, value);
}
match response.bytes().await {
Ok(bytes) => builder
.body(Body::from(bytes))
.unwrap_or_else(|_| (StatusCode::BAD_GATEWAY, "failed to build proxy response").into_response()),
Err(error) => (
StatusCode::BAD_GATEWAY,
format!("failed to read upstream response: {error}"),
)
.into_response(),
}
}
#[cfg(test)]
mod tests {
use axum::{routing::get, Router};
use reqwest::StatusCode;
use tokio::net::TcpListener;
use super::{router, router_with_products, Config, ProductRouters};
#[tokio::test]
async fn health_endpoint_is_available() {
let config = Config::default();
let app = router(&config);
let address = spawn(app).await;
let response = reqwest::get(format!("http://{address}/health")).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.json::<serde_json::Value>().await.unwrap()["status"], "ok");
}
#[tokio::test]
async fn embedded_product_router_is_mounted_under_login_prefix() {
let config = Config::default();
let products = ProductRouters {
login: Some(Router::new().route("/admin/ping", get(|| async { "pong" }))),
pay: None,
};
let app = router_with_products(&config, products);
let address = spawn(app).await;
let response = reqwest::get(format!("http://{address}/login/admin/ping"))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.text().await.unwrap(), "pong");
}
#[tokio::test]
async fn standalone_proxy_forwards_requests_to_login_upstream() {
let upstream = Router::new().route("/admin/ping", get(|| async { "upstream-pong" }));
let upstream_address = spawn(upstream).await;
let config = Config {
login_upstream_url: Some(format!("http://{upstream_address}")),
..Config::default()
};
let app = router(&config);
let address = spawn(app).await;
let response = reqwest::get(format!("http://{address}/login/admin/ping"))
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.text().await.unwrap(), "upstream-pong");
}
async fn spawn(app: Router) -> std::net::SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
address
}
}