use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use rustauth_core::context::AuthContext;
use rustauth_core::db::{Session, User};
use rustauth_core::error::RustAuthError;
use super::{Jwk, JwkAlgorithm, JwtClaims};
pub type JwtClaimsFuture<'a> =
Pin<Box<dyn Future<Output = Result<JwtClaims, RustAuthError>> + Send + 'a>>;
pub type JwtStringFuture<'a> =
Pin<Box<dyn Future<Output = Result<String, RustAuthError>> + Send + 'a>>;
pub type JwtJwksFuture<'a> =
Pin<Box<dyn Future<Output = Result<Vec<Jwk>, RustAuthError>> + Send + 'a>>;
pub type JwtJwkFuture<'a> = Pin<Box<dyn Future<Output = Result<Jwk, RustAuthError>> + Send + 'a>>;
pub type JwtDefinePayloadHandler =
Arc<dyn for<'a> Fn(&'a JwtSessionContext) -> JwtClaimsFuture<'a> + Send + Sync>;
pub type JwtGetSubjectHandler =
Arc<dyn for<'a> Fn(&'a JwtSessionContext) -> JwtStringFuture<'a> + Send + Sync>;
pub type JwtSignHandler = Arc<dyn Fn(JwtClaims) -> JwtStringFuture<'static> + Send + Sync>;
pub type JwtGetJwksHandler =
Arc<dyn for<'a> Fn(&'a AuthContext) -> JwtJwksFuture<'a> + Send + Sync>;
pub type JwtCreateJwkHandler =
Arc<dyn for<'a> Fn(&'a AuthContext, Jwk) -> JwtJwkFuture<'a> + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct JwtSessionContext {
pub session: Session,
pub user: User,
}
#[derive(Clone, Default)]
pub struct JwtOptions {
pub jwks: JwtJwksOptions,
pub jwt: JwtSigningOptions,
pub adapter: JwtAdapterOptions,
pub disable_setting_jwt_header: bool,
pub schema: super::schema::JwtSchemaOptions,
}
#[derive(Debug, Clone)]
pub struct JwtJwksOptions {
pub remote_url: Option<String>,
pub key_pair_algorithm: Option<JwkAlgorithm>,
pub rsa_modulus_length: Option<u32>,
pub disable_private_key_encryption: bool,
pub rotation_interval: Option<i64>,
pub grace_period: i64,
pub jwks_path: String,
}
impl Default for JwtJwksOptions {
fn default() -> Self {
Self {
remote_url: None,
key_pair_algorithm: Some(JwkAlgorithm::EdDsa),
rsa_modulus_length: None,
disable_private_key_encryption: false,
rotation_interval: None,
grace_period: 60 * 60 * 24 * 30,
jwks_path: "/jwks".to_owned(),
}
}
}
#[derive(Clone, Default)]
pub struct JwtSigningOptions {
pub issuer: Option<String>,
pub audience: Option<Vec<String>>,
pub expiration_time: Option<super::TimeInput>,
pub define_payload: Option<JwtDefinePayloadHandler>,
pub get_subject: Option<JwtGetSubjectHandler>,
pub sign: Option<JwtSignHandler>,
}
impl fmt::Debug for JwtSigningOptions {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("JwtSigningOptions")
.field("issuer", &self.issuer)
.field("audience", &self.audience)
.field("expiration_time", &self.expiration_time)
.field(
"define_payload",
&self.define_payload.as_ref().map(|_| "<define-payload>"),
)
.field(
"get_subject",
&self.get_subject.as_ref().map(|_| "<get-subject>"),
)
.field("sign", &self.sign.as_ref().map(|_| "<sign-handler>"))
.finish()
}
}
#[derive(Clone, Default)]
pub struct JwtAdapterOptions {
pub get_jwks: Option<JwtGetJwksHandler>,
pub create_jwk: Option<JwtCreateJwkHandler>,
}
impl fmt::Debug for JwtAdapterOptions {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("JwtAdapterOptions")
.field("get_jwks", &self.get_jwks.as_ref().map(|_| "<get-jwks>"))
.field(
"create_jwk",
&self.create_jwk.as_ref().map(|_| "<create-jwk>"),
)
.finish()
}
}
impl JwtOptions {
#[must_use]
pub fn builder() -> JwtOptionsBuilder {
JwtOptionsBuilder::default()
}
pub fn validate(&self) -> Result<(), RustAuthError> {
if self.jwt.sign.is_some() && self.jwks.remote_url.is_none() {
return Err(RustAuthError::InvalidConfig(
"options.jwks.remoteUrl must be set when using options.jwt.sign".to_owned(),
));
}
if self.jwks.remote_url.is_some() && self.jwks.key_pair_algorithm.is_none() {
return Err(RustAuthError::InvalidConfig(
"options.jwks.keyPairConfig.alg must be specified when using remoteUrl".to_owned(),
));
}
if let Some(modulus_length) = self.jwks.rsa_modulus_length {
if modulus_length < 2048 {
return Err(RustAuthError::InvalidConfig(
"options.jwks.keyPairConfig.modulusLength must be at least 2048".to_owned(),
));
}
}
let path = &self.jwks.jwks_path;
if path.is_empty() || !path.starts_with('/') || path.contains("..") {
return Err(RustAuthError::InvalidConfig(
"options.jwks.jwksPath must be a non-empty string starting with '/' and not contain '..'"
.to_owned(),
));
}
Ok(())
}
pub(crate) fn algorithm(&self) -> JwkAlgorithm {
self.jwks.key_pair_algorithm.unwrap_or(JwkAlgorithm::EdDsa)
}
}
#[derive(Clone, Default)]
pub struct JwtOptionsBuilder {
jwks: Option<JwtJwksOptions>,
jwt: Option<JwtSigningOptions>,
adapter: Option<JwtAdapterOptions>,
disable_setting_jwt_header: Option<bool>,
schema: Option<super::schema::JwtSchemaOptions>,
}
impl JwtOptionsBuilder {
#[must_use]
pub fn jwks(mut self, jwks: JwtJwksOptions) -> Self {
self.jwks = Some(jwks);
self
}
#[must_use]
pub fn jwt(mut self, jwt: JwtSigningOptions) -> Self {
self.jwt = Some(jwt);
self
}
#[must_use]
pub fn adapter(mut self, adapter: JwtAdapterOptions) -> Self {
self.adapter = Some(adapter);
self
}
#[must_use]
pub fn disable_setting_jwt_header(mut self, disabled: bool) -> Self {
self.disable_setting_jwt_header = Some(disabled);
self
}
#[must_use]
pub fn schema(mut self, schema: super::schema::JwtSchemaOptions) -> Self {
self.schema = Some(schema);
self
}
pub fn build(self) -> Result<JwtOptions, RustAuthError> {
let defaults = JwtOptions::default();
let options = JwtOptions {
jwks: self.jwks.unwrap_or(defaults.jwks),
jwt: self.jwt.unwrap_or(defaults.jwt),
adapter: self.adapter.unwrap_or(defaults.adapter),
disable_setting_jwt_header: self
.disable_setting_jwt_header
.unwrap_or(defaults.disable_setting_jwt_header),
schema: self.schema.unwrap_or(defaults.schema),
};
options.validate()?;
Ok(options)
}
}