use std::sync::Arc;
use crate::*;
#[derive(Clone, Debug)]
pub struct Extractors<ClaimsType: Claims + std::fmt::Debug> {
pub(crate) jwt_extractors: Vec<Arc<dyn SessionExtractor<ClaimsType>>>,
pub(crate) refresh_extractors: Vec<Arc<dyn SessionExtractor<RefreshToken>>>,
}
impl<ClaimsType: Claims> Default for Extractors<ClaimsType> {
fn default() -> Self {
Self {
jwt_extractors: vec![],
refresh_extractors: vec![],
}
}
}
impl<ClaimsType: Claims> Extractors<ClaimsType> {
#[must_use]
pub fn with_refresh_cookie(mut self, name: &'static str) -> Self {
self.refresh_extractors
.push(Arc::new(CookieExtractor::<RefreshToken>::new(name)));
self
}
#[must_use]
pub fn with_refresh_header(mut self, name: &'static str) -> Self {
self.refresh_extractors
.push(Arc::new(HeaderExtractor::<RefreshToken>::new(name)));
self
}
#[must_use]
pub fn with_jwt_cookie(mut self, name: &'static str) -> Self {
self.jwt_extractors
.push(Arc::new(CookieExtractor::<ClaimsType>::new(name)));
self
}
#[must_use]
pub fn with_jwt_header(mut self, name: &'static str) -> Self {
self.jwt_extractors
.push(Arc::new(HeaderExtractor::<ClaimsType>::new(name)));
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Hash)]
pub enum ExtractorKind {
Header,
Cookie,
UrlParam,
ReqBody,
}
#[async_trait(?Send)]
pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static + std::fmt::Debug {
async fn extract_claims(
&self,
req: &mut ServiceRequest,
jwt_encoding_key: Arc<EncodingKey>,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
storage: SessionStorage,
) -> Result<(), Error> {
let Some(as_str) = self.extract_token_text(req).await else {
return Ok(());
};
let decoded_claims = self.decode(&as_str, jwt_decoding_key, algorithm)?;
self.validate(&decoded_claims, storage).await?;
req.extensions_mut().insert(Authenticated {
claims: Arc::new(decoded_claims),
jwt_encoding_key,
algorithm,
});
Ok(())
}
fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)>;
fn decode(
&self,
value: &str,
jwt_decoding_key: Arc<DecodingKey>,
algorithm: Algorithm,
) -> Result<ClaimsType, Error> {
let mut validation = Validation::new(algorithm);
validation.validate_exp = false;
validation.validate_nbf = false;
validation.leeway = 0;
validation.required_spec_claims.clear();
decode::<ClaimsType>(value, &jwt_decoding_key, &validation)
.map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Failed to decode claims: {e:?}. {e}");
Error::CantDecode
})
.map(|t| t.claims)
}
async fn validate(&self, claims: &ClaimsType, storage: SessionStorage) -> Result<(), Error> {
let stored = storage
.clone()
.find_jwt::<ClaimsType>(claims.jti())
.await
.map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!(
"Failed to load {} from storage: {e:?}",
std::any::type_name::<ClaimsType>()
);
Error::LoadError
})?;
if &stored != claims {
#[cfg(feature = "use-tracing")]
tracing::debug!("{claims:?} != {stored:?}");
Err(Error::DontMatch)
} else {
Ok(())
}
}
async fn extract_token_text<'req>(
&self,
req: &'req mut ServiceRequest,
) -> Option<Cow<'req, str>>;
}
#[derive(Debug)]
pub struct CookieExtractor<ClaimsType> {
__ty: PhantomData<ClaimsType>,
cookie_name: &'static str,
}
impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
pub fn new(cookie_name: &'static str) -> Self {
Self {
__ty: Default::default(),
cookie_name,
}
}
}
#[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
async fn extract_token_text<'req>(
&self,
req: &'req mut ServiceRequest,
) -> Option<Cow<'req, str>> {
req.cookie(self.cookie_name)
.map(|c| c.value().to_string().into())
}
fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)> {
Some((ExtractorKind::Cookie, self.cookie_name.into()))
}
}
#[derive(Debug)]
pub struct HeaderExtractor<ClaimsType> {
__ty: PhantomData<ClaimsType>,
header_name: &'static str,
}
impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
pub fn new(header_name: &'static str) -> Self {
Self {
__ty: Default::default(),
header_name,
}
}
}
#[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
async fn extract_token_text<'req>(
&self,
req: &'req mut ServiceRequest,
) -> Option<Cow<'req, str>> {
req.headers()
.get(self.header_name)
.and_then(|h| h.to_str().ok())
.map(|h| h.to_owned().into())
}
fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)> {
Some((ExtractorKind::Header, self.header_name.into()))
}
}
#[derive(Debug)]
pub struct JsonExtractor<ClaimsType> {
__ty: PhantomData<ClaimsType>,
path: &'static [&'static str],
}
impl<ClaimsType: Claims> JsonExtractor<ClaimsType> {
pub fn new(path: &'static [&'static str]) -> Self {
Self {
__ty: Default::default(),
path,
}
}
}
#[async_trait(?Send)]
impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for JsonExtractor<ClaimsType> {
async fn extract_token_text<'req>(
&self,
req: &'req mut ServiceRequest,
) -> Option<Cow<'req, str>> {
let Ok(v) = req
.extract::<actix_web::web::Json<serde_json::Value>>()
.await
else {
return None;
};
let json = v.into_inner();
let mut v = &json;
let len = self.path.len();
self.path.iter().enumerate().fold(None, |_, (idx, piece)| {
if idx + 1 == len {
v.as_object()?
.get(*piece)?
.as_str()
.map(ToOwned::to_owned)
.map(Into::into)
} else {
v = v.as_object()?.get(*piece)?;
None
}
})
}
fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)> {
Some((ExtractorKind::ReqBody, self.path.join(".").into()))
}
}