#[cfg(feature = "actix")]
use crate::result::{ApiResult, ErrorCode};
#[cfg(feature = "actix")]
use actix_web::Error;
#[cfg(feature = "actix")]
use actix_web::HttpMessage;
#[cfg(feature = "actix")]
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready};
#[cfg(feature = "actix")]
use actix_web::{HttpResponse, Responder};
#[cfg(feature = "actix")]
use futures_util::future::LocalBoxFuture;
#[cfg(feature = "actix")]
use serde::Serialize;
#[cfg(feature = "actix")]
use std::rc::Rc;
#[cfg(feature = "actix")]
fn error_code_to_status(code: ErrorCode) -> actix_web::http::StatusCode {
match code {
ErrorCode::Success => actix_web::http::StatusCode::OK,
ErrorCode::BadRequest => actix_web::http::StatusCode::BAD_REQUEST,
ErrorCode::Unauthorized => actix_web::http::StatusCode::UNAUTHORIZED,
ErrorCode::Forbidden => actix_web::http::StatusCode::FORBIDDEN,
ErrorCode::NotFound => actix_web::http::StatusCode::NOT_FOUND,
ErrorCode::Conflict => actix_web::http::StatusCode::CONFLICT,
ErrorCode::ValidationError => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY,
ErrorCode::InternalError => actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
}
}
#[cfg(feature = "actix")]
#[derive(Clone)]
pub struct ApiTraceId(pub String);
#[cfg(feature = "actix")]
#[derive(Clone)]
pub struct ApiResultMiddlewareConfig {
pub trace_id_headers: Vec<String>,
pub inject_timestamp: bool,
}
#[cfg(feature = "actix")]
impl Default for ApiResultMiddlewareConfig {
fn default() -> Self {
Self {
trace_id_headers: vec!["X-Request-ID".to_string(), "X-Trace-ID".to_string()],
inject_timestamp: false,
}
}
}
#[cfg(feature = "actix")]
pub struct ApiResultLayer {
config: ApiResultMiddlewareConfig,
}
#[cfg(feature = "actix")]
impl ApiResultLayer {
pub fn new(config: ApiResultMiddlewareConfig) -> Self {
Self { config }
}
}
#[cfg(feature = "actix")]
impl Default for ApiResultLayer {
fn default() -> Self {
Self::new(ApiResultMiddlewareConfig::default())
}
}
#[cfg(feature = "actix")]
impl<S, B> Transform<S, ServiceRequest> for ApiResultLayer
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = ApiResultMiddleware<S>;
type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
std::future::ready(Ok(ApiResultMiddleware {
service: Rc::new(service),
config: self.config.clone(),
}))
}
}
#[cfg(feature = "actix")]
pub struct ApiResultMiddleware<S> {
service: Rc<S>,
config: ApiResultMiddlewareConfig,
}
#[cfg(feature = "actix")]
impl<S, B> Service<ServiceRequest> for ApiResultMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let service = Rc::clone(&self.service);
let trace_id = self
.config
.trace_id_headers
.iter()
.find_map(|h| {
req.headers()
.get(h.as_str())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
})
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
req.extensions_mut().insert(ApiTraceId(trace_id));
Box::pin(async move {
let res = service.call(req).await?;
Ok(res)
})
}
}
#[cfg(feature = "actix")]
impl<T: Serialize> Responder for ApiResult<T> {
type Body = actix_web::body::BoxBody;
fn respond_to(self, req: &actix_web::HttpRequest) -> actix_web::HttpResponse<Self::Body> {
let mut result = self;
if result.trace_id.is_none()
&& let Some(tid) = req.extensions().get::<ApiTraceId>()
{
result = result.with_trace_id(&tid.0);
}
let status = result
.code
.and_then(ErrorCode::from_i32)
.map(error_code_to_status)
.unwrap_or_else(|| {
if result.success {
actix_web::http::StatusCode::OK
} else {
actix_web::http::StatusCode::INTERNAL_SERVER_ERROR
}
});
HttpResponse::build(status).json(result)
}
}