actix_jwt_authc/
authentication.rs

1use std::collections::HashSet;
2use std::future::{ready, Ready};
3use std::marker::PhantomData;
4use std::rc::Rc;
5use std::sync::Arc;
6
7#[cfg(feature = "session")]
8use actix_session::SessionExt;
9use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
10use actix_web::{rt, FromRequest, HttpMessage};
11use derive_more::Display;
12use futures_util::future::LocalBoxFuture;
13use futures_util::{FutureExt, Stream, StreamExt};
14use jsonwebtoken::{decode, DecodingKey, Validation};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use tokio::sync::RwLock;
18#[cfg(feature = "tracing")]
19use tracing::{info, trace};
20
21use crate::errors::Error;
22
23/// A "must-be-authenticated" type wrapper, which, when added as a parameter on a route
24/// handler, will result in an 401 response if a given request cannot be authenticated.
25///
26/// It is generic on the claims type to allow developers to specify their own JWT claims type.
27///
28/// If [AuthenticateMiddleware] has been attached as middle to a [actix_web::App], this type will be
29/// injected into authenticatable-requests.
30#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
31pub struct Authenticated<T> {
32    pub jwt: JWT,
33    pub claims: T,
34}
35
36/// A "might-be-authenticated" type wrapper.
37///
38/// It is generic on the claims type to allow developers to specify their own JWT claims type.
39///
40/// If [AuthenticateMiddleware] has been attached as middle to a [actix_web::App], this type will be
41/// injected into authenticatable-requests.
42#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
43pub enum MaybeAuthenticated<T> {
44    Just(Authenticated<T>),
45    None,
46}
47
48impl<T> MaybeAuthenticated<T> {
49    pub fn into_option(self) -> Option<Authenticated<T>> {
50        self.into()
51    }
52}
53
54impl<T> From<MaybeAuthenticated<T>> for Option<Authenticated<T>> {
55    fn from(maybe_authenticated: MaybeAuthenticated<T>) -> Self {
56        match maybe_authenticated {
57            MaybeAuthenticated::Just(v) => Some(v),
58            MaybeAuthenticated::None => None,
59        }
60    }
61}
62
63impl<T> FromRequest for Authenticated<T>
64where
65    T: Clone + 'static,
66{
67    type Error = Error;
68    type Future = Ready<Result<Self, Self::Error>>;
69
70    fn from_request(
71        req: &actix_web::HttpRequest,
72        _payload: &mut actix_web::dev::Payload,
73    ) -> Self::Future {
74        let value = req.extensions().get::<Authenticated<T>>().cloned();
75        let result = match value {
76            Some(v) => Ok(v),
77            None => Err(Error::Unauthenticated),
78        };
79        ready(result)
80    }
81}
82
83impl<T> FromRequest for MaybeAuthenticated<T>
84where
85    T: Clone + 'static,
86{
87    type Error = Error;
88    type Future = Ready<Result<Self, Self::Error>>;
89
90    fn from_request(
91        req: &actix_web::HttpRequest,
92        _payload: &mut actix_web::dev::Payload,
93    ) -> Self::Future {
94        let value = req.extensions().get::<Authenticated<T>>().cloned();
95        let result = match value {
96            Some(v) => Ok(MaybeAuthenticated::Just(v)),
97            None => Ok(MaybeAuthenticated::None),
98        };
99        ready(result)
100    }
101}
102
103/// A wrapper around JWTs
104#[derive(Hash, PartialEq, Eq, Clone, Debug, Display, Serialize, Deserialize)]
105pub struct JWT(pub String);
106
107/// A wrapper to hold the key used for extracting a JWT from an [actix_session::Session]
108#[cfg(feature = "session")]
109#[cfg_attr(docsrs, doc(cfg(feature = "session")))]
110#[derive(Clone, Eq, PartialEq, Debug)]
111pub struct JWTSessionKey(pub String);
112
113/// Describes changes to invalidated tokens
114#[derive(Clone, Eq, PartialEq, Debug)]
115pub enum InvalidatedTokensEvent {
116    /// A full reload of invalidated [JWT]s
117    Full(HashSet<JWT>),
118
119    /// A batched "diff" invalidated [JWT]s
120    Diff {
121        add: Option<HashSet<JWT>>,
122        remove: Option<HashSet<JWT>>,
123    },
124
125    /// Add a single invalidated [JWT]
126    Add(JWT),
127
128    /// Remove a single invalidated [JWT]
129    Remove(JWT),
130}
131
132#[derive(Eq, PartialEq, Debug)]
133struct InvalidatedJWTsState(HashSet<JWT>);
134
135impl InvalidatedJWTsState {
136    fn new() -> InvalidatedJWTsState {
137        InvalidatedJWTsState(HashSet::new())
138    }
139}
140
141// <-- Middleware
142
143/// Settings for [AuthenticateMiddlewareFactory]. These determine how the authentication middleware
144/// will work.
145#[derive(Clone)]
146pub struct AuthenticateMiddlewareSettings {
147    /// JWT Decoding Key; used to ensure that JWTs were signed by a trusted source
148    pub jwt_decoding_key: DecodingKey,
149
150    /// JWT validation configuration options
151    pub jwt_validator: Validation,
152
153    /// Optional key for extracting a JWT out of a request's Session.
154    ///
155    /// If not provided, the middleware will not attempt to extract JWTs from Sessions.
156    #[cfg(feature = "session")]
157    #[cfg_attr(docsrs, doc(cfg(feature = "session")))]
158    pub jwt_session_key: Option<JWTSessionKey>,
159
160    /// Optional prefixes for extracting a JWT out of the Authorization header.
161    ///
162    /// The values provided should not have any extra leading or trailing spaces (e.g. "Bearer", or
163    /// "ApiKey" will suffice if you expect headers to look like "Authorization:Bearer {JWT}" or
164    /// "Authorization: ApiKey {JWT}").
165    ///
166    /// If not provided, the middleware will not attempt to extract JWTs from the Authorization
167    /// header.
168    pub jwt_authorization_header_prefixes: Option<Vec<String>>,
169}
170
171/// A factory for the authentication middleware.
172///
173/// This is meant to be instantiated once during bootstrap and *cloned* to the app factory
174/// closure. That way, there is a single set of invalidated JWTs held in memory, refreshed by
175/// a single periodic timer.
176///
177/// Cloning is cheap because internally this uses [Arc]s to hold state.
178#[derive(Clone)]
179pub struct AuthenticateMiddlewareFactory<ClaimsType> {
180    invalidated_jwts_state: Arc<RwLock<InvalidatedJWTsState>>,
181    jwt_decoding_key: Arc<DecodingKey>,
182    #[cfg(feature = "session")]
183    jwt_session_key: Option<Arc<JWTSessionKey>>,
184    jwt_authorization_header_prefixes: Option<Arc<Vec<String>>>,
185    jwt_validator: Arc<Validation>,
186    _claims_type_marker: PhantomData<ClaimsType>,
187}
188
189impl<ClaimsType> AuthenticateMiddlewareFactory<ClaimsType>
190where
191    ClaimsType: DeserializeOwned + 'static,
192{
193    /// Takes a [futures_util::Stream] of [InvalidatedTokensEvent]s and returns a [AuthenticateMiddlewareFactory]
194    /// that knows how consume the stream to populate an in-memory set of invalidated JWTs that is
195    /// then passed on to the [AuthenticateMiddleware] that it spawns.
196    pub fn new<S>(
197        invalidated_jwts_events: S,
198        settings: AuthenticateMiddlewareSettings,
199    ) -> AuthenticateMiddlewareFactory<ClaimsType>
200    where
201        S: Stream<Item = InvalidatedTokensEvent> + Unpin + 'static,
202    {
203        let invalidated_jwts_state = Arc::new(RwLock::new(InvalidatedJWTsState::new()));
204
205        #[cfg(feature = "tracing")]
206        info!("Kicking off invalidated JWT reload loop");
207        rt::spawn(reload_from_stream(
208            invalidated_jwts_events,
209            invalidated_jwts_state.clone(),
210        ));
211
212        AuthenticateMiddlewareFactory::<ClaimsType> {
213            invalidated_jwts_state,
214            jwt_decoding_key: Arc::new(settings.jwt_decoding_key),
215            #[cfg(feature = "session")]
216            jwt_session_key: settings.jwt_session_key.map(Arc::new),
217            jwt_authorization_header_prefixes: settings.jwt_authorization_header_prefixes.map(
218                |prefixes| {
219                    Arc::new(
220                        prefixes
221                            .iter()
222                            .map(|prefix| format!("{} ", prefix))
223                            .collect(),
224                    )
225                },
226            ),
227            jwt_validator: Arc::new(settings.jwt_validator),
228            _claims_type_marker: PhantomData,
229        }
230    }
231}
232
233#[cfg_attr(
234    feature = "tracing",
235    tracing::instrument(level = "trace", skip(events, invalidated_jwts_set))
236)]
237async fn reload_from_stream<S>(
238    mut events: S,
239    invalidated_jwts_set: Arc<RwLock<InvalidatedJWTsState>>,
240) where
241    S: Stream<Item = InvalidatedTokensEvent> + Unpin,
242{
243    while let Some(invalidated_jwt_event) = events.next().await {
244        #[cfg(feature = "tracing")]
245        trace!("Received invalidated JWTs event");
246        let mut invalidated_state = invalidated_jwts_set.write().await;
247        match invalidated_jwt_event {
248            InvalidatedTokensEvent::Full(all) => {
249                #[cfg(feature = "tracing")]
250                trace!(count = all.len(), "Received invalidated JWTs with full set");
251                invalidated_state.0 = all;
252            }
253            InvalidatedTokensEvent::Diff { add, remove } => {
254                #[cfg(feature = "tracing")]
255                trace!("Received invalidated JWTs diff");
256                if let Some(to_remove) = remove {
257                    #[cfg(feature = "tracing")]
258                    trace!(
259                        remove_count = to_remove.len(),
260                        "Received invalidated JWTs diff, with removals"
261                    );
262                    for to_remove in to_remove.iter() {
263                        invalidated_state.0.remove(to_remove);
264                    }
265                }
266                if let Some(to_add) = add {
267                    #[cfg(feature = "tracing")]
268                    trace!(
269                        add_count = to_add.len(),
270                        "Received invalidated JWTs diff, with additions"
271                    );
272                    invalidated_state.0.extend(to_add);
273                }
274            }
275            InvalidatedTokensEvent::Remove(jwt) => {
276                #[cfg(feature = "tracing")]
277                trace!("Received Invalidated token to remove");
278                invalidated_state.0.remove(&jwt);
279            }
280            InvalidatedTokensEvent::Add(jwt) => {
281                #[cfg(feature = "tracing")]
282                trace!("Received Invalidated token to add");
283                invalidated_state.0.insert(jwt);
284            }
285        }
286    }
287}
288
289impl<S, B, ClaimsType> Transform<S, ServiceRequest> for AuthenticateMiddlewareFactory<ClaimsType>
290where
291    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
292    ClaimsType: DeserializeOwned + 'static,
293{
294    type Response = ServiceResponse<B>;
295    type Error = actix_web::Error;
296    type Transform = AuthenticateMiddleware<S, ClaimsType>;
297    type InitError = ();
298    type Future = Ready<Result<Self::Transform, Self::InitError>>;
299
300    fn new_transform(&self, service: S) -> Self::Future {
301        ready(Ok(AuthenticateMiddleware {
302            invalidated_jwts_state: self.invalidated_jwts_state.clone(),
303            service: Rc::new(service),
304            jwt_decoding_key: self.jwt_decoding_key.clone(),
305            #[cfg(feature = "session")]
306            jwt_session_key: self.jwt_session_key.clone(),
307            jwt_authorization_header_prefixes: self.jwt_authorization_header_prefixes.clone(),
308            jwt_validator: self.jwt_validator.clone(),
309            _claims_type_marker: PhantomData,
310        }))
311    }
312}
313
314/// The actual middleware that extracts JWTs from requests, validates them, and injects them into
315/// a request.
316pub struct AuthenticateMiddleware<S, ClaimsType> {
317    invalidated_jwts_state: Arc<RwLock<InvalidatedJWTsState>>,
318    service: Rc<S>,
319    jwt_decoding_key: Arc<DecodingKey>,
320    #[cfg(feature = "session")]
321    jwt_session_key: Option<Arc<JWTSessionKey>>,
322    jwt_authorization_header_prefixes: Option<Arc<Vec<String>>>,
323    jwt_validator: Arc<Validation>,
324    _claims_type_marker: PhantomData<ClaimsType>,
325}
326
327impl<S, B, ClaimsType> Service<ServiceRequest> for AuthenticateMiddleware<S, ClaimsType>
328where
329    ClaimsType: DeserializeOwned + 'static,
330    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
331{
332    type Response = ServiceResponse<B>;
333    type Error = actix_web::Error;
334    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
335
336    forward_ready!(service);
337
338    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, req)))]
339    fn call(&self, req: ServiceRequest) -> Self::Future {
340        let svc = self.service.clone();
341        let invalidated_jwts_state = self.invalidated_jwts_state.clone();
342        let jwt_decoding_key = self.jwt_decoding_key.clone();
343        #[cfg(feature = "session")]
344        let jwt_session_key = self.jwt_session_key.clone();
345        let jwt_authorization_header_prefixes = self.jwt_authorization_header_prefixes.clone();
346        let validation = self.jwt_validator.clone();
347        async move {
348            authenticate::<S, B, ClaimsType>(
349                svc,
350                req,
351                invalidated_jwts_state,
352                &jwt_decoding_key,
353                #[cfg(feature = "session")]
354                jwt_session_key,
355                jwt_authorization_header_prefixes,
356                &validation,
357            )
358            .await
359        }
360        .boxed_local()
361    }
362}
363
364#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
365async fn authenticate<S, B, ClaimsType>(
366    svc: Rc<S>,
367    req: ServiceRequest,
368    invalidated_jwts_state: Arc<RwLock<InvalidatedJWTsState>>,
369    jwt_decoding_key: &DecodingKey,
370    #[cfg(feature = "session")] jwt_session_key: Option<Arc<JWTSessionKey>>,
371    jwt_authorization_header_prefixes: Option<Arc<Vec<String>>>,
372    validation: &Validation,
373) -> Result<ServiceResponse<B>, actix_web::Error>
374where
375    ClaimsType: DeserializeOwned + 'static,
376    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
377{
378    #[cfg(feature = "tracing")]
379    trace!("Attempting to authenticate");
380    let maybe_jwt_from_auth_header =
381        jwt_authorization_header_prefixes.and_then(|prefixes| extract_bearer_jwt(&req, &prefixes));
382    #[cfg(feature = "session")]
383    let maybe_extracted_jwt = maybe_jwt_from_auth_header
384        .or_else(|| jwt_session_key.and_then(|key| extract_session_jwt(&req, &key)));
385    #[cfg(not(feature = "session"))]
386    let maybe_extracted_jwt = maybe_jwt_from_auth_header;
387    if let Some(jwt) = maybe_extracted_jwt {
388        #[cfg(feature = "tracing")]
389        trace!(jwt = ?jwt, "JWT extracted");
390        let jwt_str = jwt.0.as_str();
391        if invalidated_jwts_state.read().await.0.contains(&jwt) {
392            #[cfg(feature = "tracing")]
393            trace!(jwt= ?jwt, "Invalidated JWT detected");
394            Err(Error::InvalidSession(format!(
395                "Invalidated session. JWT [{jwt}] was already invalidated"
396            )))?;
397        } else {
398            let decoded_claims = decode::<ClaimsType>(jwt_str, jwt_decoding_key, validation)
399                .map_err(|e| {
400                    let error_message = e.to_string();
401                    #[cfg(feature = "tracing")]
402                    trace!("Claims failed decoding because of [{}]", error_message);
403                    Error::InvalidSession(error_message)
404                })?;
405            #[cfg(feature = "tracing")]
406            trace!("Claims successfully decoded");
407
408            req.extensions_mut().insert(Authenticated {
409                jwt,
410                claims: decoded_claims.claims,
411            });
412        }
413    }
414    let res = svc.call(req).await?;
415    Ok(res)
416}
417
418#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace"))]
419fn extract_bearer_jwt(req: &ServiceRequest, auth_prefixes: &[String]) -> Option<JWT> {
420    let authorisation_header = req.headers().get("Authorization")?;
421    let as_str = authorisation_header.to_str().ok()?;
422    let jwt_str = auth_prefixes
423        .iter()
424        .filter_map(|prefix| as_str.strip_prefix(prefix))
425        .next()?;
426    Some(JWT(jwt_str.to_string()))
427}
428
429#[cfg(feature = "session")]
430#[cfg_attr(
431    feature = "tracing",
432    tracing::instrument(level = "trace", skip(jwt_session_key))
433)]
434fn extract_session_jwt(req: &ServiceRequest, jwt_session_key: &JWTSessionKey) -> Option<JWT> {
435    let session = req.get_session();
436    let jwt_str = session.get::<String>(&jwt_session_key.0).ok().flatten()?;
437    Some(JWT(jwt_str))
438}
439
440//     Middleware -->
441#[cfg(test)]
442mod tests {
443    use std::ops::Add;
444    use std::sync::Arc;
445    use std::time::Duration;
446
447    #[cfg(feature = "session")]
448    use actix_session::storage::CookieSessionStore;
449    #[cfg(feature = "session")]
450    use actix_session::Session;
451    #[cfg(feature = "session")]
452    use actix_session::SessionMiddleware;
453    #[cfg(feature = "session")]
454    use actix_web::cookie::Key;
455    use actix_web::web::Data;
456    use actix_web::{get, test, App, HttpResponse};
457    use dashmap::DashSet;
458    use futures::channel::{mpsc, mpsc::Sender};
459    use futures::SinkExt;
460    use jsonwebtoken::*;
461    use ring::rand::SystemRandom;
462    use ring::signature::{Ed25519KeyPair, KeyPair};
463    use serde::{Deserialize, Serialize};
464    use time::ext::*;
465    use time::OffsetDateTime;
466    use tokio::sync::Mutex;
467    #[cfg(feature = "tracing")]
468    use tracing::error;
469    use uuid::Uuid;
470
471    use super::*;
472
473    #[test]
474    async fn test_reload_from_stream_full_replace() {
475        let mut full_invalidated_set = HashSet::new();
476        full_invalidated_set.insert(JWT("1".to_string()));
477        full_invalidated_set.insert(JWT("2".to_string()));
478        full_invalidated_set.insert(JWT("3".to_string()));
479
480        let full = InvalidatedTokensEvent::Full(full_invalidated_set.clone());
481
482        let state = Arc::new(RwLock::new(InvalidatedJWTsState::new()));
483
484        let (mut tx, rx) = futures::channel::mpsc::channel(100);
485
486        actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
487
488        tx.send(full).await.unwrap();
489        tokio::time::sleep(Duration::from_secs(2)).await;
490
491        assert_eq!(full_invalidated_set, state.read().await.0);
492    }
493
494    #[test]
495    async fn test_reload_from_stream_full_add() {
496        let state = Arc::new(RwLock::new(InvalidatedJWTsState::new()));
497
498        let (mut tx, rx) = futures::channel::mpsc::channel(100);
499        actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
500
501        let add = InvalidatedTokensEvent::Add(JWT("1".to_string()));
502        tx.send(add).await.unwrap();
503        tokio::time::sleep(Duration::from_secs(2)).await;
504
505        let mut expected_invalidated_set = HashSet::new();
506        expected_invalidated_set.insert(JWT("1".to_string()));
507
508        assert_eq!(expected_invalidated_set, state.read().await.0);
509    }
510
511    #[test]
512    async fn test_reload_from_stream_full_remove() {
513        let mut current_state = HashSet::new();
514        current_state.insert(JWT("1".to_string()));
515
516        let state = Arc::new(RwLock::new(InvalidatedJWTsState(current_state)));
517        let (mut tx, rx) = futures::channel::mpsc::channel(100);
518
519        actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
520
521        let remove = InvalidatedTokensEvent::Remove(JWT("1".to_string()));
522        tx.send(remove).await.unwrap();
523        tokio::time::sleep(Duration::from_secs(2)).await;
524
525        assert!(state.read().await.0.is_empty());
526    }
527
528    #[test]
529    async fn test_reload_stream_diff() {
530        let mut full_invalidated_set = HashSet::new();
531        full_invalidated_set.insert(JWT("1".to_string()));
532        full_invalidated_set.insert(JWT("2".to_string()));
533        full_invalidated_set.insert(JWT("3".to_string()));
534
535        let mut add_set = HashSet::new();
536        add_set.insert(JWT("4".to_string()));
537
538        let mut remove_set = HashSet::new();
539        remove_set.insert(JWT("1".to_string()));
540
541        let diff_1 = InvalidatedTokensEvent::Diff {
542            add: Some(add_set),
543            remove: Some(remove_set),
544        };
545
546        let state = Arc::new(RwLock::new(InvalidatedJWTsState(full_invalidated_set)));
547
548        let (mut tx, rx) = futures::channel::mpsc::channel(100);
549
550        actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
551
552        tx.send(diff_1).await.unwrap();
553        tokio::time::sleep(Duration::from_secs(2)).await;
554
555        let mut expected_invalidated_set = HashSet::new();
556        expected_invalidated_set.insert(JWT("2".to_string()));
557        expected_invalidated_set.insert(JWT("3".to_string()));
558        expected_invalidated_set.insert(JWT("4".to_string()));
559
560        assert_eq!(expected_invalidated_set, state.read().await.0);
561
562        let mut remove_set_2 = HashSet::new();
563        remove_set_2.insert(JWT("2".to_string()));
564        remove_set_2.insert(JWT("3".to_string()));
565        remove_set_2.insert(JWT("4".to_string()));
566
567        let diff_2 = InvalidatedTokensEvent::Diff {
568            add: None,
569            remove: Some(remove_set_2),
570        };
571
572        tx.send(diff_2).await.unwrap();
573        tokio::time::sleep(Duration::from_secs(2)).await;
574
575        assert!(state.read().await.0.is_empty());
576    }
577
578    #[test]
579    async fn test_extract_bearer_jwt_none() {
580        let req = test::TestRequest::default().to_srv_request();
581        let resp = extract_bearer_jwt(&req, vec!["Bearer ".to_string()].as_slice());
582        assert!(resp.is_none());
583    }
584
585    #[test]
586    async fn test_extract_bearer_jwt_some() {
587        let req = test::TestRequest::default()
588            .insert_header(("Authorization", "Bearer XYZ"))
589            .to_srv_request();
590        let resp = extract_bearer_jwt(&req, vec!["Bearer ".to_string()].as_slice());
591        assert_eq!(Some(JWT("XYZ".to_string())), resp);
592    }
593
594    #[test]
595    async fn test_extract_bearer_jwt_some_all_prefix_prefix() {
596        for auth_header in ["ApiKey XYZ", "Bearer XYZ"] {
597            let req = test::TestRequest::default()
598                .insert_header(("Authorization", auth_header))
599                .to_srv_request();
600            let resp = extract_bearer_jwt(
601                &req,
602                vec!["Bearer ".to_string(), "ApiKey ".to_string()].as_slice(),
603            );
604            assert_eq!(Some(JWT("XYZ".to_string())), resp);
605        }
606    }
607
608    #[test]
609    async fn test_extract_bearer_jwt_wrong_prefix() {
610        let req = test::TestRequest::default()
611            .insert_header(("Authorization", "Bearer XYZ"))
612            .to_srv_request();
613        let resp = extract_bearer_jwt(&req, vec!["ApiKey ".to_string()].as_slice());
614        assert!(resp.is_none());
615    }
616
617    #[cfg(feature = "session")]
618    #[test]
619    async fn test_extract_session_jwt_none() {
620        let session_key = JWTSessionKey("sesh".to_string());
621        let req = test::TestRequest::default().to_srv_request();
622        let resp = extract_session_jwt(&req, &session_key);
623        assert!(resp.is_none());
624    }
625
626    #[test]
627    async fn integration_test_no_session_should_reject() {
628        let fixture = build_fixture(JWTTtl::default()).await.unwrap();
629        let app = fixture.app;
630        let req = test::TestRequest::get().uri("/session").to_request();
631        let resp = test::call_service(&app, req).await;
632        assert_eq!(actix_http::StatusCode::UNAUTHORIZED, resp.status());
633        let error_response: crate::errors::ErrorResponse = test::read_body_json(resp).await;
634        assert_eq!(
635            format!("{}", crate::errors::Error::Unauthenticated),
636            error_response.message
637        )
638    }
639
640    #[test]
641    async fn integration_test_no_session_maybe_authenticated() {
642        let fixture = build_fixture(JWTTtl::default()).await.unwrap();
643        let app = fixture.app;
644        let req = test::TestRequest::get().uri("/maybe_session").to_request();
645        let resp = test::call_service(&app, req).await;
646        assert_eq!(actix_http::StatusCode::OK, resp.status());
647        let message_response: MessageResponse = test::read_body_json(resp).await;
648        assert_eq!("No session for you !", message_response.message.as_str())
649    }
650
651    #[test]
652    async fn integration_test_with_authentication() {
653        let fixture = build_fixture(JWTTtl::default()).await.unwrap();
654        let app = fixture.app;
655
656        let login_resp = {
657            let req = test::TestRequest::get().uri("/login").to_request();
658            test::call_service(&app, req).await
659        };
660        assert_eq!(actix_http::StatusCode::OK, login_resp.status());
661        #[cfg(feature = "session")]
662        let (login_response, session_req) = {
663            let mut req = test::TestRequest::get().uri("/session");
664            for c in login_resp.response().cookies() {
665                req = req.cookie(c);
666            }
667
668            let login_response: LoginResponse = test::read_body_json(login_resp).await;
669            (login_response, req)
670        };
671        #[cfg(not(feature = "session"))]
672        let (login_response, session_req) = {
673            let login_response: LoginResponse = test::read_body_json(login_resp).await;
674            let req = test::TestRequest::get().uri("/session").insert_header((
675                "Authorization",
676                format!("Bearer {}", login_response.bearer_token),
677            ));
678            (login_response, req)
679        };
680        let session_resp = test::call_service(&app, session_req.to_request()).await;
681        assert_eq!(actix_http::StatusCode::OK, session_resp.status());
682        let session_response: Authenticated<Claims> = test::read_body_json(session_resp).await;
683        assert_eq!(login_response.claims, session_response.claims);
684    }
685
686    #[test]
687    async fn integration_test_with_expired_authentication() {
688        let fixture = build_fixture(JWTTtl(1.nanoseconds())).await.unwrap();
689        let app = fixture.app;
690
691        let login_resp = {
692            let req = test::TestRequest::get().uri("/login").to_request();
693            test::call_service(&app, req).await
694        };
695        assert_eq!(actix_http::StatusCode::OK, login_resp.status());
696        tokio::time::sleep(Duration::from_secs(2)).await;
697        #[cfg(feature = "session")]
698        let (_login_response, session_resp) = {
699            let mut req = test::TestRequest::get().uri("/session");
700            for c in login_resp.response().cookies() {
701                req = req.cookie(c);
702            }
703            let resp = app.call(req.to_request()).await.err().unwrap();
704            let login_response: LoginResponse = test::read_body_json(login_resp).await;
705            (login_response, resp)
706        };
707        #[cfg(not(feature = "session"))]
708        let (_login_response, session_resp) = {
709            let login_response: LoginResponse = test::read_body_json(login_resp).await;
710            let req = test::TestRequest::get().uri("/session").insert_header((
711                "Authorization",
712                format!("Bearer {}", login_response.bearer_token),
713            ));
714            let resp = app.call(req.to_request()).await.err().unwrap();
715            (login_response, resp)
716        };
717        let session_resp = ServiceResponse::new(
718            test::TestRequest::get().uri("/session").to_http_request(),
719            session_resp.error_response(),
720        );
721        assert_eq!(actix_http::StatusCode::UNAUTHORIZED, session_resp.status());
722        let session_response: crate::errors::ErrorResponse =
723            test::read_body_json(session_resp).await;
724        assert_eq!(
725            "Invalid session [ExpiredSignature]",
726            session_response.message.as_str()
727        )
728    }
729
730    #[test]
731    async fn integration_test_with_invalidated_authentication() {
732        let fixture = build_fixture(JWTTtl(1.nanoseconds())).await.unwrap();
733        let app = fixture.app;
734
735        let login_resp = {
736            let req = test::TestRequest::get().uri("/login").to_request();
737            test::call_service(&app, req).await
738        };
739        assert_eq!(actix_http::StatusCode::OK, login_resp.status());
740
741        #[cfg(feature = "session")]
742        let (logout_req, session_req) = {
743            let mut logout_req = test::TestRequest::get().uri("/logout");
744            for c in login_resp.response().cookies() {
745                logout_req = logout_req.cookie(c);
746            }
747
748            let mut session_req = test::TestRequest::get().uri("/session");
749            for c in login_resp.response().cookies() {
750                session_req = session_req.cookie(c);
751            }
752            (logout_req, session_req)
753        };
754        #[cfg(not(feature = "session"))]
755        let (logout_req, session_req) = {
756            let login_response: LoginResponse = test::read_body_json(login_resp).await;
757            let session_req = test::TestRequest::get().uri("/session").insert_header((
758                "Authorization",
759                format!("Bearer {}", login_response.bearer_token),
760            ));
761            let logout_req = test::TestRequest::get().uri("/logout").insert_header((
762                "Authorization",
763                format!("Bearer {}", login_response.bearer_token),
764            ));
765            (logout_req, session_req)
766        };
767        let logout_resp = test::call_service(&app, logout_req.to_request()).await;
768        assert_eq!(actix_http::StatusCode::OK, logout_resp.status());
769
770        tokio::time::sleep(Duration::from_millis(100)).await;
771
772        let session_resp: actix_web::Error =
773            { app.call(session_req.to_request()).await.err().unwrap() };
774        let session_resp = {
775            ServiceResponse::new(
776                test::TestRequest::get().uri("/session").to_http_request(),
777                session_resp.error_response(),
778            )
779        };
780        assert_eq!(actix_http::StatusCode::UNAUTHORIZED, session_resp.status());
781
782        let session_response: crate::errors::ErrorResponse =
783            test::read_body_json(session_resp).await;
784        assert!(session_response
785            .message
786            .as_str()
787            .starts_with("Invalid session [Invalidated session"))
788    }
789
790    #[test]
791    async fn integration_test_with_remotely_invalidated_session() {
792        let fixture = build_fixture(JWTTtl(1.nanoseconds())).await.unwrap();
793        let app = fixture.app;
794
795        let login_resp = {
796            let req = test::TestRequest::get().uri("/login").to_request();
797            test::call_service(&app, req).await
798        };
799        assert_eq!(actix_http::StatusCode::OK, login_resp.status());
800        #[cfg(feature = "session")]
801        let (login_response, session_req) = {
802            let mut req = test::TestRequest::get().uri("/session");
803            for c in login_resp.response().cookies() {
804                req = req.cookie(c);
805            }
806
807            let login_response: LoginResponse = test::read_body_json(login_resp).await;
808            (login_response, req)
809        };
810        #[cfg(not(feature = "session"))]
811        let (login_response, session_req) = {
812            let login_response: LoginResponse = test::read_body_json(login_resp).await;
813            let req = test::TestRequest::get().uri("/session").insert_header((
814                "Authorization",
815                format!("Bearer {}", login_response.bearer_token),
816            ));
817            (login_response, req)
818        };
819
820        let authenticated = Authenticated {
821            jwt: JWT(login_response.bearer_token),
822            claims: login_response.claims,
823        };
824        fixture
825            .invalidated_jwts_store
826            .add_to_invalidated(authenticated)
827            .await;
828
829        tokio::time::sleep(Duration::from_millis(100)).await;
830
831        let session_resp: actix_web::Error =
832            { app.call(session_req.to_request()).await.err().unwrap() };
833        let session_resp = {
834            ServiceResponse::new(
835                test::TestRequest::get().uri("/session").to_http_request(),
836                session_resp.error_response(),
837            )
838        };
839
840        assert_eq!(actix_http::StatusCode::UNAUTHORIZED, session_resp.status());
841        let session_response: crate::errors::ErrorResponse =
842            test::read_body_json(session_resp).await;
843        assert!(session_response
844            .message
845            .as_str()
846            .starts_with("Invalid session [Invalidated session"))
847    }
848
849    struct TestFixture<T> {
850        invalidated_jwts_store: InvalidatedJWTStore,
851        app: T,
852    }
853
854    /// Builds a server app, almost exactly the same as inmemory example, just with ultra-fast loops
855    /// and no tracing
856    async fn build_fixture(
857        jwt_ttl: JWTTtl,
858    ) -> Result<
859        TestFixture<
860            impl Service<actix_http::Request, Response = ServiceResponse, Error = actix_web::Error>,
861        >,
862        Box<dyn std::error::Error>,
863    > {
864        let jwt_signing_keys = JwtSigningKeys::generate()?;
865        #[cfg(feature = "session")]
866        let jwt_session_key = JWTSessionKey("jwt-session".to_string());
867
868        let mut validator = Validation::new(JWT_SIGNING_ALGO);
869        validator.leeway = 1;
870
871        let auth_middleware_settings = AuthenticateMiddlewareSettings {
872            jwt_decoding_key: jwt_signing_keys.decoding_key,
873            #[cfg(feature = "session")]
874            jwt_session_key: Some(jwt_session_key.clone()),
875            jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]),
876            jwt_validator: validator,
877        };
878
879        let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream();
880        let auth_middleware_factory =
881            AuthenticateMiddlewareFactory::<Claims>::new(stream, auth_middleware_settings.clone());
882
883        #[cfg(feature = "session")]
884        let session_encryption_key = Key::generate();
885
886        let app = {
887            #[cfg(feature = "session")]
888            let app_t = App::new()
889                .app_data(Data::new(jwt_session_key.clone()))
890                .app_data(Data::new(invalidated_jwts_store.clone()))
891                .app_data(Data::new(jwt_signing_keys.encoding_key.clone()))
892                .app_data(Data::new(jwt_ttl.clone()))
893                .wrap(auth_middleware_factory.clone())
894                .wrap(
895                    SessionMiddleware::builder(
896                        CookieSessionStore::default(),
897                        session_encryption_key.clone(),
898                    )
899                    .cookie_secure(false)
900                    .cookie_http_only(true)
901                    .build(),
902                );
903            #[cfg(not(feature = "session"))]
904            let app_t = App::new()
905                .app_data(Data::new(invalidated_jwts_store.clone()))
906                .app_data(Data::new(jwt_signing_keys.encoding_key.clone()))
907                .app_data(Data::new(jwt_ttl.clone()))
908                .wrap(auth_middleware_factory.clone());
909            test::init_service(
910                app_t
911                    .service(login)
912                    .service(logout)
913                    .service(session_info)
914                    .service(maybe_session_info),
915            )
916        }
917        .await;
918        Ok(TestFixture {
919            invalidated_jwts_store: invalidated_jwts_store.clone(),
920            app,
921        })
922    }
923
924    // <-- Routes
925    #[get("/login")]
926    async fn login(
927        jwt_encoding_key: Data<EncodingKey>,
928        #[cfg(feature = "session")] jwt_session_key: Data<JWTSessionKey>,
929        jwt_ttl: Data<JWTTtl>,
930        #[cfg(feature = "session")] session: Session,
931    ) -> Result<HttpResponse, Error> {
932        let sub = format!("{}", Uuid::new_v4().as_u128());
933        let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
934        let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0);
935        let exp = expires_at.unix_timestamp() as usize;
936
937        let jwt_claims = Claims { iat, exp, sub };
938        let jwt_token = encode(
939            &Header::new(JWT_SIGNING_ALGO),
940            &jwt_claims,
941            &jwt_encoding_key,
942        )
943        .map_err(|_| Error::InternalError)?;
944        #[cfg(feature = "session")]
945        session
946            .insert(&jwt_session_key.0, &jwt_token)
947            .map_err(|_| Error::InternalError)?;
948        let login_response = LoginResponse {
949            bearer_token: jwt_token,
950            claims: jwt_claims,
951        };
952
953        Ok(HttpResponse::Ok().json(login_response))
954    }
955
956    #[get("/session")]
957    async fn session_info(authenticated: Authenticated<Claims>) -> Result<HttpResponse, Error> {
958        Ok(HttpResponse::Ok().json(authenticated))
959    }
960
961    #[get("/maybe_session")]
962    async fn maybe_session_info(
963        maybe_authenticated: MaybeAuthenticated<Claims>,
964    ) -> Result<HttpResponse, Error> {
965        if let Some(authenticated) = maybe_authenticated.into_option() {
966            Ok(HttpResponse::Ok().json(authenticated))
967        } else {
968            Ok(HttpResponse::Ok().json(MessageResponse {
969                message: "No session for you !".to_string(),
970            }))
971        }
972    }
973
974    #[get("/logout")]
975    async fn logout(
976        invalidated_jwts: Data<InvalidatedJWTStore>,
977        authenticated: Authenticated<Claims>,
978        #[cfg(feature = "session")] session: Session,
979    ) -> Result<HttpResponse, Error> {
980        #[cfg(feature = "session")]
981        session.clear();
982        invalidated_jwts.add_to_invalidated(authenticated).await;
983        Ok(HttpResponse::Ok().json(EmptyResponse {}))
984    }
985    //    Routes -->
986
987    const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
988
989    // Holds a map of encoded JWT -> expiries
990    #[derive(Clone)]
991    struct InvalidatedJWTStore {
992        store: Arc<DashSet<JWT>>,
993        tx: Arc<Mutex<Sender<InvalidatedTokensEvent>>>,
994    }
995
996    impl InvalidatedJWTStore {
997        fn new_with_stream() -> (
998            InvalidatedJWTStore,
999            impl Stream<Item = InvalidatedTokensEvent>,
1000        ) {
1001            let invalidated = Arc::new(DashSet::new());
1002            let (tx, rx) = mpsc::channel(100);
1003            let tx_to_hold = Arc::new(Mutex::new(tx));
1004            (
1005                InvalidatedJWTStore {
1006                    store: invalidated,
1007                    tx: tx_to_hold,
1008                },
1009                rx,
1010            )
1011        }
1012
1013        async fn add_to_invalidated(&self, authenticated: Authenticated<Claims>) {
1014            self.store.insert(authenticated.jwt.clone());
1015            let mut tx = self.tx.lock().await;
1016            if let Err(_e) = tx
1017                .send(InvalidatedTokensEvent::Add(authenticated.jwt))
1018                .await
1019            {
1020                #[cfg(feature = "tracing")]
1021                error!(error = ?_e, "Failed to send update on adding to invalidated")
1022            }
1023        }
1024    }
1025
1026    struct JwtSigningKeys {
1027        encoding_key: EncodingKey,
1028        decoding_key: DecodingKey,
1029    }
1030
1031    impl JwtSigningKeys {
1032        fn generate() -> Result<Self, Box<dyn std::error::Error>> {
1033            let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?;
1034            let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?;
1035            let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
1036            let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref());
1037            Ok(JwtSigningKeys {
1038                encoding_key,
1039                decoding_key,
1040            })
1041        }
1042    }
1043
1044    // <-- Responses
1045
1046    #[derive(Clone, Copy)]
1047    struct JWTTtl(time::Duration);
1048
1049    impl Default for JWTTtl {
1050        fn default() -> Self {
1051            JWTTtl(1.days())
1052        }
1053    }
1054
1055    #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
1056    struct Claims {
1057        exp: usize,
1058        iat: usize,
1059        sub: String,
1060    }
1061
1062    #[derive(Debug, Serialize, Deserialize)]
1063    struct LoginResponse {
1064        bearer_token: String,
1065        claims: Claims,
1066    }
1067
1068    #[derive(Serialize, Deserialize)]
1069    struct EmptyResponse {}
1070
1071    #[derive(Serialize, Deserialize)]
1072    struct MessageResponse {
1073        message: String,
1074    }
1075
1076    //     Responses -->
1077}