1use std::{net::SocketAddr, rc::Rc, sync::Arc};
2
3use actix_web::{
4 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
5 HttpMessage as _,
6};
7use chrono::{DateTime, Utc};
8use futures::future::{ready, LocalBoxFuture, Ready};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone)]
12pub struct Extension {
13 pub start_time: DateTime<Utc>,
15
16 #[cfg(feature = "i18n")]
17 pub lang: String,
19
20 #[cfg(feature = "traceid")]
21 pub trace_id: String,
22
23 pub real_ip: SocketAddr,
24}
25
26pub type RealIPFunc = Rc<dyn Fn(&ServiceRequest) -> SocketAddr>;
27pub type LangFunc = Rc<dyn Fn(&ServiceRequest) -> Option<String>>;
28
29pub struct Middleware {
30 real_ip: RealIPFunc,
31 #[cfg(feature = "traceid")]
32 trace_header: Rc<Option<String>>,
33 #[cfg(feature = "i18n")]
34 lang: LangFunc,
35}
36
37impl Middleware {
38 fn default_real_ip(req: &ServiceRequest) -> SocketAddr {
39 req.peer_addr().unwrap()
40 }
41
42 #[cfg(feature = "i18n")]
43 fn default_lang(_: &ServiceRequest) -> Option<String> {
44 None
45 }
46
47 pub fn new() -> Self {
48 Self {
49 real_ip: Rc::new(Self::default_real_ip),
50 #[cfg(feature = "traceid")]
51 trace_header: Rc::new(None),
52 #[cfg(feature = "i18n")]
53 lang: Rc::new(Self::default_lang),
54 }
55 }
56
57 #[cfg(feature = "traceid")]
58 pub fn trace_header<S>(mut self, s: S) -> Self
59 where
60 S: Into<String>,
61 {
62 self.trace_header = Rc::new(Some(s.into()));
63 self
64 }
65
66 pub fn real_ip<F>(mut self, f: F) -> Self
67 where
68 F: Fn(&ServiceRequest) -> SocketAddr + 'static,
69 {
70 self.real_ip = Rc::new(f);
71 self
72 }
73
74 #[cfg(feature = "i18n")]
75 pub fn lang<F>(mut self, f: F) -> Self
76 where
77 F: Fn(&ServiceRequest) -> Option<String> + 'static,
78 {
79 self.lang = Rc::new(f);
80 self
81 }
82}
83
84impl<S, B> Transform<S, ServiceRequest> for Middleware
85where
86 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
87 S::Future: 'static,
88 B: 'static,
89{
90 type Response = ServiceResponse<B>;
91 type Error = actix_web::Error;
92 type InitError = ();
93 type Transform = MiddlewareService<S>;
94 type Future = Ready<Result<Self::Transform, Self::InitError>>;
95
96 fn new_transform(&self, service: S) -> Self::Future {
97 ready(Ok(MiddlewareService {
98 service: Rc::new(service),
99 real_ip: self.real_ip.clone(),
100 #[cfg(feature = "traceid")]
101 trace_header: self.trace_header.clone(),
102 #[cfg(feature = "i18n")]
103 lang: self.lang.clone(),
104 }))
105 }
106}
107
108pub struct MiddlewareService<S> {
109 service: Rc<S>,
110 real_ip: RealIPFunc,
111 #[cfg(feature = "traceid")]
112 trace_header: Rc<Option<String>>,
113 #[cfg(feature = "i18n")]
114 lang: LangFunc,
115}
116
117impl<S, B> Service<ServiceRequest> for MiddlewareService<S>
118where
119 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
120 S::Future: 'static,
121 B: 'static,
122{
123 type Response = ServiceResponse<B>;
124 type Error = actix_web::Error;
125 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
126
127 forward_ready!(service);
128
129 fn call(&self, req: ServiceRequest) -> Self::Future {
130 #[cfg(feature = "i18n")]
131 let state = req
132 .app_data::<actix_web::web::Data<crate::state::GlobalState>>()
133 .unwrap();
134 #[cfg(feature = "traceid")]
135 let trace_id = req
136 .extensions()
137 .get::<tracing_actix_web::RequestId>()
138 .unwrap()
139 .to_string();
140 let ext = Extension {
141 start_time: Utc::now(),
142 #[cfg(feature = "i18n")]
143 lang: (self.lang)(&req).unwrap_or_else(|| state.locale.default.clone()),
144 #[cfg(feature = "traceid")]
145 trace_id: trace_id.clone(),
146 real_ip: (self.real_ip)(&req),
147 };
148 #[cfg(feature = "traceid")]
149 let header = self.trace_header.clone();
150 req.extensions_mut().insert(Arc::new(ext));
151
152 #[cfg(not(feature = "traceid"))]
153 return Box::pin(self.service.call(req));
154 #[cfg(feature = "traceid")]
155 {
156 use futures::FutureExt;
157 use std::str::FromStr;
158 Box::pin(self.service.call(req).map(move |x| {
159 if let Some(header) = header.as_ref() {
160 x.map(|mut x| {
161 x.headers_mut().insert(
162 actix_web::http::header::HeaderName::from_str(header).unwrap(),
163 actix_web::http::header::HeaderValue::from_str(&trace_id).unwrap(),
164 );
165 x
166 })
167 } else {
168 x
169 }
170 }))
171 }
172 }
173}