meld-server 0.1.0

Single-port REST + gRPC server framework with FastAPI-like DX for Rust.
use std::{env, str::FromStr};

use axum::{
    extract::{Request, State},
    http::{header, HeaderMap, HeaderValue, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
    Json,
};
use meld_core::auth::{validate_bearer_jwt, AuthPrincipal, JwtValidationConfig};
use tonic::Status;

use crate::api::ApiErrorResponse;

#[derive(Debug, Clone, Default)]
pub struct AuthRuntimeConfig {
    pub enabled: bool,
    pub jwt_secret: Option<String>,
    pub expected_issuer: Option<String>,
    pub expected_audience: Option<String>,
}

impl AuthRuntimeConfig {
    pub fn from_env() -> Self {
        Self {
            enabled: read_env_bool_with_fallback("MELD_AUTH_ENABLED", "ALLOY_AUTH_ENABLED")
                .unwrap_or(false),
            jwt_secret: read_env_string_with_fallback(
                "MELD_AUTH_JWT_SECRET",
                "ALLOY_AUTH_JWT_SECRET",
            ),
            expected_issuer: read_env_string_with_fallback("MELD_AUTH_ISSUER", "ALLOY_AUTH_ISSUER"),
            expected_audience: read_env_string_with_fallback(
                "MELD_AUTH_AUDIENCE",
                "ALLOY_AUTH_AUDIENCE",
            ),
        }
    }

    fn jwt_validation_config(&self) -> Result<JwtValidationConfig, AuthRejection> {
        let secret = self.jwt_secret.clone().ok_or_else(|| {
            AuthRejection::Misconfigured("MELD_AUTH_JWT_SECRET is missing".to_string())
        })?;

        Ok(JwtValidationConfig {
            secret,
            expected_issuer: self.expected_issuer.clone(),
            expected_audience: self.expected_audience.clone(),
        })
    }

    pub fn authenticate_authorization_value_str(
        &self,
        auth_value: &str,
    ) -> Result<AuthPrincipal, AuthRejection> {
        if !self.enabled {
            return Ok(AuthPrincipal {
                subject: "anonymous".to_string(),
                issuer: None,
                audience: vec![],
                scopes: vec![],
            });
        }
        let token = parse_bearer_token(auth_value)?;
        let validation_cfg = self.jwt_validation_config()?;
        validate_bearer_jwt(token, &validation_cfg)
            .map_err(|err| AuthRejection::InvalidToken(err.to_string()))
    }

    pub fn authenticate_header_value(
        &self,
        auth_value: Option<&HeaderValue>,
    ) -> Result<AuthPrincipal, AuthRejection> {
        if !self.enabled {
            return Ok(AuthPrincipal {
                subject: "anonymous".to_string(),
                issuer: None,
                audience: vec![],
                scopes: vec![],
            });
        }

        let value = auth_value
            .ok_or(AuthRejection::MissingAuthorization)?
            .to_str()
            .map_err(|_| {
                AuthRejection::InvalidToken("authorization header is invalid".to_string())
            })?;

        self.authenticate_authorization_value_str(value)
    }

    pub fn authenticate_headers(
        &self,
        headers: &HeaderMap,
    ) -> Result<AuthPrincipal, AuthRejection> {
        self.authenticate_header_value(headers.get(header::AUTHORIZATION))
    }
}

#[derive(Debug, Clone)]
pub enum AuthRejection {
    MissingAuthorization,
    InvalidToken(String),
    Misconfigured(String),
}

impl AuthRejection {
    pub fn into_rest_response(self) -> Response {
        match self {
            Self::MissingAuthorization => (
                StatusCode::UNAUTHORIZED,
                Json(ApiErrorResponse {
                    code: "unauthorized".to_string(),
                    message: "missing bearer token".to_string(),
                    detail: None,
                    details: None,
                }),
            )
                .into_response(),
            Self::InvalidToken(message) => (
                StatusCode::UNAUTHORIZED,
                Json(ApiErrorResponse {
                    code: "unauthorized".to_string(),
                    message,
                    detail: None,
                    details: None,
                }),
            )
                .into_response(),
            Self::Misconfigured(message) => (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ApiErrorResponse {
                    code: "internal_error".to_string(),
                    message,
                    detail: None,
                    details: None,
                }),
            )
                .into_response(),
        }
    }

    pub fn into_grpc_status(self) -> Status {
        match self {
            Self::MissingAuthorization => Status::unauthenticated("missing bearer token"),
            Self::InvalidToken(message) => Status::unauthenticated(message),
            Self::Misconfigured(message) => Status::internal(message),
        }
    }
}

pub async fn rest_auth_middleware(
    State(cfg): State<AuthRuntimeConfig>,
    mut req: Request,
    next: Next,
) -> Response {
    match cfg.authenticate_headers(req.headers()) {
        Ok(principal) => {
            req.extensions_mut().insert(principal);
            next.run(req).await
        }
        Err(rejection) => rejection.into_rest_response(),
    }
}

pub fn parse_bearer_token(value: &str) -> Result<&str, AuthRejection> {
    let mut parts = value.splitn(2, ' ');
    let scheme = parts.next().unwrap_or_default();
    let token = parts.next().unwrap_or_default();

    if !scheme.eq_ignore_ascii_case("bearer") || token.trim().is_empty() {
        return Err(AuthRejection::InvalidToken(
            "authorization header must be Bearer <token>".to_string(),
        ));
    }

    Ok(token.trim())
}

fn read_env_bool(name: &str) -> Option<bool> {
    env::var(name)
        .ok()
        .and_then(|raw| bool::from_str(raw.trim()).ok())
}

fn read_env_bool_with_fallback(primary: &str, legacy: &str) -> Option<bool> {
    read_env_bool(primary).or_else(|| read_env_bool(legacy))
}

fn read_env_string_with_fallback(primary: &str, legacy: &str) -> Option<String> {
    env::var(primary).ok().or_else(|| env::var(legacy).ok())
}