cedros-admin-server 0.1.0

Shared admin server shell for Cedros admin composition
Documentation
use std::sync::Arc;

use axum::{
    body::{to_bytes, Body},
    extract::{Path, Query, State},
    http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
    response::{IntoResponse, Response},
    routing::{any, get},
    Json, Router,
};
use reqwest::Client;

use crate::Config;

#[derive(Clone)]
pub struct ProductRouters {
    pub login: Option<Router>,
    pub pay: Option<Router>,
}

impl ProductRouters {
    pub fn empty() -> Self {
        Self {
            login: None,
            pay: None,
        }
    }
}

impl Default for ProductRouters {
    fn default() -> Self {
        Self::empty()
    }
}

#[derive(Clone)]
struct ProxyState {
    client: Client,
    base_url: String,
}

pub fn router(config: &Config) -> Router {
    router_with_products(config, ProductRouters::default())
}

pub fn router_with_products(config: &Config, products: ProductRouters) -> Router {
    let mut app = Router::new().route("/health", get(health));

    if let Some(login_router) = products.login {
        app = app.nest("/login", login_router);
    } else if let Some(upstream) = config.login_upstream_url.clone() {
        app = app.nest("/login", proxy_router(upstream));
    }

    if let Some(pay_router) = products.pay {
        app = app.nest("/pay", pay_router);
    } else if let Some(upstream) = config.pay_upstream_url.clone() {
        app = app.nest("/pay", proxy_router(upstream));
    }

    app
}

fn proxy_router(base_url: String) -> Router {
    let state = ProxyState {
        client: Client::new(),
        base_url,
    };

    Router::new()
        .route("/", any(proxy_root))
        .route("/{*path}", any(proxy_path))
        .with_state(Arc::new(state))
}

async fn health() -> Json<serde_json::Value> {
    Json(serde_json::json!({
        "status": "ok",
        "service": "cedros-admin-server"
    }))
}

async fn proxy_root(
    State(state): State<Arc<ProxyState>>,
    method: Method,
    headers: HeaderMap,
    Query(query): Query<std::collections::HashMap<String, String>>,
    body: Body,
) -> Response {
    proxy_request(state, method, headers, query, None, body).await
}

async fn proxy_path(
    State(state): State<Arc<ProxyState>>,
    method: Method,
    headers: HeaderMap,
    Path(path): Path<String>,
    Query(query): Query<std::collections::HashMap<String, String>>,
    body: Body,
) -> Response {
    proxy_request(state, method, headers, query, Some(path), body).await
}

async fn proxy_request(
    state: Arc<ProxyState>,
    method: Method,
    headers: HeaderMap,
    query: std::collections::HashMap<String, String>,
    path: Option<String>,
    body: Body,
) -> Response {
    let target_url = build_target_url(&state.base_url, path.as_deref(), &query);
    let body_bytes = match to_bytes(body, usize::MAX).await {
        Ok(bytes) => bytes,
        Err(error) => {
            return (
                StatusCode::BAD_REQUEST,
                format!("failed to read request body: {error}"),
            )
                .into_response()
        }
    };

    let reqwest_method = match reqwest::Method::from_bytes(method.as_str().as_bytes()) {
        Ok(method) => method,
        Err(error) => {
            return (
                StatusCode::BAD_REQUEST,
                format!("unsupported method: {error}"),
            )
                .into_response()
        }
    };

    let mut request = state.client.request(reqwest_method, target_url).body(body_bytes.to_vec());
    for (name, value) in forwardable_headers(&headers) {
        request = request.header(name, value);
    }

    match request.send().await {
        Ok(response) => proxy_response(response).await,
        Err(error) => (
            StatusCode::BAD_GATEWAY,
            format!("upstream request failed: {error}"),
        )
            .into_response(),
    }
}

fn build_target_url(
    base_url: &str,
    path: Option<&str>,
    query: &std::collections::HashMap<String, String>,
) -> String {
    let trimmed_base = base_url.trim_end_matches('/');
    let path = path.unwrap_or("").trim_start_matches('/');
    let mut url = if path.is_empty() {
        trimmed_base.to_string()
    } else {
        format!("{trimmed_base}/{path}")
    };

    if !query.is_empty() {
        let mut pairs: Vec<_> = query.iter().collect();
        pairs.sort_by(|(left, _), (right, _)| left.cmp(right));
        let query_string = pairs
            .into_iter()
            .map(|(key, value)| format!("{}={}", urlencoding::encode(key), urlencoding::encode(value)))
            .collect::<Vec<_>>()
            .join("&");
        url.push('?');
        url.push_str(&query_string);
    }

    url
}

fn forwardable_headers(headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
    const HOP_BY_HOP: &[&str] = &[
        "connection",
        "keep-alive",
        "proxy-authenticate",
        "proxy-authorization",
        "te",
        "trailer",
        "transfer-encoding",
        "upgrade",
        "host",
        "content-length",
    ];

    headers
        .iter()
        .filter_map(|(name, value)| {
            if HOP_BY_HOP.iter().any(|blocked| name.as_str().eq_ignore_ascii_case(blocked)) {
                return None;
            }

            Some((name.clone(), value.clone()))
        })
        .collect()
}

async fn proxy_response(response: reqwest::Response) -> Response {
    let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
    let mut builder = Response::builder().status(status);

    for (name, value) in response.headers().iter() {
        if name.as_str().eq_ignore_ascii_case("content-length")
            || name.as_str().eq_ignore_ascii_case("transfer-encoding")
        {
            continue;
        }

        builder = builder.header(name, value);
    }

    match response.bytes().await {
        Ok(bytes) => builder
            .body(Body::from(bytes))
            .unwrap_or_else(|_| (StatusCode::BAD_GATEWAY, "failed to build proxy response").into_response()),
        Err(error) => (
            StatusCode::BAD_GATEWAY,
            format!("failed to read upstream response: {error}"),
        )
            .into_response(),
    }
}

#[cfg(test)]
mod tests {
    use axum::{routing::get, Router};
    use reqwest::StatusCode;
    use tokio::net::TcpListener;

    use super::{router, router_with_products, Config, ProductRouters};

    #[tokio::test]
    async fn health_endpoint_is_available() {
        let config = Config::default();
        let app = router(&config);
        let address = spawn(app).await;

        let response = reqwest::get(format!("http://{address}/health")).await.unwrap();

        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(response.json::<serde_json::Value>().await.unwrap()["status"], "ok");
    }

    #[tokio::test]
    async fn embedded_product_router_is_mounted_under_login_prefix() {
        let config = Config::default();
        let products = ProductRouters {
            login: Some(Router::new().route("/admin/ping", get(|| async { "pong" }))),
            pay: None,
        };
        let app = router_with_products(&config, products);
        let address = spawn(app).await;

        let response = reqwest::get(format!("http://{address}/login/admin/ping"))
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(response.text().await.unwrap(), "pong");
    }

    #[tokio::test]
    async fn standalone_proxy_forwards_requests_to_login_upstream() {
        let upstream = Router::new().route("/admin/ping", get(|| async { "upstream-pong" }));
        let upstream_address = spawn(upstream).await;

        let config = Config {
            login_upstream_url: Some(format!("http://{upstream_address}")),
            ..Config::default()
        };
        let app = router(&config);
        let address = spawn(app).await;

        let response = reqwest::get(format!("http://{address}/login/admin/ping"))
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(response.text().await.unwrap(), "upstream-pong");
    }

    async fn spawn(app: Router) -> std::net::SocketAddr {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let address = listener.local_addr().unwrap();
        tokio::spawn(async move {
            axum::serve(listener, app).await.unwrap();
        });
        address
    }
}