lmrc-http-common 0.3.16

Common HTTP utilities and patterns for LMRC Stack applications
Documentation
//! Request extractors with automatic validation
//!
//! This module provides Axum extractors that automatically validate request data
//! using the `validator` crate.
//!
//! ## Example
//!
//! ```rust
//! use axum::{Router, routing::post};
//! use lmrc_http_common::extractors::ValidatedJson;
//! use serde::Deserialize;
//! use validator::Validate;
//!
//! #[derive(Debug, Deserialize, Validate)]
//! struct CreateUser {
//!     #[validate(length(min = 3, max = 50))]
//!     username: String,
//!     #[validate(email)]
//!     email: String,
//! }
//!
//! async fn create_user(
//!     ValidatedJson(payload): ValidatedJson<CreateUser>
//! ) -> &'static str {
//!     // payload is automatically validated!
//!     "User created"
//! }
//!
//! # async fn example() {
//! let app: Router = Router::new().route("/users", post(create_user));
//! # }
//! ```

use crate::error::HttpError;
use async_trait::async_trait;
use axum::{
    extract::{FromRequest, Request},
    Json,
};
use serde::de::DeserializeOwned;

#[cfg(feature = "validation")]
use validator::Validate;

/// JSON extractor with automatic validation
///
/// This extractor deserializes JSON and automatically validates it using
/// the `validator` crate. If validation fails, it returns a 422 Unprocessable
/// Entity response with validation error details.
///
/// ## Example
///
/// ```rust
/// use lmrc_http_common::extractors::ValidatedJson;
/// use serde::Deserialize;
/// use validator::Validate;
///
/// #[derive(Deserialize, Validate)]
/// struct SignupRequest {
///     #[validate(length(min = 3))]
///     username: String,
///     #[validate(email)]
///     email: String,
/// }
///
/// async fn signup(
///     ValidatedJson(req): ValidatedJson<SignupRequest>
/// ) -> &'static str {
///     "Success"
/// }
/// ```
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedJson<T>(pub T);

#[cfg(feature = "validation")]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedJson<T>
where
    T: DeserializeOwned + Validate,
    S: Send + Sync,
{
    type Rejection = HttpError;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let Json(value) = Json::<T>::from_request(req, state)
            .await
            .map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;

        value
            .validate()
            .map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;

        Ok(ValidatedJson(value))
    }
}

#[cfg(not(feature = "validation"))]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedJson<T>
where
    T: DeserializeOwned,
    S: Send + Sync,
{
    type Rejection = HttpError;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let Json(value) = Json::<T>::from_request(req, state)
            .await
            .map_err(|e| HttpError::BadRequest(format!("Invalid JSON: {}", e)))?;

        Ok(ValidatedJson(value))
    }
}

/// Query parameters extractor with automatic validation
///
/// This extractor deserializes query parameters and automatically validates them.
///
/// ## Example
///
/// ```rust
/// use lmrc_http_common::extractors::ValidatedQuery;
/// use serde::Deserialize;
/// use validator::Validate;
///
/// #[derive(Deserialize, Validate)]
/// struct Pagination {
///     #[validate(range(min = 1, max = 100))]
///     page: u32,
///     #[validate(range(min = 1, max = 100))]
///     per_page: u32,
/// }
///
/// async fn list_items(
///     ValidatedQuery(params): ValidatedQuery<Pagination>
/// ) -> &'static str {
///     "Items list"
/// }
/// ```
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedQuery<T>(pub T);

#[cfg(feature = "validation")]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedQuery<T>
where
    T: DeserializeOwned + Validate,
    S: Send + Sync,
{
    type Rejection = HttpError;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
            .await
            .map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;

        value
            .validate()
            .map_err(|e| HttpError::ValidationError(format_validation_errors(&e)))?;

        Ok(ValidatedQuery(value))
    }
}

#[cfg(not(feature = "validation"))]
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedQuery<T>
where
    T: DeserializeOwned,
    S: Send + Sync,
{
    type Rejection = HttpError;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let axum::extract::Query(value) = axum::extract::Query::<T>::from_request(req, state)
            .await
            .map_err(|e| HttpError::BadRequest(format!("Invalid query parameters: {}", e)))?;

        Ok(ValidatedQuery(value))
    }
}

/// Format validation errors into a user-friendly string
#[cfg(feature = "validation")]
fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
    use std::fmt::Write;

    let mut message = String::new();
    let mut first = true;

    for (field, field_errors) in errors.field_errors() {
        for error in field_errors {
            if !first {
                write!(&mut message, "; ").unwrap();
            }
            first = false;

            write!(&mut message, "{}: ", field).unwrap();

            if let Some(msg) = &error.message {
                write!(&mut message, "{}", msg).unwrap();
            } else {
                write!(&mut message, "validation failed ({})", error.code).unwrap();
            }
        }
    }

    if message.is_empty() {
        "Validation failed".to_string()
    } else {
        message
    }
}

#[cfg(all(test, feature = "validation"))]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
        routing::post,
        Router,
    };
    use serde::{Deserialize, Serialize};
    use tower::ServiceExt;
    use validator::Validate;

    #[derive(Debug, Deserialize, Serialize, Validate)]
    struct TestPayload {
        #[validate(length(min = 3, max = 10))]
        name: String,
        #[validate(range(min = 18, max = 100))]
        age: u32,
    }

    async fn test_handler(ValidatedJson(payload): ValidatedJson<TestPayload>) -> StatusCode {
        assert_eq!(payload.name.len(), 5);
        StatusCode::OK
    }

    #[tokio::test]
    async fn test_validated_json_success() {
        let app = Router::new().route("/", post(test_handler));

        let payload = TestPayload {
            name: "Alice".to_string(),
            age: 25,
        };

        let response = app
            .oneshot(
                Request::builder()
                    .method("POST")
                    .uri("/")
                    .header("content-type", "application/json")
                    .body(Body::from(serde_json::to_string(&payload).unwrap()))
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
    }

    #[tokio::test]
    async fn test_validated_json_validation_error() {
        let app = Router::new().route("/", post(test_handler));

        let payload = TestPayload {
            name: "AB".to_string(), // Too short
            age: 25,
        };

        let response = app
            .oneshot(
                Request::builder()
                    .method("POST")
                    .uri("/")
                    .header("content-type", "application/json")
                    .body(Body::from(serde_json::to_string(&payload).unwrap()))
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
    }

    #[tokio::test]
    async fn test_validated_json_invalid_json() {
        let app = Router::new().route("/", post(test_handler));

        let response = app
            .oneshot(
                Request::builder()
                    .method("POST")
                    .uri("/")
                    .header("content-type", "application/json")
                    .body(Body::from("invalid json"))
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
    }
}