#[cfg(feature = "axum")]
use crate::result::{ApiResult, ErrorCode};
#[cfg(feature = "axum")]
use axum::{
body::Body,
extract::State,
http::{HeaderMap, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Json, Response},
};
#[cfg(feature = "axum")]
use serde::Serialize;
#[cfg(feature = "axum")]
fn error_code_to_status(code: ErrorCode) -> StatusCode {
match code {
ErrorCode::Success => StatusCode::OK,
ErrorCode::BadRequest => StatusCode::BAD_REQUEST,
ErrorCode::Unauthorized => StatusCode::UNAUTHORIZED,
ErrorCode::Forbidden => StatusCode::FORBIDDEN,
ErrorCode::NotFound => StatusCode::NOT_FOUND,
ErrorCode::Conflict => StatusCode::CONFLICT,
ErrorCode::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
ErrorCode::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
}
}
#[cfg(feature = "axum")]
#[derive(Clone)]
pub struct ApiResultMiddlewareConfig {
pub trace_id_headers: Vec<String>,
pub inject_timestamp: bool,
}
#[cfg(feature = "axum")]
impl Default for ApiResultMiddlewareConfig {
fn default() -> Self {
Self {
trace_id_headers: vec!["x-request-id".to_string(), "x-trace-id".to_string()],
inject_timestamp: false,
}
}
}
#[cfg(feature = "axum")]
pub async fn api_result_middleware_with_config(
State(config): State<ApiResultMiddlewareConfig>,
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
let trace_id = config
.trace_id_headers
.iter()
.find_map(|h| {
headers
.get(h.as_str())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
})
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let response = next.run(request).await;
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.contains("application/json") {
return response;
}
let (parts, body) = response.into_parts();
let bytes = match axum::body::to_bytes(body, 1024 * 64).await {
Ok(b) => b,
Err(_) => return Response::from_parts(parts, Body::empty()),
};
let mut json: serde_json::Value = match serde_json::from_slice(&bytes) {
Ok(v) => v,
Err(_) => {
return Response::from_parts(parts, Body::from(bytes));
}
};
if let Some(obj) = json.as_object_mut() {
if obj.contains_key("success") {
if !obj.contains_key("traceId") && !obj.contains_key("trace_id") {
#[cfg(not(feature = "snake-case"))]
obj.insert("traceId".to_string(), serde_json::Value::String(trace_id));
#[cfg(feature = "snake-case")]
obj.insert("trace_id".to_string(), serde_json::Value::String(trace_id));
}
if config.inject_timestamp {
use std::time::{SystemTime, UNIX_EPOCH};
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as i64)
.unwrap_or(0);
#[cfg(not(feature = "snake-case"))]
obj.insert(
"timestamp".to_string(),
serde_json::Value::Number(ts.into()),
);
#[cfg(feature = "snake-case")]
obj.insert(
"timestamp".to_string(),
serde_json::Value::Number(ts.into()),
);
}
}
}
let new_body = serde_json::to_vec(&json).unwrap_or_else(|_| bytes.to_vec());
Response::from_parts(parts, Body::from(new_body))
}
#[cfg(feature = "axum")]
pub async fn api_result_middleware(
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Response {
let config = ApiResultMiddlewareConfig::default();
api_result_middleware_with_config(State(config), headers, request, next).await
}
#[cfg(feature = "axum")]
impl<T: Serialize> IntoResponse for ApiResult<T> {
fn into_response(self) -> Response {
let status = self
.code
.and_then(ErrorCode::from_i32)
.map(error_code_to_status)
.unwrap_or_else(|| {
if self.success {
StatusCode::OK
} else {
StatusCode::INTERNAL_SERVER_ERROR
}
});
(status, Json(self)).into_response()
}
}