actix_cloud/
request.rs

1use std::{net::SocketAddr, rc::Rc, sync::Arc};
2
3use actix_web::{
4    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
5    HttpMessage as _,
6};
7use chrono::{DateTime, Utc};
8use futures::future::{ready, LocalBoxFuture, Ready};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone)]
12pub struct Extension {
13    /// Request start time.
14    pub start_time: DateTime<Utc>,
15
16    #[cfg(feature = "i18n")]
17    /// Request language.
18    pub lang: String,
19
20    #[cfg(feature = "traceid")]
21    pub trace_id: String,
22
23    pub real_ip: SocketAddr,
24}
25
26pub type RealIPFunc = Rc<dyn Fn(&ServiceRequest) -> SocketAddr>;
27pub type LangFunc = Rc<dyn Fn(&ServiceRequest) -> Option<String>>;
28
29pub struct Middleware {
30    real_ip: RealIPFunc,
31    #[cfg(feature = "traceid")]
32    trace_header: Rc<Option<String>>,
33    #[cfg(feature = "i18n")]
34    lang: LangFunc,
35}
36
37impl Middleware {
38    fn default_real_ip(req: &ServiceRequest) -> SocketAddr {
39        req.peer_addr().unwrap()
40    }
41
42    #[cfg(feature = "i18n")]
43    fn default_lang(_: &ServiceRequest) -> Option<String> {
44        None
45    }
46
47    pub fn new() -> Self {
48        Self {
49            real_ip: Rc::new(Self::default_real_ip),
50            #[cfg(feature = "traceid")]
51            trace_header: Rc::new(None),
52            #[cfg(feature = "i18n")]
53            lang: Rc::new(Self::default_lang),
54        }
55    }
56
57    #[cfg(feature = "traceid")]
58    pub fn trace_header<S>(mut self, s: S) -> Self
59    where
60        S: Into<String>,
61    {
62        self.trace_header = Rc::new(Some(s.into()));
63        self
64    }
65
66    pub fn real_ip<F>(mut self, f: F) -> Self
67    where
68        F: Fn(&ServiceRequest) -> SocketAddr + 'static,
69    {
70        self.real_ip = Rc::new(f);
71        self
72    }
73
74    #[cfg(feature = "i18n")]
75    pub fn lang<F>(mut self, f: F) -> Self
76    where
77        F: Fn(&ServiceRequest) -> Option<String> + 'static,
78    {
79        self.lang = Rc::new(f);
80        self
81    }
82}
83
84impl<S, B> Transform<S, ServiceRequest> for Middleware
85where
86    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
87    S::Future: 'static,
88    B: 'static,
89{
90    type Response = ServiceResponse<B>;
91    type Error = actix_web::Error;
92    type InitError = ();
93    type Transform = MiddlewareService<S>;
94    type Future = Ready<Result<Self::Transform, Self::InitError>>;
95
96    fn new_transform(&self, service: S) -> Self::Future {
97        ready(Ok(MiddlewareService {
98            service: Rc::new(service),
99            real_ip: self.real_ip.clone(),
100            #[cfg(feature = "traceid")]
101            trace_header: self.trace_header.clone(),
102            #[cfg(feature = "i18n")]
103            lang: self.lang.clone(),
104        }))
105    }
106}
107
108pub struct MiddlewareService<S> {
109    service: Rc<S>,
110    real_ip: RealIPFunc,
111    #[cfg(feature = "traceid")]
112    trace_header: Rc<Option<String>>,
113    #[cfg(feature = "i18n")]
114    lang: LangFunc,
115}
116
117impl<S, B> Service<ServiceRequest> for MiddlewareService<S>
118where
119    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
120    S::Future: 'static,
121    B: 'static,
122{
123    type Response = ServiceResponse<B>;
124    type Error = actix_web::Error;
125    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
126
127    forward_ready!(service);
128
129    fn call(&self, req: ServiceRequest) -> Self::Future {
130        #[cfg(feature = "i18n")]
131        let state = req
132            .app_data::<actix_web::web::Data<crate::state::GlobalState>>()
133            .unwrap();
134        #[cfg(feature = "traceid")]
135        let trace_id = req
136            .extensions()
137            .get::<tracing_actix_web::RequestId>()
138            .unwrap()
139            .to_string();
140        let ext = Extension {
141            start_time: Utc::now(),
142            #[cfg(feature = "i18n")]
143            lang: (self.lang)(&req).unwrap_or_else(|| state.locale.default.clone()),
144            #[cfg(feature = "traceid")]
145            trace_id: trace_id.clone(),
146            real_ip: (self.real_ip)(&req),
147        };
148        #[cfg(feature = "traceid")]
149        let header = self.trace_header.clone();
150        req.extensions_mut().insert(Arc::new(ext));
151
152        #[cfg(not(feature = "traceid"))]
153        return Box::pin(self.service.call(req));
154        #[cfg(feature = "traceid")]
155        {
156            use futures::FutureExt;
157            use std::str::FromStr;
158            Box::pin(self.service.call(req).map(move |x| {
159                if let Some(header) = header.as_ref() {
160                    x.map(|mut x| {
161                        x.headers_mut().insert(
162                            actix_web::http::header::HeaderName::from_str(header).unwrap(),
163                            actix_web::http::header::HeaderValue::from_str(&trace_id).unwrap(),
164                        );
165                        x
166                    })
167                } else {
168                    x
169                }
170            }))
171        }
172    }
173}