use crate::errors::{AuthError, Result};
use jsonwebtoken::{Algorithm, DecodingKey};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecureJwtClaims {
pub sub: String,
pub iss: String,
pub aud: String,
pub exp: i64,
pub nbf: i64,
pub iat: i64,
pub jti: String,
pub scope: String,
pub typ: String,
pub sid: Option<String>,
pub client_id: Option<String>,
pub auth_ctx_hash: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SecureJwtConfig {
pub allowed_algorithms: Vec<Algorithm>,
pub required_issuers: HashSet<String>,
pub required_audiences: HashSet<String>,
pub max_token_lifetime: Duration,
pub clock_skew: Duration,
pub require_jti: bool,
pub validate_nbf: bool,
pub allowed_token_types: HashSet<String>,
pub require_secure_transport: bool,
pub jwt_secret: String,
pub rsa_public_key_pem: Option<String>,
pub ec_public_key_pem: Option<String>,
pub ed_public_key_pem: Option<String>,
}
impl Default for SecureJwtConfig {
fn default() -> Self {
use ring::rand::{SecureRandom, SystemRandom};
let rng = SystemRandom::new();
let mut bytes = [0u8; 32];
rng.fill(&mut bytes)
.expect("AuthFramework fatal: system CSPRNG unavailable — the operating system cannot provide cryptographic randomness");
let jwt_secret = bytes.iter().fold(String::with_capacity(64), |mut s, b| {
s.push_str(&format!("{b:02x}"));
s
});
let mut allowed_token_types = HashSet::new();
allowed_token_types.insert("access".to_string());
allowed_token_types.insert("refresh".to_string());
allowed_token_types.insert("JARM".to_string());
let mut required_issuers = HashSet::new();
required_issuers.insert("auth-framework".to_string());
Self {
allowed_algorithms: vec![Algorithm::HS256],
required_issuers,
required_audiences: HashSet::new(),
max_token_lifetime: Duration::from_secs(3600),
clock_skew: Duration::from_secs(30),
require_jti: true,
validate_nbf: true,
allowed_token_types,
require_secure_transport: true,
jwt_secret,
rsa_public_key_pem: None,
ec_public_key_pem: None,
ed_public_key_pem: None,
}
}
}
fn is_hmac_algorithm(alg: Algorithm) -> bool {
matches!(alg, Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512)
}
impl SecureJwtConfig {
pub fn builder() -> SecureJwtConfigBuilder {
SecureJwtConfigBuilder::default()
}
}
pub struct SecureJwtConfigBuilder {
config: SecureJwtConfig,
}
impl Default for SecureJwtConfigBuilder {
fn default() -> Self {
Self {
config: SecureJwtConfig::default(),
}
}
}
impl SecureJwtConfigBuilder {
pub fn with_algorithm(mut self, algo: Algorithm) -> Self {
self.config.allowed_algorithms.push(algo);
self
}
pub fn with_algorithms(mut self, algos: Vec<Algorithm>) -> Self {
self.config.allowed_algorithms = algos;
self
}
pub fn require_issuer(mut self, issuer: impl Into<String>) -> Self {
self.config.required_issuers.insert(issuer.into());
self
}
pub fn require_audience(mut self, audience: impl Into<String>) -> Self {
self.config.required_audiences.insert(audience.into());
self
}
pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
self.config.max_token_lifetime = lifetime;
self
}
pub fn with_clock_skew(mut self, skew: Duration) -> Self {
self.config.clock_skew = skew;
self
}
pub fn require_jti(mut self, require: bool) -> Self {
self.config.require_jti = require;
self
}
pub fn with_secret(mut self, secret: impl Into<String>) -> Self {
self.config.jwt_secret = secret.into();
self
}
pub fn build(self) -> SecureJwtConfig {
self.config
}
}
pub struct SecureJwtValidator {
config: SecureJwtConfig,
revoked_tokens: std::sync::Mutex<std::collections::HashMap<String, std::time::SystemTime>>,
on_revoke: std::sync::Mutex<Option<Box<dyn Fn(&str) + Send + Sync>>>,
}
impl std::fmt::Debug for SecureJwtValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecureJwtValidator")
.field("config", &self.config)
.field("revoked_tokens", &self.revoked_tokens)
.field(
"on_revoke",
&self.on_revoke.lock().ok().map(|g| g.is_some()),
)
.finish()
}
}
impl SecureJwtValidator {
pub fn new(config: SecureJwtConfig) -> Result<Self> {
let has_hmac = config
.allowed_algorithms
.iter()
.any(|a| is_hmac_algorithm(*a));
let has_rsa = config.allowed_algorithms.iter().any(|a| {
matches!(
a,
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512
)
});
let has_ec = config
.allowed_algorithms
.iter()
.any(|a| matches!(a, Algorithm::ES256 | Algorithm::ES384));
let has_eddsa = config
.allowed_algorithms
.iter()
.any(|a| matches!(a, Algorithm::EdDSA));
if has_hmac {
#[cfg(not(test))]
if config.jwt_secret.len() < 32 {
return Err(AuthError::Configuration {
message: "SecureJwtConfig::jwt_secret must be at least 32 characters \
when HMAC algorithms are enabled"
.to_string(),
help: Some(
"Provide a cryptographically random secret unique to your deployment"
.to_string(),
),
docs_url: None,
source: None,
suggested_fix: None,
});
}
}
if has_rsa && config.rsa_public_key_pem.is_none() {
return Err(AuthError::Configuration {
message: "SecureJwtConfig::rsa_public_key_pem must be set when RSA/PS algorithms are enabled".to_string(),
help: Some("Set rsa_public_key_pem in SecureJwtConfig".to_string()),
docs_url: None,
source: None,
suggested_fix: None,
});
}
if has_ec && config.ec_public_key_pem.is_none() {
return Err(AuthError::Configuration {
message:
"SecureJwtConfig::ec_public_key_pem must be set when EC algorithms are enabled"
.to_string(),
help: Some("Set ec_public_key_pem in SecureJwtConfig".to_string()),
docs_url: None,
source: None,
suggested_fix: None,
});
}
if has_eddsa && config.ed_public_key_pem.is_none() {
return Err(AuthError::Configuration {
message: "SecureJwtConfig::ed_public_key_pem must be set when EdDSA is enabled"
.to_string(),
help: Some("Set ed_public_key_pem in SecureJwtConfig".to_string()),
docs_url: None,
source: None,
suggested_fix: None,
});
}
Ok(Self {
config,
revoked_tokens: std::sync::Mutex::new(std::collections::HashMap::new()),
on_revoke: std::sync::Mutex::new(None),
})
}
pub fn set_on_revoke<F>(&self, callback: F)
where
F: Fn(&str) + Send + Sync + 'static,
{
let mut guard = match self.on_revoke.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
*guard = Some(Box::new(callback));
}
pub fn get_decoding_key(&self) -> jsonwebtoken::DecodingKey {
jsonwebtoken::DecodingKey::from_secret(self.config.jwt_secret.as_bytes())
}
pub fn get_encoding_key(&self) -> jsonwebtoken::EncodingKey {
jsonwebtoken::EncodingKey::from_secret(self.config.jwt_secret.as_bytes())
}
fn decoding_key_for_algorithm(&self, alg: Algorithm) -> Result<DecodingKey> {
match alg {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
Ok(DecodingKey::from_secret(self.config.jwt_secret.as_bytes()))
}
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512 => {
let pem = self.config.rsa_public_key_pem.as_deref().ok_or_else(|| {
AuthError::Configuration {
message: "RSA public key PEM not configured".to_string(),
help: Some(
"Set rsa_public_key_pem in SecureJwtConfig for RSA/PS algorithms"
.to_string(),
),
docs_url: None,
source: None,
suggested_fix: None,
}
})?;
DecodingKey::from_rsa_pem(pem.as_bytes()).map_err(|e| AuthError::Configuration {
message: format!("Invalid RSA public key PEM: {e}"),
help: None,
docs_url: None,
source: None,
suggested_fix: None,
})
}
Algorithm::ES256 | Algorithm::ES384 => {
let pem = self.config.ec_public_key_pem.as_deref().ok_or_else(|| {
AuthError::Configuration {
message: "EC public key PEM not configured".to_string(),
help: Some(
"Set ec_public_key_pem in SecureJwtConfig for EC algorithms"
.to_string(),
),
docs_url: None,
source: None,
suggested_fix: None,
}
})?;
DecodingKey::from_ec_pem(pem.as_bytes()).map_err(|e| AuthError::Configuration {
message: format!("Invalid EC public key PEM: {e}"),
help: None,
docs_url: None,
source: None,
suggested_fix: None,
})
}
Algorithm::EdDSA => {
let pem = self.config.ed_public_key_pem.as_deref().ok_or_else(|| {
AuthError::Configuration {
message: "Ed25519 public key PEM not configured".to_string(),
help: Some(
"Set ed_public_key_pem in SecureJwtConfig for EdDSA".to_string(),
),
docs_url: None,
source: None,
suggested_fix: None,
}
})?;
DecodingKey::from_ed_pem(pem.as_bytes()).map_err(|e| AuthError::Configuration {
message: format!("Invalid Ed25519 public key PEM: {e}"),
help: None,
docs_url: None,
source: None,
suggested_fix: None,
})
}
}
}
pub fn validate(&self, token: &str) -> Result<SecureJwtClaims> {
let header = jsonwebtoken::decode_header(token)
.map_err(|e| AuthError::Unauthorized(format!("Invalid JWT header: {e}")))?;
if !self.config.allowed_algorithms.contains(&header.alg) {
return Err(AuthError::Unauthorized(format!(
"Token algorithm {:?} is not permitted; allowed: {:?}",
header.alg, self.config.allowed_algorithms
)));
}
let decoding_key = self.decoding_key_for_algorithm(header.alg)?;
let mut validation = jsonwebtoken::Validation::new(header.alg);
validation.algorithms = self.config.allowed_algorithms.clone();
validation.leeway = self.config.clock_skew.as_secs();
validation.validate_exp = true;
validation.validate_nbf = self.config.validate_nbf;
if !self.config.required_audiences.is_empty() {
validation.set_audience(
&self
.config
.required_audiences
.iter()
.collect::<Vec<&String>>(),
);
} else {
validation.validate_aud = false;
}
if !self.config.required_issuers.is_empty() {
validation.set_issuer(
&self
.config
.required_issuers
.iter()
.collect::<Vec<&String>>(),
);
}
let token_data = jsonwebtoken::decode::<SecureJwtClaims>(token, &decoding_key, &validation)
.map_err(|e| AuthError::Unauthorized(format!("JWT validation failed: {e}")))?;
let claims = token_data.claims;
if self.is_token_revoked(&claims.jti)? {
return Err(AuthError::Unauthorized("Token is revoked".to_string()));
}
let token_lifetime = claims.exp.saturating_sub(claims.iat);
if token_lifetime > 0 && (token_lifetime as u64) > self.config.max_token_lifetime.as_secs()
{
return Err(AuthError::Unauthorized(format!(
"Token lifetime ({token_lifetime}s) exceeds maximum allowed ({}s)",
self.config.max_token_lifetime.as_secs()
)));
}
if self.config.require_jti && claims.jti.is_empty() {
return Err(AuthError::Unauthorized(
"Token missing required JTI claim".to_string(),
));
}
if !self.config.allowed_token_types.is_empty() && !claims.typ.is_empty() {
if !self.config.allowed_token_types.contains(&claims.typ) {
return Err(AuthError::Unauthorized(format!(
"Token type '{}' is not permitted",
claims.typ
)));
}
}
Ok(claims)
}
pub fn validate_token(
&self,
token: &str,
decoding_key: &DecodingKey,
) -> Result<SecureJwtClaims> {
let header = jsonwebtoken::decode_header(token)
.map_err(|e| AuthError::Unauthorized(format!("Invalid JWT header: {e}")))?;
if !self.config.allowed_algorithms.contains(&header.alg) {
return Err(AuthError::Unauthorized(format!(
"Token algorithm {:?} is not permitted; allowed: {:?}",
header.alg, self.config.allowed_algorithms
)));
}
let mut validation = jsonwebtoken::Validation::new(header.alg);
validation.algorithms = self.config.allowed_algorithms.clone();
validation.leeway = self.config.clock_skew.as_secs();
validation.validate_exp = true;
validation.validate_nbf = self.config.validate_nbf;
if !self.config.required_audiences.is_empty() {
validation.set_audience(
&self
.config
.required_audiences
.iter()
.collect::<Vec<&String>>(),
);
} else {
validation.validate_aud = false;
}
if !self.config.required_issuers.is_empty() {
validation.set_issuer(
&self
.config
.required_issuers
.iter()
.collect::<Vec<&String>>(),
);
}
let token_data = jsonwebtoken::decode::<SecureJwtClaims>(token, decoding_key, &validation)
.map_err(|e| AuthError::Unauthorized(format!("JWT validation failed: {e}")))?;
let claims = token_data.claims;
if self.is_token_revoked(&claims.jti)? {
return Err(AuthError::Unauthorized("Token is revoked".to_string()));
}
let token_lifetime = claims.exp.saturating_sub(claims.iat);
if token_lifetime > 0 && (token_lifetime as u64) > self.config.max_token_lifetime.as_secs()
{
return Err(AuthError::Unauthorized(format!(
"Token lifetime ({token_lifetime}s) exceeds maximum allowed ({}s)",
self.config.max_token_lifetime.as_secs()
)));
}
if self.config.require_jti && claims.jti.is_empty() {
return Err(AuthError::Unauthorized(
"Token missing required JTI claim".to_string(),
));
}
if !self.config.allowed_token_types.is_empty() && !claims.typ.is_empty() {
if !self.config.allowed_token_types.contains(&claims.typ) {
return Err(AuthError::Unauthorized(format!(
"Token type '{}' is not permitted",
claims.typ
)));
}
}
Ok(claims)
}
pub fn is_token_revoked(&self, jti: &str) -> Result<bool> {
let revoked_tokens = self.revoked_tokens.lock().map_err(|_| {
AuthError::internal("Lock poisoned — a prior thread panicked while holding this lock")
})?;
Ok(revoked_tokens.contains_key(jti))
}
pub fn revoke_token(&self, jti: &str) -> Result<()> {
{
let mut revoked_tokens = self.revoked_tokens.lock().map_err(|_| {
AuthError::internal(
"Lock poisoned — a prior thread panicked while holding this lock",
)
})?;
revoked_tokens.insert(jti.to_string(), std::time::SystemTime::now());
}
if let Some(ref cb) = *self.on_revoke.lock().map_err(|_| {
AuthError::internal("Lock poisoned — a prior thread panicked while holding this lock")
})? {
cb(jti);
}
Ok(())
}
pub fn cleanup_revoked_tokens(&self, expired_cutoff: std::time::SystemTime) -> Result<()> {
const MAX_REVOKED_TOKENS: usize = 10_000;
let mut revoked_tokens = self.revoked_tokens.lock().map_err(|_| {
AuthError::internal("Lock poisoned — a prior thread panicked while holding this lock")
})?;
revoked_tokens.retain(|_, inserted_at| *inserted_at >= expired_cutoff);
if revoked_tokens.len() > MAX_REVOKED_TOKENS {
let target_len = MAX_REVOKED_TOKENS * 3 / 4;
let mut by_age: Vec<(String, std::time::SystemTime)> = revoked_tokens.drain().collect();
by_age.sort_unstable_by_key(|(_, t)| *t);
for (jti, inserted_at) in by_age.into_iter().rev().take(target_len) {
revoked_tokens.insert(jti, inserted_at);
}
tracing::warn!(
"Revoked token list exceeded {} entries; oldest entries were evicted.",
MAX_REVOKED_TOKENS
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{Algorithm, EncodingKey, Header};
fn test_config() -> SecureJwtConfig {
SecureJwtConfig {
jwt_secret: "a]test_secret_that_is_longer_than_32_chars_for_security!".to_string(),
..SecureJwtConfig::default()
}
}
fn issue_token(config: &SecureJwtConfig, claims: &SecureJwtClaims) -> String {
let key = EncodingKey::from_secret(config.jwt_secret.as_bytes());
jsonwebtoken::encode(&Header::new(Algorithm::HS256), claims, &key).unwrap()
}
fn valid_claims() -> SecureJwtClaims {
let now = chrono::Utc::now().timestamp();
SecureJwtClaims {
sub: "user123".to_string(),
iss: "auth-framework".to_string(),
aud: "test".to_string(),
exp: now + 600,
nbf: now - 10,
iat: now,
jti: uuid::Uuid::new_v4().to_string(),
scope: "read".to_string(),
typ: "access".to_string(),
sid: None,
client_id: None,
auth_ctx_hash: None,
}
}
#[test]
fn test_default_config_generates_random_secret() {
let c1 = SecureJwtConfig::default();
let c2 = SecureJwtConfig::default();
assert_ne!(c1.jwt_secret, c2.jwt_secret);
assert!(c1.jwt_secret.len() >= 32);
}
#[test]
fn test_builder_fluent_api() {
let config = SecureJwtConfig::builder()
.with_secret("a]test_secret_that_is_longer_than_32_chars_for_security!")
.require_issuer("my-issuer")
.require_audience("my-aud")
.with_max_lifetime(Duration::from_secs(7200))
.with_clock_skew(Duration::from_secs(60))
.require_jti(false)
.build();
assert!(config.required_issuers.contains("my-issuer"));
assert!(config.required_audiences.contains("my-aud"));
assert_eq!(config.max_token_lifetime, Duration::from_secs(7200));
assert_eq!(config.clock_skew, Duration::from_secs(60));
assert!(!config.require_jti);
}
#[test]
fn test_validate_valid_token() {
let config = test_config();
let claims = valid_claims();
let token = issue_token(&config, &claims);
let validator = SecureJwtValidator::new(config).unwrap();
let result = validator.validate(&token).unwrap();
assert_eq!(result.sub, "user123");
assert_eq!(result.iss, "auth-framework");
}
#[test]
fn test_validate_rejects_expired_token() {
let config = test_config();
let mut claims = valid_claims();
claims.exp = chrono::Utc::now().timestamp() - 3600;
claims.iat = claims.exp - 600;
let token = issue_token(&config, &claims);
let validator = SecureJwtValidator::new(config).unwrap();
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_validate_rejects_wrong_issuer() {
let config = test_config();
let mut claims = valid_claims();
claims.iss = "evil-issuer".to_string();
let token = issue_token(&config, &claims);
let validator = SecureJwtValidator::new(config).unwrap();
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_revoke_and_check() {
let config = test_config();
let validator = SecureJwtValidator::new(config).unwrap();
let jti = "test-jti-123";
assert!(!validator.is_token_revoked(jti).unwrap());
validator.revoke_token(jti).unwrap();
assert!(validator.is_token_revoked(jti).unwrap());
}
#[test]
fn test_revoked_token_rejected() {
let config = test_config();
let claims = valid_claims();
let jti = claims.jti.clone();
let token = issue_token(&config, &claims);
let validator = SecureJwtValidator::new(config).unwrap();
validator.revoke_token(&jti).unwrap();
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_cleanup_removes_old_entries() {
let config = test_config();
let validator = SecureJwtValidator::new(config).unwrap();
validator.revoke_token("old-jti").unwrap();
let future = std::time::SystemTime::now() + Duration::from_secs(3600);
validator.cleanup_revoked_tokens(future).unwrap();
assert!(!validator.is_token_revoked("old-jti").unwrap());
}
#[test]
fn test_on_revoke_callback() {
let config = test_config();
let validator = SecureJwtValidator::new(config).unwrap();
let captured = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let captured_clone = captured.clone();
validator.set_on_revoke(move |jti| {
captured_clone.lock().unwrap().push(jti.to_string());
});
validator.revoke_token("cb-jti-1").unwrap();
validator.revoke_token("cb-jti-2").unwrap();
let jtis = captured.lock().unwrap();
assert_eq!(jtis.len(), 2);
assert!(jtis.contains(&"cb-jti-1".to_string()));
assert!(jtis.contains(&"cb-jti-2".to_string()));
}
#[test]
fn test_rejects_disallowed_algorithm() {
let config = test_config();
let claims = valid_claims();
let key = EncodingKey::from_secret(config.jwt_secret.as_bytes());
let token =
jsonwebtoken::encode(&Header::new(Algorithm::HS384), &claims, &key).unwrap();
let validator = SecureJwtValidator::new(config).unwrap();
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_rejects_excessive_lifetime() {
let mut config = test_config();
config.max_token_lifetime = Duration::from_secs(300);
let mut claims = valid_claims();
let now = chrono::Utc::now().timestamp();
claims.iat = now;
claims.exp = now + 600; let token = issue_token(&config, &claims);
let validator = SecureJwtValidator::new(config).unwrap();
assert!(validator.validate(&token).is_err());
}
#[test]
fn test_missing_rsa_key_rejected() {
let mut config = test_config();
config.allowed_algorithms = vec![Algorithm::RS256];
let result = SecureJwtValidator::new(config);
assert!(result.is_err());
}
}