use std::borrow::Cow;
use std::marker::PhantomData;
use std::sync::Arc;
pub use actix_web::cookie::time::{Duration, OffsetDateTime};
use actix_web::dev::ServiceRequest;
use actix_web::{FromRequest, HttpMessage, HttpResponse};
use async_trait::async_trait;
use derive_more::{Constructor, Deref};
pub use jsonwebtoken::Algorithm;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Validation};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
pub use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Deref, Constructor)]
#[serde(transparent)]
pub struct JwtTtl(pub Duration);
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Deref, Constructor)]
#[serde(transparent)]
pub struct RefreshTtl(pub Duration);
pub static JWT_HEADER_NAME: &str = "Authorization";
pub static REFRESH_HEADER_NAME: &str = "ACX-Refresh";
pub static JWT_COOKIE_NAME: &str = "ACX-Auth";
pub static REFRESH_COOKIE_NAME: &str = "ACX-Refresh";
pub trait Claims:
PartialEq + DeserializeOwned + Serialize + Clone + Send + Sync + std::fmt::Debug + 'static
{
fn jti(&self) -> uuid::Uuid;
fn subject(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshToken {
#[serde(rename = "iat")]
pub issues_at: OffsetDateTime,
#[serde(rename = "sub")]
access_jti: String,
pub access_ttl: JwtTtl,
pub refresh_jti: uuid::Uuid,
pub refresh_ttl: RefreshTtl,
#[serde(rename = "exp")]
pub expiration_time: u64,
#[serde(rename = "nbf")]
pub not_before: u64,
#[serde(rename = "aud")]
pub audience: String,
#[serde(rename = "iss")]
pub issuer: String,
}
impl PartialEq for RefreshToken {
fn eq(&self, o: &Self) -> bool {
self.access_jti == o.access_jti
&& self.refresh_jti == o.refresh_jti
&& self.refresh_ttl == o.refresh_ttl
&& self.expiration_time == o.expiration_time
&& self.not_before == o.not_before
&& self.audience == o.audience
&& self.issuer == o.issuer
}
}
impl RefreshToken {
pub fn is_access_valid(&self) -> bool {
self.issues_at + self.access_ttl.0 >= OffsetDateTime::now_utc()
}
pub fn is_refresh_valid(&self) -> bool {
self.issues_at + self.refresh_ttl.0 >= OffsetDateTime::now_utc()
}
pub fn access_jti(&self) -> uuid::Uuid {
Uuid::parse_str(&self.access_jti).unwrap()
}
}
impl Claims for RefreshToken {
fn jti(&self) -> uuid::Uuid {
self.refresh_jti
}
fn subject(&self) -> &str {
"refresh-token"
}
}
pub struct Pair<ClaimsType: Claims> {
pub jwt: Authenticated<ClaimsType>,
pub refresh: Authenticated<RefreshToken>,
}
#[derive(Debug, thiserror::Error, PartialEq, Clone, Copy)]
pub enum Error {
#[error("Failed to obtain redis connection")]
RedisConn,
#[error("Record not found")]
NotFound,
#[error("Record malformed")]
RecordMalformed,
#[error("Invalid session")]
InvalidSession,
#[error("Claims can't be loaded")]
LoadError,
#[error("Storage claims and given claims are different")]
DontMatch,
#[error("Given token in invalid. Can't decode claims")]
CantDecode,
#[error("No http authentication header")]
NoAuthHeader,
#[error("Failed to serialize claims")]
SerializeFailed,
#[error("Unable to write claims to storage")]
WriteFailed,
#[error("Access token expired")]
JWTExpired,
}
impl actix_web::ResponseError for Error {
fn status_code(&self) -> actix_web::http::StatusCode {
match self {
Self::RedisConn => actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
_ => actix_web::http::StatusCode::UNAUTHORIZED,
}
}
fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
HttpResponse::build(self.status_code()).body("")
}
}
#[derive(Clone)]
pub struct Authenticated<T> {
pub claims: Arc<T>,
pub jwt_encoding_key: Arc<EncodingKey>,
pub algorithm: Algorithm,
}
impl<T> std::ops::Deref for Authenticated<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.claims
}
}
impl<T: Claims> Authenticated<T> {
pub fn encode(&self) -> Result<String, jsonwebtoken::errors::Error> {
encode(
&jsonwebtoken::Header::new(self.algorithm),
&*self.claims,
&self.jwt_encoding_key,
)
}
}
impl<T: Claims> FromRequest for Authenticated<T> {
type Error = actix_web::error::Error;
type Future = std::future::Ready<Result<Self, actix_web::Error>>;
fn from_request(
req: &actix_web::HttpRequest,
_payload: &mut actix_web::dev::Payload,
) -> Self::Future {
let value = req
.extensions_mut()
.get::<Authenticated<T>>()
.map(Clone::clone);
std::future::ready(value.ok_or_else(|| Error::NotFound.into()))
}
}
pub struct MaybeAuthenticated<ClaimsType: Claims>(Option<Authenticated<ClaimsType>>);
impl<ClaimsType: Claims> MaybeAuthenticated<ClaimsType> {
pub fn is_authenticated(&self) -> bool {
self.0.is_some()
}
pub fn into_option(self) -> Option<Authenticated<ClaimsType>> {
self.0
}
}
impl<ClaimsType: Claims> std::ops::Deref for MaybeAuthenticated<ClaimsType> {
type Target = Option<Authenticated<ClaimsType>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T: Claims> FromRequest for MaybeAuthenticated<T> {
type Error = actix_web::error::Error;
type Future = std::future::Ready<Result<Self, actix_web::Error>>;
fn from_request(
req: &actix_web::HttpRequest,
_payload: &mut actix_web::dev::Payload,
) -> Self::Future {
let value = req
.extensions_mut()
.get::<Authenticated<T>>()
.map(Clone::clone);
std::future::ready(Ok(MaybeAuthenticated(value)))
}
}
#[async_trait(?Send)]
pub trait TokenStorage: Send + Sync {
async fn get_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<Vec<u8>, Error>;
async fn set_by_jti(
self: Arc<Self>,
jwt_jti: &[u8],
refresh_jti: &[u8],
bytes: &[u8],
exp: Duration,
) -> Result<(), Error>;
async fn remove_by_jti(self: Arc<Self>, jti: &[u8]) -> Result<(), Error>;
}
#[derive(Clone)]
pub struct SessionStorage {
storage: Arc<dyn TokenStorage>,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
}
impl std::ops::Deref for SessionStorage {
type Target = Arc<dyn TokenStorage>;
fn deref(&self) -> &Self::Target {
&self.storage
}
}
#[doc(hidden)]
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SessionRecord {
refresh_jti: uuid::Uuid,
jwt_jti: uuid::Uuid,
refresh_token: String,
jwt: String,
}
impl SessionRecord {
fn new<ClaimsType: Claims>(claims: ClaimsType, refresh: RefreshToken) -> Result<Self, Error> {
let refresh_jti = claims.jti();
let jwt_jti = refresh.refresh_jti;
let refresh_token = serde_json::to_string(&refresh).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Failed to serialize Refresh Token to construct pair: {e:?}");
Error::SerializeFailed
})?;
let jwt = serde_json::to_string(&claims).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Failed to serialize JWT from to construct pair {e:?}");
Error::SerializeFailed
})?;
Ok(Self {
refresh_jti,
jwt_jti,
refresh_token,
jwt,
})
}
fn refresh_token(&self) -> Result<RefreshToken, Error> {
serde_json::from_str(&self.refresh_token).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Failed to deserialize refresh token from pair: {e:?}");
Error::RecordMalformed
})
}
fn from_field<CT: Claims>(s: &str) -> Result<CT, Error> {
serde_json::from_str(s).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!(
"Failed to deserialize {} for pair: {e:?}",
std::any::type_name::<CT>()
);
Error::RecordMalformed
})
}
fn set_refresh_token(&mut self, mut refresh: RefreshToken) -> Result<(), Error> {
refresh.expiration_time = refresh.refresh_ttl.0.as_seconds_f64() as u64;
let refresh_token = serde_json::to_string(&refresh).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Failed to serialize refresh token for pair: {e:?}");
Error::SerializeFailed
})?;
self.refresh_token = refresh_token;
Ok(())
}
}
impl SessionStorage {
pub fn new(
storage: Arc<dyn TokenStorage>,
jwt_encoding_key: Arc<EncodingKey>,
algorithm: Algorithm,
) -> Self {
Self {
storage,
jwt_encoding_key,
algorithm,
}
}
pub async fn find_jwt<ClaimsType: Claims>(&self, jti: uuid::Uuid) -> Result<ClaimsType, Error> {
let record = self.load_pair_by_jwt(jti).await?;
let refresh_token = record.refresh_token()?;
if std::any::type_name::<ClaimsType>() == std::any::type_name::<RefreshToken>() {
SessionRecord::from_field(&record.refresh_token)
} else {
if !refresh_token.is_access_valid() {
#[cfg(feature = "use-tracing")]
tracing::debug!("JWT expired");
return Err(Error::JWTExpired);
}
SessionRecord::from_field(&record.jwt)
}
}
pub async fn refresh<ClaimsType: Claims>(
&self,
refresh_jti: uuid::Uuid,
) -> Result<Pair<ClaimsType>, Error> {
let mut record = self.load_pair_by_refresh(refresh_jti).await?;
let mut refresh_token = record.refresh_token()?;
let ttl = refresh_token.refresh_ttl;
refresh_token.issues_at = OffsetDateTime::now_utc();
record.set_refresh_token(refresh_token)?;
self.store_pair(record.clone(), ttl).await?;
let claims = SessionRecord::from_field::<ClaimsType>(&record.jwt)?;
let refresh = SessionRecord::from_field::<RefreshToken>(&record.refresh_token)?;
Ok(Pair {
jwt: Authenticated {
claims: Arc::new(claims),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
refresh: Authenticated {
claims: Arc::new(refresh),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
})
}
pub async fn store<ClaimsType: Claims>(
&self,
claims: ClaimsType,
access_ttl: JwtTtl,
refresh_ttl: RefreshTtl,
) -> Result<Pair<ClaimsType>, Error> {
let now = OffsetDateTime::now_utc();
let refresh = RefreshToken {
refresh_jti: uuid::Uuid::new_v4(),
refresh_ttl,
access_jti: claims.jti().hyphenated().to_string(),
access_ttl,
issues_at: now,
expiration_time: refresh_ttl.0.as_seconds_f64() as u64,
issuer: claims.jti().hyphenated().to_string(),
not_before: 0,
audience: claims.subject().to_string(),
};
let record = SessionRecord::new(claims.clone(), refresh.clone())?;
self.store_pair(record, refresh_ttl).await?;
Ok(Pair {
jwt: Authenticated {
claims: Arc::new(claims),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
refresh: Authenticated {
claims: Arc::new(refresh),
jwt_encoding_key: self.jwt_encoding_key.clone(),
algorithm: self.algorithm,
},
})
}
pub async fn erase<ClaimsType: Claims>(&self, jti: Uuid) -> Result<(), Error> {
let record = self.load_pair_by_jwt(jti).await?;
self.storage
.clone()
.remove_by_jti(record.refresh_jti.as_bytes())
.await?;
self.storage
.clone()
.remove_by_jti(record.jwt_jti.as_bytes())
.await?;
Ok(())
}
async fn store_pair(
&self,
record: SessionRecord,
refresh_ttl: RefreshTtl,
) -> Result<(), Error> {
let value = bincode::serialize(&record).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Serialize pair to bytes failed: {e:?}");
Error::SerializeFailed
})?;
self.storage
.clone()
.set_by_jti(
record.jwt_jti.as_bytes(),
record.refresh_jti.as_bytes(),
&value,
refresh_ttl.0,
)
.await?;
Ok(())
}
async fn load_pair_by_jwt(&self, jti: Uuid) -> Result<SessionRecord, Error> {
self.storage
.clone()
.get_by_jti(jti.as_bytes())
.await
.and_then(|bytes| {
bincode::deserialize(&bytes).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Deserialize pair while loading for JWT ID failed: {e:?}");
Error::RecordMalformed
})
})
}
async fn load_pair_by_refresh(&self, jti: Uuid) -> Result<SessionRecord, Error> {
self.storage
.clone()
.get_by_jti(jti.as_bytes())
.await
.and_then(|bytes| {
bincode::deserialize(&bytes).map_err(|e| {
#[cfg(feature = "use-tracing")]
tracing::debug!("Deserialize pair while loading for refresh id failed: {e:?}");
Error::RecordMalformed
})
})
}
}
pub mod builder;
pub use builder::*;
#[cfg(feature = "routes")]
pub mod actix_routes;
#[cfg(feature = "routes")]
pub use actix_routes::configure;
mod extractors;
pub use extractors::*;
pub struct JwtSigningKeys {
pub encoding_key: EncodingKey,
pub decoding_key: DecodingKey,
}
impl JwtSigningKeys {
pub fn load_or_create() -> Self {
match Self::load_from_files() {
Ok(s) => s,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
Self::generate(true).expect("Generating new jwt signing keys must succeed")
}
Err(e) => panic!("Failed to load or generate jwt signing keys: {:?}", e),
}
}
pub fn generate(save: bool) -> Result<Self, Box<dyn std::error::Error>> {
use jsonwebtoken::*;
use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair};
let doc = Ed25519KeyPair::generate_pkcs8(&SystemRandom::new())?;
let keypair = Ed25519KeyPair::from_pkcs8(doc.as_ref())?;
let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
let decoding_key = DecodingKey::from_ed_der(keypair.public_key().as_ref());
if save {
std::fs::write("./config/jwt-encoding.bin", doc.as_ref()).unwrap_or_else(|e| {
panic!("Failed to write ./config/jwt-encoding.bin: {:?}", e);
});
std::fs::write("./config/jwt-decoding.bin", keypair.public_key()).unwrap_or_else(|e| {
panic!("Failed to write ./config/jwt-decoding.bin: {:?}", e);
});
}
Ok(JwtSigningKeys {
encoding_key,
decoding_key,
})
}
pub fn load_from_files() -> std::io::Result<Self> {
use std::io::Read;
use jsonwebtoken::*;
let mut buf = Vec::new();
let mut e = std::fs::File::open("./config/jwt-encoding.bin")?;
e.read_to_end(&mut buf).unwrap_or_else(|e| {
panic!("Failed to read jwt encoding key: {:?}", e);
});
let encoding_key: EncodingKey = EncodingKey::from_ed_der(&buf);
let mut buf = Vec::new();
let mut e = std::fs::File::open("./config/jwt-decoding.bin")?;
e.read_to_end(&mut buf).unwrap_or_else(|e| {
panic!("Failed to read jwt decoding key: {:?}", e);
});
let decoding_key = DecodingKey::from_ed_der(&buf);
Ok(Self {
encoding_key,
decoding_key,
})
}
}
#[macro_export]
macro_rules! bad_ttl {
($ttl: expr, $min: expr, $panic_msg: expr) => {
if $ttl < $min {
#[cfg(feature = "use-tracing")]
tracing::warn!(
"Expiration time is bellow 1s. This is not allowed for redis server. Overriding!"
);
if cfg!(feature = "panic-bad-ttl") {
panic!($panic_msg);
} else if cfg!(feature = "override-bad-ttl") {
$ttl = $min;
}
}
};
}
mod middleware;
pub use middleware::*;
#[cfg(feature = "redis")]
mod redis_adapter;
#[allow(unused_imports)]
#[cfg(feature = "redis")]
pub use redis_adapter::*;
#[cfg(feature = "hashing")]
mod hashing;
#[cfg(feature = "hashing")]
pub use hashing::*;