1use std::collections::HashSet;
2use std::future::{ready, Ready};
3use std::marker::PhantomData;
4use std::rc::Rc;
5use std::sync::Arc;
6
7#[cfg(feature = "session")]
8use actix_session::SessionExt;
9use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
10use actix_web::{rt, FromRequest, HttpMessage};
11use derive_more::Display;
12use futures_util::future::LocalBoxFuture;
13use futures_util::{FutureExt, Stream, StreamExt};
14use jsonwebtoken::{decode, DecodingKey, Validation};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use tokio::sync::RwLock;
18#[cfg(feature = "tracing")]
19use tracing::{info, trace};
20
21use crate::errors::Error;
22
23#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
31pub struct Authenticated<T> {
32 pub jwt: JWT,
33 pub claims: T,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
43pub enum MaybeAuthenticated<T> {
44 Just(Authenticated<T>),
45 None,
46}
47
48impl<T> MaybeAuthenticated<T> {
49 pub fn into_option(self) -> Option<Authenticated<T>> {
50 self.into()
51 }
52}
53
54impl<T> From<MaybeAuthenticated<T>> for Option<Authenticated<T>> {
55 fn from(maybe_authenticated: MaybeAuthenticated<T>) -> Self {
56 match maybe_authenticated {
57 MaybeAuthenticated::Just(v) => Some(v),
58 MaybeAuthenticated::None => None,
59 }
60 }
61}
62
63impl<T> FromRequest for Authenticated<T>
64where
65 T: Clone + 'static,
66{
67 type Error = Error;
68 type Future = Ready<Result<Self, Self::Error>>;
69
70 fn from_request(
71 req: &actix_web::HttpRequest,
72 _payload: &mut actix_web::dev::Payload,
73 ) -> Self::Future {
74 let value = req.extensions().get::<Authenticated<T>>().cloned();
75 let result = match value {
76 Some(v) => Ok(v),
77 None => Err(Error::Unauthenticated),
78 };
79 ready(result)
80 }
81}
82
83impl<T> FromRequest for MaybeAuthenticated<T>
84where
85 T: Clone + 'static,
86{
87 type Error = Error;
88 type Future = Ready<Result<Self, Self::Error>>;
89
90 fn from_request(
91 req: &actix_web::HttpRequest,
92 _payload: &mut actix_web::dev::Payload,
93 ) -> Self::Future {
94 let value = req.extensions().get::<Authenticated<T>>().cloned();
95 let result = match value {
96 Some(v) => Ok(MaybeAuthenticated::Just(v)),
97 None => Ok(MaybeAuthenticated::None),
98 };
99 ready(result)
100 }
101}
102
103#[derive(Hash, PartialEq, Eq, Clone, Debug, Display, Serialize, Deserialize)]
105pub struct JWT(pub String);
106
107#[cfg(feature = "session")]
109#[cfg_attr(docsrs, doc(cfg(feature = "session")))]
110#[derive(Clone, Eq, PartialEq, Debug)]
111pub struct JWTSessionKey(pub String);
112
113#[derive(Clone, Eq, PartialEq, Debug)]
115pub enum InvalidatedTokensEvent {
116 Full(HashSet<JWT>),
118
119 Diff {
121 add: Option<HashSet<JWT>>,
122 remove: Option<HashSet<JWT>>,
123 },
124
125 Add(JWT),
127
128 Remove(JWT),
130}
131
132#[derive(Eq, PartialEq, Debug)]
133struct InvalidatedJWTsState(HashSet<JWT>);
134
135impl InvalidatedJWTsState {
136 fn new() -> InvalidatedJWTsState {
137 InvalidatedJWTsState(HashSet::new())
138 }
139}
140
141#[derive(Clone)]
146pub struct AuthenticateMiddlewareSettings {
147 pub jwt_decoding_key: DecodingKey,
149
150 pub jwt_validator: Validation,
152
153 #[cfg(feature = "session")]
157 #[cfg_attr(docsrs, doc(cfg(feature = "session")))]
158 pub jwt_session_key: Option<JWTSessionKey>,
159
160 pub jwt_authorization_header_prefixes: Option<Vec<String>>,
169}
170
171#[derive(Clone)]
179pub struct AuthenticateMiddlewareFactory<ClaimsType> {
180 invalidated_jwts_state: Arc<RwLock<InvalidatedJWTsState>>,
181 jwt_decoding_key: Arc<DecodingKey>,
182 #[cfg(feature = "session")]
183 jwt_session_key: Option<Arc<JWTSessionKey>>,
184 jwt_authorization_header_prefixes: Option<Arc<Vec<String>>>,
185 jwt_validator: Arc<Validation>,
186 _claims_type_marker: PhantomData<ClaimsType>,
187}
188
189impl<ClaimsType> AuthenticateMiddlewareFactory<ClaimsType>
190where
191 ClaimsType: DeserializeOwned + 'static,
192{
193 pub fn new<S>(
197 invalidated_jwts_events: S,
198 settings: AuthenticateMiddlewareSettings,
199 ) -> AuthenticateMiddlewareFactory<ClaimsType>
200 where
201 S: Stream<Item = InvalidatedTokensEvent> + Unpin + 'static,
202 {
203 let invalidated_jwts_state = Arc::new(RwLock::new(InvalidatedJWTsState::new()));
204
205 #[cfg(feature = "tracing")]
206 info!("Kicking off invalidated JWT reload loop");
207 rt::spawn(reload_from_stream(
208 invalidated_jwts_events,
209 invalidated_jwts_state.clone(),
210 ));
211
212 AuthenticateMiddlewareFactory::<ClaimsType> {
213 invalidated_jwts_state,
214 jwt_decoding_key: Arc::new(settings.jwt_decoding_key),
215 #[cfg(feature = "session")]
216 jwt_session_key: settings.jwt_session_key.map(Arc::new),
217 jwt_authorization_header_prefixes: settings.jwt_authorization_header_prefixes.map(
218 |prefixes| {
219 Arc::new(
220 prefixes
221 .iter()
222 .map(|prefix| format!("{} ", prefix))
223 .collect(),
224 )
225 },
226 ),
227 jwt_validator: Arc::new(settings.jwt_validator),
228 _claims_type_marker: PhantomData,
229 }
230 }
231}
232
233#[cfg_attr(
234 feature = "tracing",
235 tracing::instrument(level = "trace", skip(events, invalidated_jwts_set))
236)]
237async fn reload_from_stream<S>(
238 mut events: S,
239 invalidated_jwts_set: Arc<RwLock<InvalidatedJWTsState>>,
240) where
241 S: Stream<Item = InvalidatedTokensEvent> + Unpin,
242{
243 while let Some(invalidated_jwt_event) = events.next().await {
244 #[cfg(feature = "tracing")]
245 trace!("Received invalidated JWTs event");
246 let mut invalidated_state = invalidated_jwts_set.write().await;
247 match invalidated_jwt_event {
248 InvalidatedTokensEvent::Full(all) => {
249 #[cfg(feature = "tracing")]
250 trace!(count = all.len(), "Received invalidated JWTs with full set");
251 invalidated_state.0 = all;
252 }
253 InvalidatedTokensEvent::Diff { add, remove } => {
254 #[cfg(feature = "tracing")]
255 trace!("Received invalidated JWTs diff");
256 if let Some(to_remove) = remove {
257 #[cfg(feature = "tracing")]
258 trace!(
259 remove_count = to_remove.len(),
260 "Received invalidated JWTs diff, with removals"
261 );
262 for to_remove in to_remove.iter() {
263 invalidated_state.0.remove(to_remove);
264 }
265 }
266 if let Some(to_add) = add {
267 #[cfg(feature = "tracing")]
268 trace!(
269 add_count = to_add.len(),
270 "Received invalidated JWTs diff, with additions"
271 );
272 invalidated_state.0.extend(to_add);
273 }
274 }
275 InvalidatedTokensEvent::Remove(jwt) => {
276 #[cfg(feature = "tracing")]
277 trace!("Received Invalidated token to remove");
278 invalidated_state.0.remove(&jwt);
279 }
280 InvalidatedTokensEvent::Add(jwt) => {
281 #[cfg(feature = "tracing")]
282 trace!("Received Invalidated token to add");
283 invalidated_state.0.insert(jwt);
284 }
285 }
286 }
287}
288
289impl<S, B, ClaimsType> Transform<S, ServiceRequest> for AuthenticateMiddlewareFactory<ClaimsType>
290where
291 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
292 ClaimsType: DeserializeOwned + 'static,
293{
294 type Response = ServiceResponse<B>;
295 type Error = actix_web::Error;
296 type Transform = AuthenticateMiddleware<S, ClaimsType>;
297 type InitError = ();
298 type Future = Ready<Result<Self::Transform, Self::InitError>>;
299
300 fn new_transform(&self, service: S) -> Self::Future {
301 ready(Ok(AuthenticateMiddleware {
302 invalidated_jwts_state: self.invalidated_jwts_state.clone(),
303 service: Rc::new(service),
304 jwt_decoding_key: self.jwt_decoding_key.clone(),
305 #[cfg(feature = "session")]
306 jwt_session_key: self.jwt_session_key.clone(),
307 jwt_authorization_header_prefixes: self.jwt_authorization_header_prefixes.clone(),
308 jwt_validator: self.jwt_validator.clone(),
309 _claims_type_marker: PhantomData,
310 }))
311 }
312}
313
314pub struct AuthenticateMiddleware<S, ClaimsType> {
317 invalidated_jwts_state: Arc<RwLock<InvalidatedJWTsState>>,
318 service: Rc<S>,
319 jwt_decoding_key: Arc<DecodingKey>,
320 #[cfg(feature = "session")]
321 jwt_session_key: Option<Arc<JWTSessionKey>>,
322 jwt_authorization_header_prefixes: Option<Arc<Vec<String>>>,
323 jwt_validator: Arc<Validation>,
324 _claims_type_marker: PhantomData<ClaimsType>,
325}
326
327impl<S, B, ClaimsType> Service<ServiceRequest> for AuthenticateMiddleware<S, ClaimsType>
328where
329 ClaimsType: DeserializeOwned + 'static,
330 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
331{
332 type Response = ServiceResponse<B>;
333 type Error = actix_web::Error;
334 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
335
336 forward_ready!(service);
337
338 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, req)))]
339 fn call(&self, req: ServiceRequest) -> Self::Future {
340 let svc = self.service.clone();
341 let invalidated_jwts_state = self.invalidated_jwts_state.clone();
342 let jwt_decoding_key = self.jwt_decoding_key.clone();
343 #[cfg(feature = "session")]
344 let jwt_session_key = self.jwt_session_key.clone();
345 let jwt_authorization_header_prefixes = self.jwt_authorization_header_prefixes.clone();
346 let validation = self.jwt_validator.clone();
347 async move {
348 authenticate::<S, B, ClaimsType>(
349 svc,
350 req,
351 invalidated_jwts_state,
352 &jwt_decoding_key,
353 #[cfg(feature = "session")]
354 jwt_session_key,
355 jwt_authorization_header_prefixes,
356 &validation,
357 )
358 .await
359 }
360 .boxed_local()
361 }
362}
363
364#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
365async fn authenticate<S, B, ClaimsType>(
366 svc: Rc<S>,
367 req: ServiceRequest,
368 invalidated_jwts_state: Arc<RwLock<InvalidatedJWTsState>>,
369 jwt_decoding_key: &DecodingKey,
370 #[cfg(feature = "session")] jwt_session_key: Option<Arc<JWTSessionKey>>,
371 jwt_authorization_header_prefixes: Option<Arc<Vec<String>>>,
372 validation: &Validation,
373) -> Result<ServiceResponse<B>, actix_web::Error>
374where
375 ClaimsType: DeserializeOwned + 'static,
376 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
377{
378 #[cfg(feature = "tracing")]
379 trace!("Attempting to authenticate");
380 let maybe_jwt_from_auth_header =
381 jwt_authorization_header_prefixes.and_then(|prefixes| extract_bearer_jwt(&req, &prefixes));
382 #[cfg(feature = "session")]
383 let maybe_extracted_jwt = maybe_jwt_from_auth_header
384 .or_else(|| jwt_session_key.and_then(|key| extract_session_jwt(&req, &key)));
385 #[cfg(not(feature = "session"))]
386 let maybe_extracted_jwt = maybe_jwt_from_auth_header;
387 if let Some(jwt) = maybe_extracted_jwt {
388 #[cfg(feature = "tracing")]
389 trace!(jwt = ?jwt, "JWT extracted");
390 let jwt_str = jwt.0.as_str();
391 if invalidated_jwts_state.read().await.0.contains(&jwt) {
392 #[cfg(feature = "tracing")]
393 trace!(jwt= ?jwt, "Invalidated JWT detected");
394 Err(Error::InvalidSession(format!(
395 "Invalidated session. JWT [{jwt}] was already invalidated"
396 )))?;
397 } else {
398 let decoded_claims = decode::<ClaimsType>(jwt_str, jwt_decoding_key, validation)
399 .map_err(|e| {
400 let error_message = e.to_string();
401 #[cfg(feature = "tracing")]
402 trace!("Claims failed decoding because of [{}]", error_message);
403 Error::InvalidSession(error_message)
404 })?;
405 #[cfg(feature = "tracing")]
406 trace!("Claims successfully decoded");
407
408 req.extensions_mut().insert(Authenticated {
409 jwt,
410 claims: decoded_claims.claims,
411 });
412 }
413 }
414 let res = svc.call(req).await?;
415 Ok(res)
416}
417
418#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace"))]
419fn extract_bearer_jwt(req: &ServiceRequest, auth_prefixes: &[String]) -> Option<JWT> {
420 let authorisation_header = req.headers().get("Authorization")?;
421 let as_str = authorisation_header.to_str().ok()?;
422 let jwt_str = auth_prefixes
423 .iter()
424 .filter_map(|prefix| as_str.strip_prefix(prefix))
425 .next()?;
426 Some(JWT(jwt_str.to_string()))
427}
428
429#[cfg(feature = "session")]
430#[cfg_attr(
431 feature = "tracing",
432 tracing::instrument(level = "trace", skip(jwt_session_key))
433)]
434fn extract_session_jwt(req: &ServiceRequest, jwt_session_key: &JWTSessionKey) -> Option<JWT> {
435 let session = req.get_session();
436 let jwt_str = session.get::<String>(&jwt_session_key.0).ok().flatten()?;
437 Some(JWT(jwt_str))
438}
439
440#[cfg(test)]
442mod tests {
443 use std::ops::Add;
444 use std::sync::Arc;
445 use std::time::Duration;
446
447 #[cfg(feature = "session")]
448 use actix_session::storage::CookieSessionStore;
449 #[cfg(feature = "session")]
450 use actix_session::Session;
451 #[cfg(feature = "session")]
452 use actix_session::SessionMiddleware;
453 #[cfg(feature = "session")]
454 use actix_web::cookie::Key;
455 use actix_web::web::Data;
456 use actix_web::{get, test, App, HttpResponse};
457 use dashmap::DashSet;
458 use futures::channel::{mpsc, mpsc::Sender};
459 use futures::SinkExt;
460 use jsonwebtoken::*;
461 use ring::rand::SystemRandom;
462 use ring::signature::{Ed25519KeyPair, KeyPair};
463 use serde::{Deserialize, Serialize};
464 use time::ext::*;
465 use time::OffsetDateTime;
466 use tokio::sync::Mutex;
467 #[cfg(feature = "tracing")]
468 use tracing::error;
469 use uuid::Uuid;
470
471 use super::*;
472
473 #[test]
474 async fn test_reload_from_stream_full_replace() {
475 let mut full_invalidated_set = HashSet::new();
476 full_invalidated_set.insert(JWT("1".to_string()));
477 full_invalidated_set.insert(JWT("2".to_string()));
478 full_invalidated_set.insert(JWT("3".to_string()));
479
480 let full = InvalidatedTokensEvent::Full(full_invalidated_set.clone());
481
482 let state = Arc::new(RwLock::new(InvalidatedJWTsState::new()));
483
484 let (mut tx, rx) = futures::channel::mpsc::channel(100);
485
486 actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
487
488 tx.send(full).await.unwrap();
489 tokio::time::sleep(Duration::from_secs(2)).await;
490
491 assert_eq!(full_invalidated_set, state.read().await.0);
492 }
493
494 #[test]
495 async fn test_reload_from_stream_full_add() {
496 let state = Arc::new(RwLock::new(InvalidatedJWTsState::new()));
497
498 let (mut tx, rx) = futures::channel::mpsc::channel(100);
499 actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
500
501 let add = InvalidatedTokensEvent::Add(JWT("1".to_string()));
502 tx.send(add).await.unwrap();
503 tokio::time::sleep(Duration::from_secs(2)).await;
504
505 let mut expected_invalidated_set = HashSet::new();
506 expected_invalidated_set.insert(JWT("1".to_string()));
507
508 assert_eq!(expected_invalidated_set, state.read().await.0);
509 }
510
511 #[test]
512 async fn test_reload_from_stream_full_remove() {
513 let mut current_state = HashSet::new();
514 current_state.insert(JWT("1".to_string()));
515
516 let state = Arc::new(RwLock::new(InvalidatedJWTsState(current_state)));
517 let (mut tx, rx) = futures::channel::mpsc::channel(100);
518
519 actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
520
521 let remove = InvalidatedTokensEvent::Remove(JWT("1".to_string()));
522 tx.send(remove).await.unwrap();
523 tokio::time::sleep(Duration::from_secs(2)).await;
524
525 assert!(state.read().await.0.is_empty());
526 }
527
528 #[test]
529 async fn test_reload_stream_diff() {
530 let mut full_invalidated_set = HashSet::new();
531 full_invalidated_set.insert(JWT("1".to_string()));
532 full_invalidated_set.insert(JWT("2".to_string()));
533 full_invalidated_set.insert(JWT("3".to_string()));
534
535 let mut add_set = HashSet::new();
536 add_set.insert(JWT("4".to_string()));
537
538 let mut remove_set = HashSet::new();
539 remove_set.insert(JWT("1".to_string()));
540
541 let diff_1 = InvalidatedTokensEvent::Diff {
542 add: Some(add_set),
543 remove: Some(remove_set),
544 };
545
546 let state = Arc::new(RwLock::new(InvalidatedJWTsState(full_invalidated_set)));
547
548 let (mut tx, rx) = futures::channel::mpsc::channel(100);
549
550 actix_web::rt::spawn(reload_from_stream(rx, state.clone()));
551
552 tx.send(diff_1).await.unwrap();
553 tokio::time::sleep(Duration::from_secs(2)).await;
554
555 let mut expected_invalidated_set = HashSet::new();
556 expected_invalidated_set.insert(JWT("2".to_string()));
557 expected_invalidated_set.insert(JWT("3".to_string()));
558 expected_invalidated_set.insert(JWT("4".to_string()));
559
560 assert_eq!(expected_invalidated_set, state.read().await.0);
561
562 let mut remove_set_2 = HashSet::new();
563 remove_set_2.insert(JWT("2".to_string()));
564 remove_set_2.insert(JWT("3".to_string()));
565 remove_set_2.insert(JWT("4".to_string()));
566
567 let diff_2 = InvalidatedTokensEvent::Diff {
568 add: None,
569 remove: Some(remove_set_2),
570 };
571
572 tx.send(diff_2).await.unwrap();
573 tokio::time::sleep(Duration::from_secs(2)).await;
574
575 assert!(state.read().await.0.is_empty());
576 }
577
578 #[test]
579 async fn test_extract_bearer_jwt_none() {
580 let req = test::TestRequest::default().to_srv_request();
581 let resp = extract_bearer_jwt(&req, vec!["Bearer ".to_string()].as_slice());
582 assert!(resp.is_none());
583 }
584
585 #[test]
586 async fn test_extract_bearer_jwt_some() {
587 let req = test::TestRequest::default()
588 .insert_header(("Authorization", "Bearer XYZ"))
589 .to_srv_request();
590 let resp = extract_bearer_jwt(&req, vec!["Bearer ".to_string()].as_slice());
591 assert_eq!(Some(JWT("XYZ".to_string())), resp);
592 }
593
594 #[test]
595 async fn test_extract_bearer_jwt_some_all_prefix_prefix() {
596 for auth_header in ["ApiKey XYZ", "Bearer XYZ"] {
597 let req = test::TestRequest::default()
598 .insert_header(("Authorization", auth_header))
599 .to_srv_request();
600 let resp = extract_bearer_jwt(
601 &req,
602 vec!["Bearer ".to_string(), "ApiKey ".to_string()].as_slice(),
603 );
604 assert_eq!(Some(JWT("XYZ".to_string())), resp);
605 }
606 }
607
608 #[test]
609 async fn test_extract_bearer_jwt_wrong_prefix() {
610 let req = test::TestRequest::default()
611 .insert_header(("Authorization", "Bearer XYZ"))
612 .to_srv_request();
613 let resp = extract_bearer_jwt(&req, vec!["ApiKey ".to_string()].as_slice());
614 assert!(resp.is_none());
615 }
616
617 #[cfg(feature = "session")]
618 #[test]
619 async fn test_extract_session_jwt_none() {
620 let session_key = JWTSessionKey("sesh".to_string());
621 let req = test::TestRequest::default().to_srv_request();
622 let resp = extract_session_jwt(&req, &session_key);
623 assert!(resp.is_none());
624 }
625
626 #[test]
627 async fn integration_test_no_session_should_reject() {
628 let fixture = build_fixture(JWTTtl::default()).await.unwrap();
629 let app = fixture.app;
630 let req = test::TestRequest::get().uri("/session").to_request();
631 let resp = test::call_service(&app, req).await;
632 assert_eq!(actix_http::StatusCode::UNAUTHORIZED, resp.status());
633 let error_response: crate::errors::ErrorResponse = test::read_body_json(resp).await;
634 assert_eq!(
635 format!("{}", crate::errors::Error::Unauthenticated),
636 error_response.message
637 )
638 }
639
640 #[test]
641 async fn integration_test_no_session_maybe_authenticated() {
642 let fixture = build_fixture(JWTTtl::default()).await.unwrap();
643 let app = fixture.app;
644 let req = test::TestRequest::get().uri("/maybe_session").to_request();
645 let resp = test::call_service(&app, req).await;
646 assert_eq!(actix_http::StatusCode::OK, resp.status());
647 let message_response: MessageResponse = test::read_body_json(resp).await;
648 assert_eq!("No session for you !", message_response.message.as_str())
649 }
650
651 #[test]
652 async fn integration_test_with_authentication() {
653 let fixture = build_fixture(JWTTtl::default()).await.unwrap();
654 let app = fixture.app;
655
656 let login_resp = {
657 let req = test::TestRequest::get().uri("/login").to_request();
658 test::call_service(&app, req).await
659 };
660 assert_eq!(actix_http::StatusCode::OK, login_resp.status());
661 #[cfg(feature = "session")]
662 let (login_response, session_req) = {
663 let mut req = test::TestRequest::get().uri("/session");
664 for c in login_resp.response().cookies() {
665 req = req.cookie(c);
666 }
667
668 let login_response: LoginResponse = test::read_body_json(login_resp).await;
669 (login_response, req)
670 };
671 #[cfg(not(feature = "session"))]
672 let (login_response, session_req) = {
673 let login_response: LoginResponse = test::read_body_json(login_resp).await;
674 let req = test::TestRequest::get().uri("/session").insert_header((
675 "Authorization",
676 format!("Bearer {}", login_response.bearer_token),
677 ));
678 (login_response, req)
679 };
680 let session_resp = test::call_service(&app, session_req.to_request()).await;
681 assert_eq!(actix_http::StatusCode::OK, session_resp.status());
682 let session_response: Authenticated<Claims> = test::read_body_json(session_resp).await;
683 assert_eq!(login_response.claims, session_response.claims);
684 }
685
686 #[test]
687 async fn integration_test_with_expired_authentication() {
688 let fixture = build_fixture(JWTTtl(1.nanoseconds())).await.unwrap();
689 let app = fixture.app;
690
691 let login_resp = {
692 let req = test::TestRequest::get().uri("/login").to_request();
693 test::call_service(&app, req).await
694 };
695 assert_eq!(actix_http::StatusCode::OK, login_resp.status());
696 tokio::time::sleep(Duration::from_secs(2)).await;
697 #[cfg(feature = "session")]
698 let (_login_response, session_resp) = {
699 let mut req = test::TestRequest::get().uri("/session");
700 for c in login_resp.response().cookies() {
701 req = req.cookie(c);
702 }
703 let resp = app.call(req.to_request()).await.err().unwrap();
704 let login_response: LoginResponse = test::read_body_json(login_resp).await;
705 (login_response, resp)
706 };
707 #[cfg(not(feature = "session"))]
708 let (_login_response, session_resp) = {
709 let login_response: LoginResponse = test::read_body_json(login_resp).await;
710 let req = test::TestRequest::get().uri("/session").insert_header((
711 "Authorization",
712 format!("Bearer {}", login_response.bearer_token),
713 ));
714 let resp = app.call(req.to_request()).await.err().unwrap();
715 (login_response, resp)
716 };
717 let session_resp = ServiceResponse::new(
718 test::TestRequest::get().uri("/session").to_http_request(),
719 session_resp.error_response(),
720 );
721 assert_eq!(actix_http::StatusCode::UNAUTHORIZED, session_resp.status());
722 let session_response: crate::errors::ErrorResponse =
723 test::read_body_json(session_resp).await;
724 assert_eq!(
725 "Invalid session [ExpiredSignature]",
726 session_response.message.as_str()
727 )
728 }
729
730 #[test]
731 async fn integration_test_with_invalidated_authentication() {
732 let fixture = build_fixture(JWTTtl(1.nanoseconds())).await.unwrap();
733 let app = fixture.app;
734
735 let login_resp = {
736 let req = test::TestRequest::get().uri("/login").to_request();
737 test::call_service(&app, req).await
738 };
739 assert_eq!(actix_http::StatusCode::OK, login_resp.status());
740
741 #[cfg(feature = "session")]
742 let (logout_req, session_req) = {
743 let mut logout_req = test::TestRequest::get().uri("/logout");
744 for c in login_resp.response().cookies() {
745 logout_req = logout_req.cookie(c);
746 }
747
748 let mut session_req = test::TestRequest::get().uri("/session");
749 for c in login_resp.response().cookies() {
750 session_req = session_req.cookie(c);
751 }
752 (logout_req, session_req)
753 };
754 #[cfg(not(feature = "session"))]
755 let (logout_req, session_req) = {
756 let login_response: LoginResponse = test::read_body_json(login_resp).await;
757 let session_req = test::TestRequest::get().uri("/session").insert_header((
758 "Authorization",
759 format!("Bearer {}", login_response.bearer_token),
760 ));
761 let logout_req = test::TestRequest::get().uri("/logout").insert_header((
762 "Authorization",
763 format!("Bearer {}", login_response.bearer_token),
764 ));
765 (logout_req, session_req)
766 };
767 let logout_resp = test::call_service(&app, logout_req.to_request()).await;
768 assert_eq!(actix_http::StatusCode::OK, logout_resp.status());
769
770 tokio::time::sleep(Duration::from_millis(100)).await;
771
772 let session_resp: actix_web::Error =
773 { app.call(session_req.to_request()).await.err().unwrap() };
774 let session_resp = {
775 ServiceResponse::new(
776 test::TestRequest::get().uri("/session").to_http_request(),
777 session_resp.error_response(),
778 )
779 };
780 assert_eq!(actix_http::StatusCode::UNAUTHORIZED, session_resp.status());
781
782 let session_response: crate::errors::ErrorResponse =
783 test::read_body_json(session_resp).await;
784 assert!(session_response
785 .message
786 .as_str()
787 .starts_with("Invalid session [Invalidated session"))
788 }
789
790 #[test]
791 async fn integration_test_with_remotely_invalidated_session() {
792 let fixture = build_fixture(JWTTtl(1.nanoseconds())).await.unwrap();
793 let app = fixture.app;
794
795 let login_resp = {
796 let req = test::TestRequest::get().uri("/login").to_request();
797 test::call_service(&app, req).await
798 };
799 assert_eq!(actix_http::StatusCode::OK, login_resp.status());
800 #[cfg(feature = "session")]
801 let (login_response, session_req) = {
802 let mut req = test::TestRequest::get().uri("/session");
803 for c in login_resp.response().cookies() {
804 req = req.cookie(c);
805 }
806
807 let login_response: LoginResponse = test::read_body_json(login_resp).await;
808 (login_response, req)
809 };
810 #[cfg(not(feature = "session"))]
811 let (login_response, session_req) = {
812 let login_response: LoginResponse = test::read_body_json(login_resp).await;
813 let req = test::TestRequest::get().uri("/session").insert_header((
814 "Authorization",
815 format!("Bearer {}", login_response.bearer_token),
816 ));
817 (login_response, req)
818 };
819
820 let authenticated = Authenticated {
821 jwt: JWT(login_response.bearer_token),
822 claims: login_response.claims,
823 };
824 fixture
825 .invalidated_jwts_store
826 .add_to_invalidated(authenticated)
827 .await;
828
829 tokio::time::sleep(Duration::from_millis(100)).await;
830
831 let session_resp: actix_web::Error =
832 { app.call(session_req.to_request()).await.err().unwrap() };
833 let session_resp = {
834 ServiceResponse::new(
835 test::TestRequest::get().uri("/session").to_http_request(),
836 session_resp.error_response(),
837 )
838 };
839
840 assert_eq!(actix_http::StatusCode::UNAUTHORIZED, session_resp.status());
841 let session_response: crate::errors::ErrorResponse =
842 test::read_body_json(session_resp).await;
843 assert!(session_response
844 .message
845 .as_str()
846 .starts_with("Invalid session [Invalidated session"))
847 }
848
849 struct TestFixture<T> {
850 invalidated_jwts_store: InvalidatedJWTStore,
851 app: T,
852 }
853
854 async fn build_fixture(
857 jwt_ttl: JWTTtl,
858 ) -> Result<
859 TestFixture<
860 impl Service<actix_http::Request, Response = ServiceResponse, Error = actix_web::Error>,
861 >,
862 Box<dyn std::error::Error>,
863 > {
864 let jwt_signing_keys = JwtSigningKeys::generate()?;
865 #[cfg(feature = "session")]
866 let jwt_session_key = JWTSessionKey("jwt-session".to_string());
867
868 let mut validator = Validation::new(JWT_SIGNING_ALGO);
869 validator.leeway = 1;
870
871 let auth_middleware_settings = AuthenticateMiddlewareSettings {
872 jwt_decoding_key: jwt_signing_keys.decoding_key,
873 #[cfg(feature = "session")]
874 jwt_session_key: Some(jwt_session_key.clone()),
875 jwt_authorization_header_prefixes: Some(vec!["Bearer".to_string()]),
876 jwt_validator: validator,
877 };
878
879 let (invalidated_jwts_store, stream) = InvalidatedJWTStore::new_with_stream();
880 let auth_middleware_factory =
881 AuthenticateMiddlewareFactory::<Claims>::new(stream, auth_middleware_settings.clone());
882
883 #[cfg(feature = "session")]
884 let session_encryption_key = Key::generate();
885
886 let app = {
887 #[cfg(feature = "session")]
888 let app_t = App::new()
889 .app_data(Data::new(jwt_session_key.clone()))
890 .app_data(Data::new(invalidated_jwts_store.clone()))
891 .app_data(Data::new(jwt_signing_keys.encoding_key.clone()))
892 .app_data(Data::new(jwt_ttl.clone()))
893 .wrap(auth_middleware_factory.clone())
894 .wrap(
895 SessionMiddleware::builder(
896 CookieSessionStore::default(),
897 session_encryption_key.clone(),
898 )
899 .cookie_secure(false)
900 .cookie_http_only(true)
901 .build(),
902 );
903 #[cfg(not(feature = "session"))]
904 let app_t = App::new()
905 .app_data(Data::new(invalidated_jwts_store.clone()))
906 .app_data(Data::new(jwt_signing_keys.encoding_key.clone()))
907 .app_data(Data::new(jwt_ttl.clone()))
908 .wrap(auth_middleware_factory.clone());
909 test::init_service(
910 app_t
911 .service(login)
912 .service(logout)
913 .service(session_info)
914 .service(maybe_session_info),
915 )
916 }
917 .await;
918 Ok(TestFixture {
919 invalidated_jwts_store: invalidated_jwts_store.clone(),
920 app,
921 })
922 }
923
924 #[get("/login")]
926 async fn login(
927 jwt_encoding_key: Data<EncodingKey>,
928 #[cfg(feature = "session")] jwt_session_key: Data<JWTSessionKey>,
929 jwt_ttl: Data<JWTTtl>,
930 #[cfg(feature = "session")] session: Session,
931 ) -> Result<HttpResponse, Error> {
932 let sub = format!("{}", Uuid::new_v4().as_u128());
933 let iat = OffsetDateTime::now_utc().unix_timestamp() as usize;
934 let expires_at = OffsetDateTime::now_utc().add(jwt_ttl.0);
935 let exp = expires_at.unix_timestamp() as usize;
936
937 let jwt_claims = Claims { iat, exp, sub };
938 let jwt_token = encode(
939 &Header::new(JWT_SIGNING_ALGO),
940 &jwt_claims,
941 &jwt_encoding_key,
942 )
943 .map_err(|_| Error::InternalError)?;
944 #[cfg(feature = "session")]
945 session
946 .insert(&jwt_session_key.0, &jwt_token)
947 .map_err(|_| Error::InternalError)?;
948 let login_response = LoginResponse {
949 bearer_token: jwt_token,
950 claims: jwt_claims,
951 };
952
953 Ok(HttpResponse::Ok().json(login_response))
954 }
955
956 #[get("/session")]
957 async fn session_info(authenticated: Authenticated<Claims>) -> Result<HttpResponse, Error> {
958 Ok(HttpResponse::Ok().json(authenticated))
959 }
960
961 #[get("/maybe_session")]
962 async fn maybe_session_info(
963 maybe_authenticated: MaybeAuthenticated<Claims>,
964 ) -> Result<HttpResponse, Error> {
965 if let Some(authenticated) = maybe_authenticated.into_option() {
966 Ok(HttpResponse::Ok().json(authenticated))
967 } else {
968 Ok(HttpResponse::Ok().json(MessageResponse {
969 message: "No session for you !".to_string(),
970 }))
971 }
972 }
973
974 #[get("/logout")]
975 async fn logout(
976 invalidated_jwts: Data<InvalidatedJWTStore>,
977 authenticated: Authenticated<Claims>,
978 #[cfg(feature = "session")] session: Session,
979 ) -> Result<HttpResponse, Error> {
980 #[cfg(feature = "session")]
981 session.clear();
982 invalidated_jwts.add_to_invalidated(authenticated).await;
983 Ok(HttpResponse::Ok().json(EmptyResponse {}))
984 }
985 const JWT_SIGNING_ALGO: Algorithm = Algorithm::EdDSA;
988
989 #[derive(Clone)]
991 struct InvalidatedJWTStore {
992 store: Arc<DashSet<JWT>>,
993 tx: Arc<Mutex<Sender<InvalidatedTokensEvent>>>,
994 }
995
996 impl InvalidatedJWTStore {
997 fn new_with_stream() -> (
998 InvalidatedJWTStore,
999 impl Stream<Item = InvalidatedTokensEvent>,
1000 ) {
1001 let invalidated = Arc::new(DashSet::new());
1002 let (tx, rx) = mpsc::channel(100);
1003 let tx_to_hold = Arc::new(Mutex::new(tx));
1004 (
1005 InvalidatedJWTStore {
1006 store: invalidated,
1007 tx: tx_to_hold,
1008 },
1009 rx,
1010 )
1011 }
1012
1013 async fn add_to_invalidated(&self, authenticated: Authenticated<Claims>) {
1014 self.store.insert(authenticated.jwt.clone());
1015 let mut tx = self.tx.lock().await;
1016 if let Err(_e) = tx
1017 .send(InvalidatedTokensEvent::Add(authenticated.jwt))
1018 .await
1019 {
1020 #[cfg(feature = "tracing")]
1021 error!(error = ?_e, "Failed to send update on adding to invalidated")
1022 }
1023 }
1024 }
1025
1026 struct JwtSigningKeys {
1027 encoding_key: EncodingKey,
1028 decoding_key: DecodingKey,
1029 }
1030
1031 impl JwtSigningKeys {
1032 fn generate() -> Result<Self, Box<dyn std::error::Error>> {
1033 let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?;
1034 let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?;
1035 let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
1036 let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref());
1037 Ok(JwtSigningKeys {
1038 encoding_key,
1039 decoding_key,
1040 })
1041 }
1042 }
1043
1044 #[derive(Clone, Copy)]
1047 struct JWTTtl(time::Duration);
1048
1049 impl Default for JWTTtl {
1050 fn default() -> Self {
1051 JWTTtl(1.days())
1052 }
1053 }
1054
1055 #[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
1056 struct Claims {
1057 exp: usize,
1058 iat: usize,
1059 sub: String,
1060 }
1061
1062 #[derive(Debug, Serialize, Deserialize)]
1063 struct LoginResponse {
1064 bearer_token: String,
1065 claims: Claims,
1066 }
1067
1068 #[derive(Serialize, Deserialize)]
1069 struct EmptyResponse {}
1070
1071 #[derive(Serialize, Deserialize)]
1072 struct MessageResponse {
1073 message: String,
1074 }
1075
1076 }