use crate::*;
pub use actix_web::cookie::time::{Duration, OffsetDateTime};
use actix_web::dev::Transform;
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse};
use futures_util::future::LocalBoxFuture;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey};
use std::future::{ready, Ready};
use std::rc::Rc;
use std::sync::Arc;
pub struct SessionMiddlewareBuilder<ClaimsType: Claims> {
pub(crate) jwt_encoding_key: Arc<EncodingKey>,
pub(crate) jwt_decoding_key: Arc<DecodingKey>,
pub(crate) algorithm: Algorithm,
pub(crate) storage: Option<SessionStorage>,
pub(crate) jwt_extractors: Vec<Box<dyn SessionExtractor<ClaimsType>>>,
pub(crate) refresh_extractors: Vec<Box<dyn SessionExtractor<RefreshToken>>>,
}
impl<ClaimsType: Claims> SessionMiddlewareBuilder<ClaimsType> {
#[doc(hidden)]
pub(crate) fn new(
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
) -> Self {
Self {
jwt_encoding_key: jwt_encoding_key.clone(),
jwt_decoding_key,
algorithm,
storage: None,
jwt_extractors: vec![],
refresh_extractors: vec![],
}
}
#[must_use]
pub fn with_storage(mut self, storage: SessionStorage) -> Self {
self.storage = Some(storage);
self
}
#[must_use]
pub fn with_refresh_cookie(mut self, name: &'static str) -> Self {
self.refresh_extractors
.push(Box::new(CookieExtractor::<RefreshToken>::new(name)));
self
}
#[must_use]
pub fn with_refresh_header(mut self, name: &'static str) -> Self {
self.refresh_extractors
.push(Box::new(HeaderExtractor::<RefreshToken>::new(name)));
self
}
#[must_use]
pub fn with_jwt_cookie(mut self, name: &'static str) -> Self {
self.jwt_extractors
.push(Box::new(CookieExtractor::<ClaimsType>::new(name)));
self
}
#[must_use]
pub fn with_jwt_header(mut self, name: &'static str) -> Self {
self.jwt_extractors
.push(Box::new(HeaderExtractor::<ClaimsType>::new(name)));
self
}
pub fn finish(self) -> (SessionStorage, SessionMiddlewareFactory<ClaimsType>) {
let Self {
storage,
jwt_encoding_key,
jwt_decoding_key,
algorithm,
jwt_extractors,
refresh_extractors,
..
} = self;
let storage = storage
.expect("Session storage must be constracted from pool or set from existing storage");
(
storage.clone(),
SessionMiddlewareFactory {
jwt_encoding_key,
jwt_decoding_key,
algorithm,
storage,
jwt_extractors: Arc::new(jwt_extractors),
refresh_extractors: Arc::new(refresh_extractors),
},
)
}
}
#[derive(Clone)]
pub struct SessionMiddlewareFactory<ClaimsType: Claims> {
pub(crate) jwt_encoding_key: Arc<EncodingKey>,
pub(crate) jwt_decoding_key: Arc<DecodingKey>,
pub(crate) algorithm: Algorithm,
pub(crate) storage: SessionStorage,
pub(crate) jwt_extractors: Arc<Vec<Box<dyn SessionExtractor<ClaimsType>>>>,
pub(crate) refresh_extractors: Arc<Vec<Box<dyn SessionExtractor<RefreshToken>>>>,
}
impl<ClaimsType: Claims> SessionMiddlewareFactory<ClaimsType> {
pub fn build(
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
) -> SessionMiddlewareBuilder<ClaimsType> {
SessionMiddlewareBuilder::new(jwt_encoding_key, jwt_decoding_key, algorithm)
}
}
impl<S, B, ClaimsType> Transform<S, ServiceRequest> for SessionMiddlewareFactory<ClaimsType>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
ClaimsType: Claims,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Transform = SessionMiddleware<S, ClaimsType>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(SessionMiddleware {
service: Rc::new(service),
storage: self.storage.clone(),
jwt_encoding_key: self.jwt_encoding_key.clone(),
jwt_decoding_key: self.jwt_decoding_key.clone(),
algorithm: self.algorithm,
jwt_extractors: self.jwt_extractors.clone(),
refresh_extractors: self.refresh_extractors.clone(),
}))
}
}
#[doc(hidden)]
pub struct SessionMiddleware<S, ClaimsType>
where
ClaimsType: Claims,
{
pub(crate) service: Rc<S>,
pub(crate) jwt_encoding_key: Arc<EncodingKey>,
pub(crate) jwt_decoding_key: Arc<DecodingKey>,
pub(crate) algorithm: Algorithm,
pub(crate) storage: SessionStorage,
pub(crate) jwt_extractors: Arc<Vec<Box<dyn SessionExtractor<ClaimsType>>>>,
pub(crate) refresh_extractors: Arc<Vec<Box<dyn SessionExtractor<RefreshToken>>>>,
}
impl<S, ClaimsType: Claims> SessionMiddleware<S, ClaimsType> {
async fn extract_token<C: Claims>(
req: &mut ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage,
extractors: &[Box<dyn SessionExtractor<C>>],
) -> Result<(), Error> {
let mut last_error = None;
for extractor in extractors.iter() {
match extractor
.extract_claims(
req,
jwt_encoding_key.clone(),
jwt_decoding_key.clone(),
algorithm,
storage.clone(),
)
.await
{
Ok(_) => break,
Err(e) => {
last_error = Some(e);
}
};
}
if let Some(e) = last_error {
return Err(e)?;
}
Ok(())
}
}
impl<S, B, ClaimsType> Service<ServiceRequest> for SessionMiddleware<S, ClaimsType>
where
ClaimsType: Claims,
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, mut req: ServiceRequest) -> Self::Future {
use futures_lite::FutureExt;
let svc = self.service.clone();
let jwt_decoding_key = self.jwt_decoding_key.clone();
let jwt_encoding_key = self.jwt_encoding_key.clone();
let algorithm = self.algorithm;
let storage = self.storage.clone();
let jwt_extractors = self.jwt_extractors.clone();
let refresh_extractors = self.refresh_extractors.clone();
async move {
if !jwt_extractors.is_empty() {
Self::extract_token(
&mut req,
jwt_encoding_key.clone(),
jwt_decoding_key.clone(),
algorithm,
storage.clone(),
&jwt_extractors,
)
.await?;
}
if !refresh_extractors.is_empty() {
Self::extract_token(
&mut req,
jwt_encoding_key,
jwt_decoding_key,
algorithm,
storage,
&refresh_extractors,
)
.await?;
}
let res = svc.call(req).await?;
Ok(res)
}
.boxed_local()
}
}