anycms-core 0.5.4

A unified API response library supporting multiple Rust web frameworks
Documentation
#[cfg(feature = "axum")]
use crate::result::{ApiResult, ErrorCode};
#[cfg(feature = "axum")]
use axum::{
    body::Body,
    extract::State,
    http::{HeaderMap, Request, StatusCode},
    middleware::Next,
    response::{IntoResponse, Json, Response},
};
#[cfg(feature = "axum")]
use serde::Serialize;

#[cfg(feature = "axum")]
fn error_code_to_status(code: ErrorCode) -> StatusCode {
    match code {
        ErrorCode::Success => StatusCode::OK,
        ErrorCode::BadRequest => StatusCode::BAD_REQUEST,
        ErrorCode::Unauthorized => StatusCode::UNAUTHORIZED,
        ErrorCode::Forbidden => StatusCode::FORBIDDEN,
        ErrorCode::NotFound => StatusCode::NOT_FOUND,
        ErrorCode::Conflict => StatusCode::CONFLICT,
        ErrorCode::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
        ErrorCode::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
    }
}

// ============================================================
// Middleware Configuration
// ============================================================

/// Configuration for the axum `api_result_middleware`.
#[cfg(feature = "axum")]
#[derive(Clone)]
pub struct ApiResultMiddlewareConfig {
    /// Header names to check for an existing trace ID (checked in order).
    ///
    /// Default: `["x-request-id", "x-trace-id"]`
    pub trace_id_headers: Vec<String>,

    /// Whether to auto-inject the current Unix-millisecond timestamp.
    ///
    /// Default: `false`
    pub inject_timestamp: bool,
}

#[cfg(feature = "axum")]
impl Default for ApiResultMiddlewareConfig {
    fn default() -> Self {
        Self {
            trace_id_headers: vec!["x-request-id".to_string(), "x-trace-id".to_string()],
            inject_timestamp: false,
        }
    }
}

// ============================================================
// Middleware (from_fn_with_state pattern)
// ============================================================

/// Axum middleware that auto-injects `trace_id` (and optionally `timestamp`)
/// into JSON API responses.
///
/// Since axum's `IntoResponse` trait does not have access to the request,
/// this middleware intercepts the response body, detects JSON responses that
/// look like `ApiResult` (contain a `"success"` field), and patches in
/// `traceId` / `timestamp`.
///
/// # Usage
///
/// ```ignore
/// use axum::Router;
/// use axum::routing::get;
/// use axum::middleware;
/// use anycms_core::axum::{ApiResultMiddlewareConfig, api_result_middleware_with_config};
///
/// let config = ApiResultMiddlewareConfig::default();
/// let app = Router::new()
///     .route("/users", get(list_users))
///     .layer(middleware::from_fn_with_state(
///         config,
///         api_result_middleware_with_config,
///     ));
/// ```
#[cfg(feature = "axum")]
pub async fn api_result_middleware_with_config(
    State(config): State<ApiResultMiddlewareConfig>,
    headers: HeaderMap,
    request: Request<Body>,
    next: Next,
) -> Response {
    // Extract or generate trace_id
    let trace_id = config
        .trace_id_headers
        .iter()
        .find_map(|h| {
            headers
                .get(h.as_str())
                .and_then(|v| v.to_str().ok())
                .map(|s| s.to_string())
        })
        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());

    let response = next.run(request).await;

    // Only process JSON responses
    let content_type = response
        .headers()
        .get("content-type")
        .and_then(|v| v.to_str().ok())
        .unwrap_or("");

    if !content_type.contains("application/json") {
        return response;
    }

    // Read response body
    let (parts, body) = response.into_parts();
    let bytes = match axum::body::to_bytes(body, 1024 * 64).await {
        Ok(b) => b,
        Err(_) => return Response::from_parts(parts, Body::empty()),
    };

    // Parse JSON and inject fields if it looks like an ApiResult
    let mut json: serde_json::Value = match serde_json::from_slice(&bytes) {
        Ok(v) => v,
        Err(_) => {
            // Not valid JSON, return as-is
            return Response::from_parts(parts, Body::from(bytes));
        }
    };

    if let Some(obj) = json.as_object_mut() {
        // Only patch responses that look like ApiResult (have "success" field)
        if obj.contains_key("success") {
            // Inject trace_id if not already present
            if !obj.contains_key("traceId") && !obj.contains_key("trace_id") {
                #[cfg(not(feature = "snake-case"))]
                obj.insert("traceId".to_string(), serde_json::Value::String(trace_id));
                #[cfg(feature = "snake-case")]
                obj.insert("trace_id".to_string(), serde_json::Value::String(trace_id));
            }

            // Optionally inject timestamp
            if config.inject_timestamp {
                use std::time::{SystemTime, UNIX_EPOCH};
                let ts = SystemTime::now()
                    .duration_since(UNIX_EPOCH)
                    .map(|d| d.as_millis() as i64)
                    .unwrap_or(0);
                #[cfg(not(feature = "snake-case"))]
                obj.insert(
                    "timestamp".to_string(),
                    serde_json::Value::Number(ts.into()),
                );
                #[cfg(feature = "snake-case")]
                obj.insert(
                    "timestamp".to_string(),
                    serde_json::Value::Number(ts.into()),
                );
            }
        }
    }

    let new_body = serde_json::to_vec(&json).unwrap_or_else(|_| bytes.to_vec());
    Response::from_parts(parts, Body::from(new_body))
}

/// Axum middleware with default configuration (zero-config).
///
/// # Usage
///
/// ```ignore
/// use axum::Router;
/// use axum::routing::get;
/// use axum::middleware;
/// use anycms_core::axum::api_result_middleware;
///
/// let app = Router::new()
///     .route("/users", get(list_users))
///     .layer(middleware::from_fn(api_result_middleware));
/// ```
#[cfg(feature = "axum")]
pub async fn api_result_middleware(
    headers: HeaderMap,
    request: Request<Body>,
    next: Next,
) -> Response {
    let config = ApiResultMiddlewareConfig::default();
    api_result_middleware_with_config(State(config), headers, request, next).await
}

// ============================================================
// IntoResponse implementation
// ============================================================

/// Axum framework integration for ApiResult.
#[cfg(feature = "axum")]
impl<T: Serialize> IntoResponse for ApiResult<T> {
    fn into_response(self) -> Response {
        let status = self
            .code
            .and_then(ErrorCode::from_i32)
            .map(error_code_to_status)
            .unwrap_or_else(|| {
                if self.success {
                    StatusCode::OK
                } else {
                    StatusCode::INTERNAL_SERVER_ERROR
                }
            });

        (status, Json(self)).into_response()
    }
}