1use std::future::{ready, Ready};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use actix_web::dev::Transform;
9use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse};
10use futures_util::future::LocalBoxFuture;
11use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey};
12
13use crate::*;
14
15pub struct SessionMiddlewareBuilder<ClaimsType: Claims> {
19 pub(crate) jwt_encoding_key: Arc<EncodingKey>,
20 pub(crate) jwt_decoding_key: Arc<DecodingKey>,
21 pub(crate) algorithm: Algorithm,
22 pub(crate) storage: Option<SessionStorage>,
23 pub(crate) extractors: Extractors<ClaimsType>,
24}
25impl<ClaimsType: Claims> SessionMiddlewareBuilder<ClaimsType> {
26 #[doc(hidden)]
27 pub(crate) fn new(
28 jwt_encoding_key: Arc<EncodingKey>,
29 jwt_decoding_key: Arc<DecodingKey>,
30 algorithm: Algorithm,
31 ) -> Self {
32 Self {
33 jwt_encoding_key: jwt_encoding_key.clone(),
34 jwt_decoding_key,
35 algorithm,
36 storage: None,
37 extractors: Extractors::default(),
38 }
39 }
40
41 pub(crate) fn auto_ed_dsa() -> Self {
42 let keys = JwtSigningKeys::load_or_create();
43 Self::new(
44 Arc::new(keys.encoding_key),
45 Arc::new(keys.decoding_key),
46 Algorithm::EdDSA,
47 )
48 }
49
50 #[must_use]
53 pub fn with_storage(mut self, storage: SessionStorage) -> Self {
54 self.storage = Some(storage);
55 self
56 }
57
58 #[must_use]
60 pub fn with_extractors(mut self, extractors: Extractors<ClaimsType>) -> Self {
61 self.extractors = extractors;
62 self
63 }
64
65 pub fn finish(self) -> (SessionStorage, SessionMiddlewareFactory<ClaimsType>) {
67 let Self {
68 storage,
69 jwt_encoding_key,
70 jwt_decoding_key,
71 algorithm,
72 extractors,
73 ..
74 } = self;
75 let storage = storage
76 .expect("Session storage must be constracted from pool or set from existing storage");
77 (
78 storage.clone(),
79 SessionMiddlewareFactory {
80 jwt_encoding_key,
81 jwt_decoding_key,
82 algorithm,
83 storage,
84 extractors,
85 },
86 )
87 }
88}
89
90#[derive(Clone)]
135pub struct SessionMiddlewareFactory<ClaimsType: Claims> {
136 pub(crate) jwt_encoding_key: Arc<EncodingKey>,
137 pub(crate) jwt_decoding_key: Arc<DecodingKey>,
138 pub(crate) algorithm: Algorithm,
139 pub(crate) storage: SessionStorage,
140 pub(crate) extractors: Extractors<ClaimsType>,
141}
142
143impl<ClaimsType: Claims> SessionMiddlewareFactory<ClaimsType> {
144 pub fn build_ed_dsa() -> SessionMiddlewareBuilder<ClaimsType> {
145 SessionMiddlewareBuilder::auto_ed_dsa()
146 }
147
148 pub fn build(
149 jwt_encoding_key: Arc<EncodingKey>,
150 jwt_decoding_key: Arc<DecodingKey>,
151 algorithm: Algorithm,
152 ) -> SessionMiddlewareBuilder<ClaimsType> {
153 SessionMiddlewareBuilder::new(jwt_encoding_key, jwt_decoding_key, algorithm)
154 }
155}
156
157impl<S, ClaimsType> Transform<S, ServiceRequest> for SessionMiddlewareFactory<ClaimsType>
158where
159 S: Service<ServiceRequest, Error = actix_web::Error, Response = ServiceResponse> + 'static,
160 ClaimsType: Claims,
161{
162 type Response = ServiceResponse;
163 type Error = actix_web::Error;
164 type Transform = SessionMiddleware<S, ClaimsType>;
165 type InitError = ();
166 type Future = Ready<Result<Self::Transform, Self::InitError>>;
167
168 fn new_transform(&self, service: S) -> Self::Future {
169 ready(Ok(SessionMiddleware {
170 service: Rc::new(service),
171 storage: self.storage.clone(),
172 jwt_encoding_key: self.jwt_encoding_key.clone(),
173 jwt_decoding_key: self.jwt_decoding_key.clone(),
174 algorithm: self.algorithm,
175 extractors: self.extractors.clone(),
176 }))
177 }
178}
179
180#[doc(hidden)]
181pub struct SessionMiddleware<S, ClaimsType>
182where
183 ClaimsType: Claims,
184{
185 pub(crate) service: Rc<S>,
186 pub(crate) jwt_encoding_key: Arc<EncodingKey>,
187 pub(crate) jwt_decoding_key: Arc<DecodingKey>,
188 pub(crate) algorithm: Algorithm,
189 pub(crate) storage: SessionStorage,
190 pub(crate) extractors: Extractors<ClaimsType>,
191}
192
193impl<S, ClaimsType: Claims> SessionMiddleware<S, ClaimsType> {
194 async fn extract_token<C: Claims>(
195 req: &mut ServiceRequest,
196 jwt_encoding_key: Arc<EncodingKey>,
197 jwt_decoding_key: Arc<DecodingKey>,
198 algorithm: Algorithm,
199 storage: SessionStorage,
200 extractors: &[Arc<dyn SessionExtractor<C>>],
201 ) -> Result<(), Error> {
202 let mut last_error = None;
203 for extractor in extractors.iter() {
204 match extractor
205 .extract_claims(
206 req,
207 jwt_encoding_key.clone(),
208 jwt_decoding_key.clone(),
209 algorithm,
210 storage.clone(),
211 )
212 .await
213 {
214 Ok(_) => break,
215 Err(e) => {
216 last_error = Some(e);
217 }
218 };
219 }
220 if let Some(e) = last_error {
221 return Err(e)?;
222 }
223 Ok(())
224 }
225}
226
227impl<S, ClaimsType> Service<ServiceRequest> for SessionMiddleware<S, ClaimsType>
228where
229 ClaimsType: Claims,
230 S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
231{
232 type Response = ServiceResponse;
233 type Error = actix_web::Error;
234 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
235
236 forward_ready!(service);
237
238 fn call(&self, mut req: ServiceRequest) -> Self::Future {
239 use futures_lite::FutureExt;
240
241 let svc = self.service.clone();
242 let jwt_decoding_key = self.jwt_decoding_key.clone();
243 let jwt_encoding_key = self.jwt_encoding_key.clone();
244 let algorithm = self.algorithm;
245 let storage = self.storage.clone();
246 let extractors = self.extractors.clone();
247
248 async move {
249 if !extractors.jwt_extractors.is_empty() {
250 Self::extract_token(
251 &mut req,
252 jwt_encoding_key.clone(),
253 jwt_decoding_key.clone(),
254 algorithm,
255 storage.clone(),
256 &extractors.jwt_extractors,
257 )
258 .await?;
259 }
260 if !extractors.refresh_extractors.is_empty() {
261 Self::extract_token(
262 &mut req,
263 jwt_encoding_key,
264 jwt_decoding_key,
265 algorithm,
266 storage,
267 &extractors.refresh_extractors,
268 )
269 .await?;
270 }
271 let res = svc.call(req).await?;
272 Ok(res)
273 }
274 .boxed_local()
275 }
276}