use crate::errors::{AuthError, Result};
use chrono::Utc;
use jsonwebtoken::{Algorithm, Validation};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SecurityLevel {
Minimum,
Recommended,
Maximum,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum CryptoStrength {
Weak,
Acceptable,
Strong,
High,
}
#[derive(Debug, Clone)]
pub struct JwtBestPracticesConfig {
pub security_level: SecurityLevel,
pub allowed_algorithms: Vec<Algorithm>,
pub forbidden_algorithms: Vec<Algorithm>,
pub max_lifetime: i64,
pub min_lifetime: i64,
pub clock_skew: i64,
pub required_issuers: HashSet<String>,
pub required_audiences: HashSet<String>,
pub require_subject: bool,
pub require_issued_at: bool,
pub require_expiration: bool,
pub require_not_before: bool,
pub require_jwt_id: bool,
pub max_nested_depth: u8,
}
impl Default for JwtBestPracticesConfig {
fn default() -> Self {
Self {
security_level: SecurityLevel::Recommended,
allowed_algorithms: vec![
Algorithm::RS256,
Algorithm::RS384,
Algorithm::RS512,
Algorithm::ES256,
Algorithm::ES384,
Algorithm::EdDSA,
Algorithm::PS256,
Algorithm::PS384,
Algorithm::PS512,
Algorithm::EdDSA,
],
forbidden_algorithms: vec![],
max_lifetime: 3600, min_lifetime: 60, clock_skew: 30, required_issuers: HashSet::new(),
required_audiences: HashSet::new(),
require_subject: true,
require_issued_at: true,
require_expiration: true,
require_not_before: false,
require_jwt_id: false,
max_nested_depth: 1,
}
}
}
impl JwtBestPracticesConfig {
pub fn minimum_security() -> Self {
Self {
security_level: SecurityLevel::Minimum,
allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256, Algorithm::PS256],
max_lifetime: 86400, require_subject: false,
require_issued_at: false,
require_jwt_id: false,
..Default::default()
}
}
pub fn maximum_security() -> Self {
Self {
security_level: SecurityLevel::Maximum,
allowed_algorithms: vec![
Algorithm::ES384,
Algorithm::EdDSA,
Algorithm::PS384,
Algorithm::PS512,
Algorithm::EdDSA,
],
forbidden_algorithms: vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512],
max_lifetime: 900, min_lifetime: 30, clock_skew: 5, require_subject: true,
require_issued_at: true,
require_expiration: true,
require_not_before: true,
require_jwt_id: true,
max_nested_depth: 0, ..Default::default()
}
}
}
pub fn get_algorithm_crypto_strength(algorithm: &Algorithm) -> CryptoStrength {
match algorithm {
Algorithm::HS256 => CryptoStrength::Acceptable,
Algorithm::HS384 => CryptoStrength::Strong,
Algorithm::HS512 => CryptoStrength::Strong,
Algorithm::RS256 => CryptoStrength::Acceptable,
Algorithm::RS384 => CryptoStrength::Strong,
Algorithm::RS512 => CryptoStrength::Strong,
Algorithm::ES256 => CryptoStrength::Strong,
Algorithm::ES384 => CryptoStrength::High,
Algorithm::EdDSA => CryptoStrength::High,
Algorithm::PS256 => CryptoStrength::Strong,
Algorithm::PS384 => CryptoStrength::High,
Algorithm::PS512 => CryptoStrength::High,
}
}
pub fn is_algorithm_symmetric(algorithm: &Algorithm) -> bool {
matches!(
algorithm,
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512
)
}
pub fn is_algorithm_asymmetric(algorithm: &Algorithm) -> bool {
!is_algorithm_symmetric(algorithm)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecureJwtClaims {
pub iss: String,
pub sub: String,
pub aud: Vec<String>,
pub exp: i64,
pub nbf: Option<i64>,
pub iat: i64,
pub jti: String,
#[serde(flatten)]
pub custom: HashMap<String, Value>,
}
pub struct JwtBestPracticesValidator {
config: JwtBestPracticesConfig,
used_jtis: HashSet<String>, }
impl JwtBestPracticesValidator {
pub fn new(config: JwtBestPracticesConfig) -> Self {
Self {
config,
used_jtis: HashSet::new(),
}
}
pub fn validate_token_format(&self, token: &str) -> Result<()> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::token("Invalid JWT format".to_string()));
}
if token.len() > 8192 {
return Err(AuthError::token("Token too large".to_string()));
}
Ok(())
}
pub fn validate_algorithm(&self, algorithm: &Algorithm) -> Result<()> {
if self.config.forbidden_algorithms.contains(algorithm) {
return Err(AuthError::token(format!(
"Forbidden algorithm: {:?}",
algorithm
)));
}
if !self.config.allowed_algorithms.contains(algorithm) {
return Err(AuthError::token(format!(
"Algorithm not allowed: {:?}",
algorithm
)));
}
let strength = get_algorithm_crypto_strength(algorithm);
match self.config.security_level {
SecurityLevel::Minimum => {
if strength < CryptoStrength::Acceptable {
return Err(AuthError::token("Algorithm too weak".to_string()));
}
}
SecurityLevel::Recommended => {
if strength < CryptoStrength::Strong {
return Err(AuthError::token("Algorithm not recommended".to_string()));
}
}
SecurityLevel::Maximum => {
if strength < CryptoStrength::High {
return Err(AuthError::token(
"Algorithm insufficient for maximum security".to_string(),
));
}
}
}
Ok(())
}
pub fn validate_standard_claims(&mut self, claims: &SecureJwtClaims) -> Result<()> {
let now = Utc::now().timestamp();
if claims.exp <= now {
return Err(AuthError::token("Token has expired".to_string()));
}
if let Some(nbf) = claims.nbf
&& nbf > now + self.config.clock_skew
{
return Err(AuthError::token("Token is not yet valid".to_string()));
}
if claims.iat > now + self.config.clock_skew {
return Err(AuthError::token("Token issued in the future".to_string()));
}
let lifetime = claims.exp - claims.iat;
if lifetime > self.config.max_lifetime {
return Err(AuthError::token("Token lifetime too long".to_string()));
}
if lifetime < self.config.min_lifetime {
return Err(AuthError::token("Token lifetime too short".to_string()));
}
if !self.config.required_issuers.is_empty()
&& !self.config.required_issuers.contains(&claims.iss)
{
return Err(AuthError::token("Invalid issuer".to_string()));
}
if !self.config.required_audiences.is_empty() {
let has_valid_audience = claims
.aud
.iter()
.any(|aud| self.config.required_audiences.contains(aud));
if !has_valid_audience {
return Err(AuthError::token("Invalid audience".to_string()));
}
}
if self.config.require_jwt_id {
if self.used_jtis.contains(&claims.jti) {
return Err(AuthError::token("Token replay detected".to_string()));
}
self.used_jtis.insert(claims.jti.clone());
}
Ok(())
}
pub fn create_validation_rules(&self, algorithm: Algorithm) -> Result<Validation> {
let mut validation = Validation::new(algorithm);
validation.leeway = self.config.clock_skew as u64;
validation.validate_exp = self.config.require_expiration;
validation.validate_nbf = self.config.require_not_before;
if !self.config.required_issuers.is_empty() {
let issuers: Vec<&str> = self
.config
.required_issuers
.iter()
.map(|s| s.as_str())
.collect();
validation.set_issuer(&issuers);
}
if !self.config.required_audiences.is_empty() {
let audiences: Vec<&str> = self
.config
.required_audiences
.iter()
.map(|s| s.as_str())
.collect();
validation.set_audience(&audiences);
}
Ok(validation)
}
pub fn get_security_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
if self
.config
.allowed_algorithms
.iter()
.any(is_algorithm_symmetric)
{
recommendations.push(
"Consider using asymmetric algorithms (RS*, ES*, PS*) for better security"
.to_string(),
);
}
if self.config.max_lifetime > 3600 {
recommendations.push("Consider reducing token lifetime to 1 hour or less".to_string());
}
if !self.config.require_jwt_id {
recommendations
.push("Consider enabling JWT ID (jti) claim for replay protection".to_string());
}
if !self.config.require_issued_at {
recommendations
.push("Consider requiring issued at (iat) claim for better validation".to_string());
}
recommendations
}
pub fn clear_used_jtis(&mut self) {
self.used_jtis.clear();
}
pub fn get_config(&self) -> &JwtBestPracticesConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_algorithm_strength_classification() {
assert_eq!(
get_algorithm_crypto_strength(&Algorithm::HS256),
CryptoStrength::Acceptable
);
assert_eq!(
get_algorithm_crypto_strength(&Algorithm::ES384),
CryptoStrength::High
);
assert_eq!(
get_algorithm_crypto_strength(&Algorithm::EdDSA),
CryptoStrength::High
);
}
#[test]
fn test_security_level_configuration() {
let min_config = JwtBestPracticesConfig::minimum_security();
let max_config = JwtBestPracticesConfig::maximum_security();
assert_eq!(min_config.security_level, SecurityLevel::Minimum);
assert_eq!(max_config.security_level, SecurityLevel::Maximum);
assert!(max_config.max_lifetime < min_config.max_lifetime);
assert!(max_config.require_jwt_id);
assert!(!min_config.require_jwt_id);
}
#[test]
fn test_jwt_best_practices_validation() {
let config = JwtBestPracticesConfig::default();
let validator = JwtBestPracticesValidator::new(config);
assert!(validator.validate_algorithm(&Algorithm::ES256).is_ok());
}
#[test]
fn test_token_format_validation() {
let config = JwtBestPracticesConfig::default();
let validator = JwtBestPracticesValidator::new(config);
assert!(
validator
.validate_token_format("header.payload.signature")
.is_ok()
);
assert!(validator.validate_token_format("invalid.format").is_err());
assert!(
validator
.validate_token_format("too.many.parts.here")
.is_err()
);
}
}