actix-cloud 0.5.1

Actix Cloud is an all-in-one web framework based on Actix Web.
use std::{net::SocketAddr, rc::Rc, sync::Arc};

use actix_web::{
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    HttpMessage as _,
};
use chrono::{DateTime, Utc};
use futures::future::{ready, LocalBoxFuture, Ready};

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone)]
pub struct Extension {
    /// Request start time.
    pub start_time: DateTime<Utc>,

    #[cfg(feature = "i18n")]
    /// Request language.
    pub lang: String,

    #[cfg(feature = "traceid")]
    pub trace_id: String,

    pub real_ip: SocketAddr,
}

pub type RealIPFunc = Rc<dyn Fn(&ServiceRequest) -> SocketAddr>;
pub type LangFunc = Rc<dyn Fn(&ServiceRequest) -> Option<String>>;

pub struct Middleware {
    real_ip: RealIPFunc,
    #[cfg(feature = "traceid")]
    trace_header: Rc<Option<String>>,
    #[cfg(feature = "i18n")]
    lang: LangFunc,
}

impl Default for Middleware {
    fn default() -> Self {
        Self::new()
    }
}

impl Middleware {
    fn default_real_ip(req: &ServiceRequest) -> SocketAddr {
        req.peer_addr().unwrap()
    }

    #[cfg(feature = "i18n")]
    fn default_lang(_: &ServiceRequest) -> Option<String> {
        None
    }

    pub fn new() -> Self {
        Self {
            real_ip: Rc::new(Self::default_real_ip),
            #[cfg(feature = "traceid")]
            trace_header: Rc::new(None),
            #[cfg(feature = "i18n")]
            lang: Rc::new(Self::default_lang),
        }
    }

    #[cfg(feature = "traceid")]
    pub fn trace_header<S>(mut self, s: S) -> Self
    where
        S: Into<String>,
    {
        self.trace_header = Rc::new(Some(s.into()));
        self
    }

    pub fn real_ip<F>(mut self, f: F) -> Self
    where
        F: Fn(&ServiceRequest) -> SocketAddr + 'static,
    {
        self.real_ip = Rc::new(f);
        self
    }

    #[cfg(feature = "i18n")]
    pub fn lang<F>(mut self, f: F) -> Self
    where
        F: Fn(&ServiceRequest) -> Option<String> + 'static,
    {
        self.lang = Rc::new(f);
        self
    }
}

impl<S, B> Transform<S, ServiceRequest> for Middleware
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type InitError = ();
    type Transform = MiddlewareService<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(MiddlewareService {
            service: Rc::new(service),
            real_ip: self.real_ip.clone(),
            #[cfg(feature = "traceid")]
            trace_header: self.trace_header.clone(),
            #[cfg(feature = "i18n")]
            lang: self.lang.clone(),
        }))
    }
}

pub struct MiddlewareService<S> {
    service: Rc<S>,
    real_ip: RealIPFunc,
    #[cfg(feature = "traceid")]
    trace_header: Rc<Option<String>>,
    #[cfg(feature = "i18n")]
    lang: LangFunc,
}

impl<S, B> Service<ServiceRequest> for MiddlewareService<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        #[cfg(feature = "i18n")]
        let state = req
            .app_data::<actix_web::web::Data<crate::state::GlobalState>>()
            .unwrap();
        #[cfg(feature = "traceid")]
        let trace_id = req
            .extensions()
            .get::<tracing_actix_web::RequestId>()
            .unwrap()
            .to_string();
        let ext = Extension {
            start_time: Utc::now(),
            #[cfg(feature = "i18n")]
            lang: (self.lang)(&req).unwrap_or_else(|| state.locale.default.clone()),
            #[cfg(feature = "traceid")]
            trace_id: trace_id.clone(),
            real_ip: (self.real_ip)(&req),
        };
        #[cfg(feature = "traceid")]
        let header = self.trace_header.clone();
        req.extensions_mut().insert(Arc::new(ext));

        #[cfg(not(feature = "traceid"))]
        return Box::pin(self.service.call(req));
        #[cfg(feature = "traceid")]
        {
            use futures::FutureExt;
            use std::str::FromStr;
            Box::pin(self.service.call(req).map(move |x| {
                if let Some(header) = header.as_ref() {
                    x.map(|mut x| {
                        x.headers_mut().insert(
                            actix_web::http::header::HeaderName::from_str(header).unwrap(),
                            actix_web::http::header::HeaderValue::from_str(&trace_id).unwrap(),
                        );
                        x
                    })
                } else {
                    x
                }
            }))
        }
    }
}