meld-server 0.1.0

Single-port REST + gRPC server framework with FastAPI-like DX for Rust.
use std::{env, str::FromStr, time::Duration};

use axum::{
    error_handling::HandleErrorLayer,
    http::{HeaderName, HeaderValue, StatusCode},
    BoxError, Router,
};
use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, ServiceBuilder};
use tower_http::{
    cors::{Any, CorsLayer},
    limit::RequestBodyLimitLayer,
    request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
    trace::TraceLayer,
};

const REQUEST_ID_HEADER: &str = "x-request-id";
const DEFAULT_TIMEOUT_SECONDS: u64 = 15;
const DEFAULT_MAX_IN_FLIGHT_REQUESTS: usize = 1024;
const DEFAULT_REQUEST_BODY_LIMIT_BYTES: usize = 1_048_576;

#[derive(Debug, Clone, Default)]
pub enum CorsAllowOrigins {
    #[default]
    None,
    Any,
    List(Vec<HeaderValue>),
}

#[derive(Debug, Clone)]
pub struct MiddlewareConfig {
    pub timeout_seconds: u64,
    pub max_in_flight_requests: usize,
    pub max_request_body_bytes: usize,
    pub cors_allow_origins: CorsAllowOrigins,
}

impl Default for MiddlewareConfig {
    fn default() -> Self {
        Self {
            timeout_seconds: DEFAULT_TIMEOUT_SECONDS,
            max_in_flight_requests: DEFAULT_MAX_IN_FLIGHT_REQUESTS,
            max_request_body_bytes: DEFAULT_REQUEST_BODY_LIMIT_BYTES,
            cors_allow_origins: CorsAllowOrigins::None,
        }
    }
}

impl MiddlewareConfig {
    pub fn from_env() -> Self {
        Self {
            timeout_seconds: read_env_with_fallback(
                "MELD_TIMEOUT_SECONDS",
                "ALLOY_TIMEOUT_SECONDS",
            )
            .unwrap_or(DEFAULT_TIMEOUT_SECONDS),
            max_in_flight_requests: read_env_with_fallback(
                "MELD_MAX_IN_FLIGHT_REQUESTS",
                "ALLOY_MAX_IN_FLIGHT_REQUESTS",
            )
            .unwrap_or(DEFAULT_MAX_IN_FLIGHT_REQUESTS),
            max_request_body_bytes: read_env_with_fallback(
                "MELD_REQUEST_BODY_LIMIT_BYTES",
                "ALLOY_REQUEST_BODY_LIMIT_BYTES",
            )
            .unwrap_or(DEFAULT_REQUEST_BODY_LIMIT_BYTES),
            cors_allow_origins: parse_cors_allow_origins(
                env::var("MELD_CORS_ALLOW_ORIGINS")
                    .ok()
                    .or_else(|| env::var("ALLOY_CORS_ALLOW_ORIGINS").ok()),
            ),
        }
    }
}

pub fn apply_shared_middleware(app: Router, config: &MiddlewareConfig) -> Router {
    let app = match &config.cors_allow_origins {
        CorsAllowOrigins::None => app,
        CorsAllowOrigins::Any => app.layer(CorsLayer::new().allow_origin(Any)),
        CorsAllowOrigins::List(origins) => {
            app.layer(CorsLayer::new().allow_origin(origins.clone()))
        }
    };

    app.layer(
        ServiceBuilder::new()
            .layer(HandleErrorLayer::new(handle_middleware_error))
            .layer(TraceLayer::new_for_http())
            .layer(PropagateRequestIdLayer::new(header_name()))
            .layer(SetRequestIdLayer::new(header_name(), MakeRequestUuid))
            .layer(RequestBodyLimitLayer::new(config.max_request_body_bytes))
            .layer(TimeoutLayer::new(Duration::from_secs(
                config.timeout_seconds,
            )))
            .layer(ConcurrencyLimitLayer::new(config.max_in_flight_requests)),
    )
}

async fn handle_middleware_error(error: BoxError) -> (StatusCode, String) {
    if error.is::<tower::timeout::error::Elapsed>() {
        return (StatusCode::REQUEST_TIMEOUT, "request timed out".to_string());
    }

    tracing::error!(error = %error, "unhandled middleware error");
    (
        StatusCode::INTERNAL_SERVER_ERROR,
        "internal server error".to_string(),
    )
}

fn header_name() -> HeaderName {
    HeaderName::from_static(REQUEST_ID_HEADER)
}

fn read_env<T>(name: &str) -> Option<T>
where
    T: FromStr,
{
    env::var(name).ok().and_then(|raw| raw.parse::<T>().ok())
}

fn read_env_with_fallback<T>(primary: &str, legacy: &str) -> Option<T>
where
    T: FromStr,
{
    read_env(primary).or_else(|| read_env(legacy))
}

fn parse_cors_allow_origins(raw: Option<String>) -> CorsAllowOrigins {
    let Some(raw) = raw else {
        return CorsAllowOrigins::None;
    };

    let raw = raw.trim();
    if raw.is_empty() {
        return CorsAllowOrigins::None;
    }
    if raw == "*" {
        return CorsAllowOrigins::Any;
    }

    let origins = raw
        .split(',')
        .map(str::trim)
        .filter(|origin| !origin.is_empty())
        .filter_map(|origin| match HeaderValue::from_str(origin) {
            Ok(value) => Some(value),
            Err(err) => {
                tracing::warn!(origin = %origin, error = %err, "ignoring invalid cors origin");
                None
            }
        })
        .collect::<Vec<_>>();

    if origins.is_empty() {
        CorsAllowOrigins::None
    } else {
        CorsAllowOrigins::List(origins)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{header::ORIGIN, Request},
        routing::{get, post},
    };
    use tower::util::ServiceExt;

    #[test]
    fn default_config_is_reasonable() {
        let config = MiddlewareConfig::default();
        assert_eq!(config.timeout_seconds, DEFAULT_TIMEOUT_SECONDS);
        assert_eq!(
            config.max_in_flight_requests,
            DEFAULT_MAX_IN_FLIGHT_REQUESTS
        );
        assert_eq!(
            config.max_request_body_bytes,
            DEFAULT_REQUEST_BODY_LIMIT_BYTES
        );
        assert!(matches!(config.cors_allow_origins, CorsAllowOrigins::None));
    }

    #[test]
    fn parse_cors_allow_origins_supports_wildcard_and_allowlist() {
        assert!(matches!(
            parse_cors_allow_origins(Some("*".to_string())),
            CorsAllowOrigins::Any
        ));

        let parsed =
            parse_cors_allow_origins(Some("https://one.example, https://two.example".to_string()));
        match parsed {
            CorsAllowOrigins::List(origins) => assert_eq!(origins.len(), 2),
            _ => panic!("expected allowlist"),
        }
    }

    #[tokio::test]
    async fn default_cors_policy_does_not_emit_allow_origin_header() {
        let app = apply_shared_middleware(
            Router::new().route("/health", get(|| async { "ok" })),
            &MiddlewareConfig::default(),
        );

        let response = app
            .oneshot(
                Request::builder()
                    .uri("/health")
                    .header(ORIGIN, "https://example.com")
                    .body(Body::empty())
                    .unwrap(),
            )
            .await
            .expect("request should complete");

        assert!(response
            .headers()
            .get("access-control-allow-origin")
            .is_none());
    }

    #[tokio::test]
    async fn request_body_limit_rejects_oversized_payload() {
        let config = MiddlewareConfig {
            max_request_body_bytes: 8,
            ..MiddlewareConfig::default()
        };

        let app = apply_shared_middleware(
            Router::new().route("/echo", post(|body: String| async move { body })),
            &config,
        );

        let response = app
            .oneshot(
                Request::builder()
                    .method("POST")
                    .uri("/echo")
                    .body(Body::from("0123456789012345"))
                    .unwrap(),
            )
            .await
            .expect("request should complete");
        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
    }

    #[tokio::test]
    async fn middleware_internal_error_response_is_generic() {
        let (status, body) =
            handle_middleware_error(std::io::Error::other("sensitive").into()).await;
        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
        assert_eq!(body, "internal server error");
    }
}