meld-server 0.1.0

Single-port REST + gRPC server framework with FastAPI-like DX for Rust.
use axum::{
    extract::{FromRequest, FromRequestParts, Path, Query},
    http::{request::Parts, StatusCode},
    response::IntoResponse,
    Json,
};
use meld_core::MeldError;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{json, Value};
use tonic::Status;
use validator::{Validate, ValidationErrors};

#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema, PartialEq, Eq)]
pub struct ApiValidationIssue {
    pub loc: Vec<String>,
    pub msg: String,
    #[serde(rename = "type")]
    pub issue_type: String,
}

#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ApiErrorResponse {
    pub code: String,
    pub message: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub detail: Option<Vec<ApiValidationIssue>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub details: Option<Value>,
}

impl ApiErrorResponse {
    pub fn validation(
        message: impl Into<String>,
        detail: Option<Vec<ApiValidationIssue>>,
        details: Option<Value>,
    ) -> Self {
        Self {
            code: "validation_error".to_string(),
            message: message.into(),
            detail,
            details,
        }
    }

    pub fn bad_request(message: impl Into<String>) -> Self {
        Self {
            code: "bad_request".to_string(),
            message: message.into(),
            detail: None,
            details: None,
        }
    }

    pub fn internal_server_error() -> Self {
        Self {
            code: "internal_error".to_string(),
            message: "internal server error".to_string(),
            detail: None,
            details: None,
        }
    }
}

pub type ApiError = (StatusCode, Json<ApiErrorResponse>);

pub fn validation_error(err: ValidationErrors) -> ApiError {
    validation_error_with_source(err, "request")
}

pub fn validation_error_with_source(err: ValidationErrors, source: &'static str) -> ApiError {
    let mut fields = serde_json::Map::new();
    let mut detail = Vec::new();
    for (field, errors) in err.field_errors() {
        let messages: Vec<String> = errors
            .iter()
            .map(|e| {
                e.message
                    .clone()
                    .map(|m| m.to_string())
                    .unwrap_or_else(|| e.code.to_string())
            })
            .collect();
        for error in errors {
            let msg = error
                .message
                .clone()
                .map(|m| m.to_string())
                .unwrap_or_else(|| error.code.to_string());
            detail.push(ApiValidationIssue {
                loc: vec![source.to_string(), field.to_string()],
                msg,
                issue_type: error.code.to_string(),
            });
        }
        fields.insert(field.to_string(), json!(messages));
    }

    (
        StatusCode::BAD_REQUEST,
        Json(ApiErrorResponse::validation(
            "request validation failed",
            Some(detail),
            Some(Value::Object(fields)),
        )),
    )
}

pub fn bad_request(message: impl Into<String>) -> ApiError {
    (
        StatusCode::BAD_REQUEST,
        Json(ApiErrorResponse::bad_request(message)),
    )
}

pub fn map_domain_error_to_rest(err: MeldError) -> ApiError {
    match err {
        MeldError::Validation(message) => {
            let detail = ApiValidationIssue {
                loc: vec!["domain".to_string()],
                msg: message.clone(),
                issue_type: "domain_validation".to_string(),
            };
            (
                StatusCode::BAD_REQUEST,
                Json(ApiErrorResponse::validation(
                    message,
                    Some(vec![detail]),
                    None,
                )),
            )
        }
        MeldError::Internal(message) => {
            tracing::error!(error = %message, "internal domain error surfaced in REST handler");
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ApiErrorResponse::internal_server_error()),
            )
        }
    }
}

pub fn map_domain_error_to_grpc(err: MeldError) -> Status {
    match err {
        MeldError::Validation(message) => Status::invalid_argument(message),
        MeldError::Internal(message) => {
            tracing::error!(error = %message, "internal domain error surfaced in gRPC handler");
            Status::internal("internal server error")
        }
    }
}

#[derive(Debug, Clone)]
pub struct ValidatedJson<T>(pub T);

#[axum::async_trait]
impl<T, S> FromRequest<S> for ValidatedJson<T>
where
    T: DeserializeOwned + Validate,
    S: Send + Sync,
{
    type Rejection = ApiError;

    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
        let Json(value) = Json::<T>::from_request(req, state)
            .await
            .map_err(|err| bad_request(format!("invalid json body: {err}")))?;

        value
            .validate()
            .map_err(|err| validation_error_with_source(err, "body"))?;
        Ok(Self(value))
    }
}

#[derive(Debug, Clone)]
pub struct ValidatedQuery<T>(pub T);

#[axum::async_trait]
impl<T, S> FromRequestParts<S> for ValidatedQuery<T>
where
    T: DeserializeOwned + Validate,
    S: Send + Sync,
{
    type Rejection = ApiError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let Query(value) = Query::<T>::from_request_parts(parts, state)
            .await
            .map_err(|err| bad_request(format!("invalid query: {err}")))?;
        value
            .validate()
            .map_err(|err| validation_error_with_source(err, "query"))?;
        Ok(Self(value))
    }
}

#[derive(Debug, Clone)]
pub struct ValidatedPath<T>(pub T);

#[axum::async_trait]
impl<T, S> FromRequestParts<S> for ValidatedPath<T>
where
    T: DeserializeOwned + Validate + Send,
    S: Send + Sync,
{
    type Rejection = ApiError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let Path(value) = Path::<T>::from_request_parts(parts, state)
            .await
            .map_err(|err| bad_request(format!("invalid path: {err}")))?;
        value
            .validate()
            .map_err(|err| validation_error_with_source(err, "path"))?;
        Ok(Self(value))
    }
}

#[derive(Debug, Clone)]
pub struct ValidatedParts<T>(pub T);

#[axum::async_trait]
impl<T, S> FromRequestParts<S> for ValidatedParts<T>
where
    T: FromRequestParts<S> + Validate,
    T::Rejection: std::fmt::Display,
    S: Send + Sync,
{
    type Rejection = ApiError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let value = T::from_request_parts(parts, state)
            .await
            .map_err(|err| bad_request(format!("invalid request parts: {err}")))?;
        value
            .validate()
            .map_err(|err| validation_error_with_source(err, "parts"))?;
        Ok(Self(value))
    }
}

impl IntoResponse for ApiErrorResponse {
    fn into_response(self) -> axum::response::Response {
        Json(self).into_response()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use validator::Validate;

    #[derive(Debug, serde::Deserialize, Validate)]
    struct BodyDto {
        #[validate(length(min = 3))]
        name: String,
    }

    #[test]
    fn validation_error_includes_fastapi_like_detail_shape() {
        let dto = BodyDto {
            name: "ab".to_string(),
        };
        let err = dto.validate().expect_err("must fail");
        let (status, Json(body)) = validation_error_with_source(err, "body");

        assert_eq!(status, StatusCode::BAD_REQUEST);
        assert_eq!(body.code, "validation_error");
        let detail = body.detail.expect("detail should exist");
        assert!(detail
            .iter()
            .any(|issue| issue.loc == vec!["body".to_string(), "name".to_string()]));
    }

    #[test]
    fn internal_domain_errors_are_sanitized_for_rest_clients() {
        let (status, Json(body)) =
            map_domain_error_to_rest(MeldError::Internal("db exploded".to_string()));
        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
        assert_eq!(body.code, "internal_error");
        assert_eq!(body.message, "internal server error");
        assert!(body.detail.is_none());
    }

    #[test]
    fn internal_domain_errors_are_sanitized_for_grpc_clients() {
        let status = map_domain_error_to_grpc(MeldError::Internal("db exploded".to_string()));
        assert_eq!(status.code(), tonic::Code::Internal);
        assert_eq!(status.message(), "internal server error");
    }
}