actix_jwt_session/
middleware.rs

1//! Create session storage and build middleware factory
2
3use std::future::{ready, Ready};
4use std::rc::Rc;
5use std::sync::Arc;
6
7// pub use actix_web::cookie::time::{Duration, OffsetDateTime};
8use actix_web::dev::Transform;
9use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse};
10use futures_util::future::LocalBoxFuture;
11use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey};
12
13use crate::*;
14
15/// Session middleware factory builder
16///
17/// It should be constructed with [SessionMiddlewareFactory::build].
18pub struct SessionMiddlewareBuilder<ClaimsType: Claims> {
19    pub(crate) jwt_encoding_key: Arc<EncodingKey>,
20    pub(crate) jwt_decoding_key: Arc<DecodingKey>,
21    pub(crate) algorithm: Algorithm,
22    pub(crate) storage: Option<SessionStorage>,
23    pub(crate) extractors: Extractors<ClaimsType>,
24}
25impl<ClaimsType: Claims> SessionMiddlewareBuilder<ClaimsType> {
26    #[doc(hidden)]
27    pub(crate) fn new(
28        jwt_encoding_key: Arc<EncodingKey>,
29        jwt_decoding_key: Arc<DecodingKey>,
30        algorithm: Algorithm,
31    ) -> Self {
32        Self {
33            jwt_encoding_key: jwt_encoding_key.clone(),
34            jwt_decoding_key,
35            algorithm,
36            storage: None,
37            extractors: Extractors::default(),
38        }
39    }
40
41    pub(crate) fn auto_ed_dsa() -> Self {
42        let keys = JwtSigningKeys::load_or_create();
43        Self::new(
44            Arc::new(keys.encoding_key),
45            Arc::new(keys.decoding_key),
46            Algorithm::EdDSA,
47        )
48    }
49
50    /// Set session storage to given instance. Good if for some reason you need
51    /// to share 1 storage with multiple instances of session middleware
52    #[must_use]
53    pub fn with_storage(mut self, storage: SessionStorage) -> Self {
54        self.storage = Some(storage);
55        self
56    }
57
58    /// Set how session and refresh token should be extracted
59    #[must_use]
60    pub fn with_extractors(mut self, extractors: Extractors<ClaimsType>) -> Self {
61        self.extractors = extractors;
62        self
63    }
64
65    /// Builds middleware factory and returns session storage with factory
66    pub fn finish(self) -> (SessionStorage, SessionMiddlewareFactory<ClaimsType>) {
67        let Self {
68            storage,
69            jwt_encoding_key,
70            jwt_decoding_key,
71            algorithm,
72            extractors,
73            ..
74        } = self;
75        let storage = storage
76            .expect("Session storage must be constracted from pool or set from existing storage");
77        (
78            storage.clone(),
79            SessionMiddlewareFactory {
80                jwt_encoding_key,
81                jwt_decoding_key,
82                algorithm,
83                storage,
84                extractors,
85            },
86        )
87    }
88}
89
90/// Factory creates middlware for every single request.
91///
92/// All fields here are immutable and have atomic access and only pointer is
93/// copied so are very cheap
94///
95/// Example:
96///
97/// ```
98/// use std::sync::Arc;
99/// use actix_jwt_session::*;
100///
101/// # async fn create<AppClaims: actix_jwt_session::Claims>() {
102/// // create redis connection
103/// let redis = {
104///     use deadpool_redis::{Config, Runtime};
105///     let cfg = Config::from_url("redis://localhost:6379");
106///     let pool = cfg.create_pool(Some(Runtime::Tokio1)).unwrap();
107///     pool
108/// };
109///
110/// // load or create new keys in `./config`
111/// let keys = JwtSigningKeys::load_or_create();
112///
113/// // create new [SessionStorage] and [SessionMiddlewareFactory]
114/// let (storage, factory) = SessionMiddlewareFactory::<AppClaims>::build(
115///     Arc::new(keys.encoding_key),
116///     Arc::new(keys.decoding_key),
117///     Algorithm::EdDSA
118/// )
119/// // pass redis connection
120/// .with_redis_pool(redis.clone())
121/// .with_extractors(
122///     Extractors::default()
123///     // Check if header "Authorization" exists and contains Bearer with encoded JWT
124///     .with_jwt_header("Authorization")
125///     // Check if cookie "jwt" exists and contains encoded JWT
126///     .with_jwt_cookie("acx-a")
127///     .with_refresh_header("ACX-Refresh")
128///     // Check if cookie "jwt" exists and contains encoded JWT
129///     .with_refresh_cookie("acx-r")
130/// )
131/// .finish();
132/// # }
133/// ```
134#[derive(Clone)]
135pub struct SessionMiddlewareFactory<ClaimsType: Claims> {
136    pub(crate) jwt_encoding_key: Arc<EncodingKey>,
137    pub(crate) jwt_decoding_key: Arc<DecodingKey>,
138    pub(crate) algorithm: Algorithm,
139    pub(crate) storage: SessionStorage,
140    pub(crate) extractors: Extractors<ClaimsType>,
141}
142
143impl<ClaimsType: Claims> SessionMiddlewareFactory<ClaimsType> {
144    pub fn build_ed_dsa() -> SessionMiddlewareBuilder<ClaimsType> {
145        SessionMiddlewareBuilder::auto_ed_dsa()
146    }
147
148    pub fn build(
149        jwt_encoding_key: Arc<EncodingKey>,
150        jwt_decoding_key: Arc<DecodingKey>,
151        algorithm: Algorithm,
152    ) -> SessionMiddlewareBuilder<ClaimsType> {
153        SessionMiddlewareBuilder::new(jwt_encoding_key, jwt_decoding_key, algorithm)
154    }
155}
156
157impl<S, ClaimsType> Transform<S, ServiceRequest> for SessionMiddlewareFactory<ClaimsType>
158where
159    S: Service<ServiceRequest, Error = actix_web::Error, Response = ServiceResponse> + 'static,
160    ClaimsType: Claims,
161{
162    type Response = ServiceResponse;
163    type Error = actix_web::Error;
164    type Transform = SessionMiddleware<S, ClaimsType>;
165    type InitError = ();
166    type Future = Ready<Result<Self::Transform, Self::InitError>>;
167
168    fn new_transform(&self, service: S) -> Self::Future {
169        ready(Ok(SessionMiddleware {
170            service: Rc::new(service),
171            storage: self.storage.clone(),
172            jwt_encoding_key: self.jwt_encoding_key.clone(),
173            jwt_decoding_key: self.jwt_decoding_key.clone(),
174            algorithm: self.algorithm,
175            extractors: self.extractors.clone(),
176        }))
177    }
178}
179
180#[doc(hidden)]
181pub struct SessionMiddleware<S, ClaimsType>
182where
183    ClaimsType: Claims,
184{
185    pub(crate) service: Rc<S>,
186    pub(crate) jwt_encoding_key: Arc<EncodingKey>,
187    pub(crate) jwt_decoding_key: Arc<DecodingKey>,
188    pub(crate) algorithm: Algorithm,
189    pub(crate) storage: SessionStorage,
190    pub(crate) extractors: Extractors<ClaimsType>,
191}
192
193impl<S, ClaimsType: Claims> SessionMiddleware<S, ClaimsType> {
194    async fn extract_token<C: Claims>(
195        req: &mut ServiceRequest,
196        jwt_encoding_key: Arc<EncodingKey>,
197        jwt_decoding_key: Arc<DecodingKey>,
198        algorithm: Algorithm,
199        storage: SessionStorage,
200        extractors: &[Arc<dyn SessionExtractor<C>>],
201    ) -> Result<(), Error> {
202        let mut last_error = None;
203        for extractor in extractors.iter() {
204            match extractor
205                .extract_claims(
206                    req,
207                    jwt_encoding_key.clone(),
208                    jwt_decoding_key.clone(),
209                    algorithm,
210                    storage.clone(),
211                )
212                .await
213            {
214                Ok(_) => break,
215                Err(e) => {
216                    last_error = Some(e);
217                }
218            };
219        }
220        if let Some(e) = last_error {
221            return Err(e)?;
222        }
223        Ok(())
224    }
225}
226
227impl<S, ClaimsType> Service<ServiceRequest> for SessionMiddleware<S, ClaimsType>
228where
229    ClaimsType: Claims,
230    S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error> + 'static,
231{
232    type Response = ServiceResponse;
233    type Error = actix_web::Error;
234    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
235
236    forward_ready!(service);
237
238    fn call(&self, mut req: ServiceRequest) -> Self::Future {
239        use futures_lite::FutureExt;
240
241        let svc = self.service.clone();
242        let jwt_decoding_key = self.jwt_decoding_key.clone();
243        let jwt_encoding_key = self.jwt_encoding_key.clone();
244        let algorithm = self.algorithm;
245        let storage = self.storage.clone();
246        let extractors = self.extractors.clone();
247
248        async move {
249            if !extractors.jwt_extractors.is_empty() {
250                Self::extract_token(
251                    &mut req,
252                    jwt_encoding_key.clone(),
253                    jwt_decoding_key.clone(),
254                    algorithm,
255                    storage.clone(),
256                    &extractors.jwt_extractors,
257                )
258                .await?;
259            }
260            if !extractors.refresh_extractors.is_empty() {
261                Self::extract_token(
262                    &mut req,
263                    jwt_encoding_key,
264                    jwt_decoding_key,
265                    algorithm,
266                    storage,
267                    &extractors.refresh_extractors,
268                )
269                .await?;
270            }
271            let res = svc.call(req).await?;
272            Ok(res)
273        }
274        .boxed_local()
275    }
276}