use chrono::{Duration, Utc};
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, dangerous::insecure_decode,
decode, encode,
};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::collections::HashMap;
use crate::error::{Error, Result, TokenError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum JwtAlgorithm {
#[default]
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
ES256,
ES384,
}
impl From<JwtAlgorithm> for Algorithm {
fn from(alg: JwtAlgorithm) -> Self {
match alg {
JwtAlgorithm::HS256 => Algorithm::HS256,
JwtAlgorithm::HS384 => Algorithm::HS384,
JwtAlgorithm::HS512 => Algorithm::HS512,
JwtAlgorithm::RS256 => Algorithm::RS256,
JwtAlgorithm::RS384 => Algorithm::RS384,
JwtAlgorithm::RS512 => Algorithm::RS512,
JwtAlgorithm::ES256 => Algorithm::ES256,
JwtAlgorithm::ES384 => Algorithm::ES384,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Claims {
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, serde_json::Value>,
}
impl Claims {
pub fn new() -> Self {
Self::default()
}
pub fn is_expired(&self) -> bool {
if let Some(exp) = self.exp {
Utc::now().timestamp() > exp
} else {
false
}
}
pub fn is_not_yet_valid(&self) -> bool {
if let Some(nbf) = self.nbf {
Utc::now().timestamp() < nbf
} else {
false
}
}
pub fn get_custom<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.custom
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
}
#[derive(Debug, Clone, Default)]
pub struct JwtBuilder {
claims: Claims,
algorithm: JwtAlgorithm,
}
impl JwtBuilder {
pub fn new() -> Self {
let mut builder = Self::default();
builder.claims.iat = Some(Utc::now().timestamp());
builder
}
pub fn subject(mut self, sub: impl Into<String>) -> Self {
self.claims.sub = Some(sub.into());
self
}
pub fn issuer(mut self, iss: impl Into<String>) -> Self {
self.claims.iss = Some(iss.into());
self
}
pub fn audience(mut self, aud: impl Into<String>) -> Self {
self.claims.aud = Some(aud.into());
self
}
pub fn expires_in_seconds(mut self, seconds: i64) -> Self {
self.claims.exp = Some(Utc::now().timestamp() + seconds);
self
}
pub fn expires_in_minutes(self, minutes: i64) -> Self {
self.expires_in_seconds(minutes * 60)
}
pub fn expires_in_hours(self, hours: i64) -> Self {
self.expires_in_seconds(hours * 3600)
}
pub fn expires_in_days(self, days: i64) -> Self {
self.expires_in_seconds(days * 86400)
}
pub fn expires_in(mut self, duration: Duration) -> Self {
self.claims.exp = Some(Utc::now().timestamp() + duration.num_seconds());
self
}
pub fn not_before(mut self, nbf: i64) -> Self {
self.claims.nbf = Some(nbf);
self
}
pub fn not_before_in_seconds(mut self, seconds: i64) -> Self {
self.claims.nbf = Some(Utc::now().timestamp() + seconds);
self
}
pub fn jwt_id(mut self, jti: impl Into<String>) -> Self {
self.claims.jti = Some(jti.into());
self
}
pub fn with_random_jwt_id(mut self) -> Self {
self.claims.jti = Some(crate::random::generate_random_hex(16).unwrap_or_default());
self
}
pub fn claim<V: Serialize>(mut self, key: impl Into<String>, value: V) -> Self {
if let Ok(json_value) = serde_json::to_value(value) {
self.claims.custom.insert(key.into(), json_value);
}
self
}
pub fn algorithm(mut self, algorithm: JwtAlgorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn build_with_secret(self, secret: &[u8]) -> Result<String> {
let header = Header::new(self.algorithm.into());
let key = EncodingKey::from_secret(secret);
encode(&header, &self.claims, &key).map_err(|e| {
Error::Token(TokenError::EncodingFailed(format!(
"failed to encode JWT: {}",
e
)))
})
}
pub fn build_with_rsa_private_key(self, private_key_pem: &[u8]) -> Result<String> {
let header = Header::new(self.algorithm.into());
let key = EncodingKey::from_rsa_pem(private_key_pem).map_err(|e| {
Error::Token(TokenError::EncodingFailed(format!(
"invalid RSA key: {}",
e
)))
})?;
encode(&header, &self.claims, &key).map_err(|e| {
Error::Token(TokenError::EncodingFailed(format!(
"failed to encode JWT: {}",
e
)))
})
}
pub fn build_with_ec_private_key(self, private_key_pem: &[u8]) -> Result<String> {
let header = Header::new(self.algorithm.into());
let key = EncodingKey::from_ec_pem(private_key_pem).map_err(|e| {
Error::Token(TokenError::EncodingFailed(format!("invalid EC key: {}", e)))
})?;
encode(&header, &self.claims, &key).map_err(|e| {
Error::Token(TokenError::EncodingFailed(format!(
"failed to encode JWT: {}",
e
)))
})
}
pub fn get_claims(&self) -> &Claims {
&self.claims
}
}
#[derive(Debug, Clone)]
pub struct JwtValidatorConfig {
pub validate_exp: bool,
pub validate_nbf: bool,
pub issuer: Option<String>,
pub audience: Option<String>,
pub leeway: u64,
pub algorithms: Vec<JwtAlgorithm>,
}
impl Default for JwtValidatorConfig {
fn default() -> Self {
Self {
validate_exp: true,
validate_nbf: true,
issuer: None,
audience: None,
leeway: 0,
algorithms: vec![JwtAlgorithm::HS256],
}
}
}
impl JwtValidatorConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn with_leeway(mut self, leeway: u64) -> Self {
self.leeway = leeway;
self
}
pub fn with_algorithms(mut self, algorithms: Vec<JwtAlgorithm>) -> Self {
self.algorithms = algorithms;
self
}
pub fn without_exp_validation(mut self) -> Self {
self.validate_exp = false;
self
}
pub fn without_nbf_validation(mut self) -> Self {
self.validate_nbf = false;
self
}
}
pub struct JwtValidator {
decoding_key: DecodingKey,
config: JwtValidatorConfig,
}
impl JwtValidator {
pub fn new(secret: &[u8]) -> Self {
Self {
decoding_key: DecodingKey::from_secret(secret),
config: JwtValidatorConfig::default(),
}
}
pub fn with_config(secret: &[u8], config: JwtValidatorConfig) -> Self {
Self {
decoding_key: DecodingKey::from_secret(secret),
config,
}
}
pub fn from_rsa_public_key(public_key_pem: &[u8]) -> Result<Self> {
let key = DecodingKey::from_rsa_pem(public_key_pem).map_err(|e| {
Error::Token(TokenError::DecodingFailed(format!(
"invalid RSA public key: {}",
e
)))
})?;
Ok(Self {
decoding_key: key,
config: JwtValidatorConfig {
algorithms: vec![
JwtAlgorithm::RS256,
JwtAlgorithm::RS384,
JwtAlgorithm::RS512,
],
..Default::default()
},
})
}
pub fn from_ec_public_key(public_key_pem: &[u8]) -> Result<Self> {
let key = DecodingKey::from_ec_pem(public_key_pem).map_err(|e| {
Error::Token(TokenError::DecodingFailed(format!(
"invalid EC public key: {}",
e
)))
})?;
Ok(Self {
decoding_key: key,
config: JwtValidatorConfig {
algorithms: vec![JwtAlgorithm::ES256, JwtAlgorithm::ES384],
..Default::default()
},
})
}
pub fn set_config(&mut self, config: JwtValidatorConfig) {
self.config = config;
}
pub fn validate(&self, token: &str) -> Result<Claims> {
self.validate_with_claims::<Claims>(token)
}
pub fn validate_with_claims<T: DeserializeOwned>(&self, token: &str) -> Result<T> {
let validation = self.build_validation();
let token_data: TokenData<T> =
decode(token, &self.decoding_key, &validation).map_err(|e| {
let error = match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => TokenError::Expired,
jsonwebtoken::errors::ErrorKind::InvalidSignature => {
TokenError::InvalidSignature
}
jsonwebtoken::errors::ErrorKind::InvalidToken => {
TokenError::InvalidFormat("invalid token structure".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
TokenError::InvalidClaim("invalid issuer".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidAudience => {
TokenError::InvalidClaim("invalid audience".to_string())
}
jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
TokenError::InvalidClaim("token not yet valid".to_string())
}
_ => TokenError::DecodingFailed(e.to_string()),
};
Error::Token(error)
})?;
Ok(token_data.claims)
}
pub fn decode_without_validation(token: &str) -> Result<Claims> {
let token_data: TokenData<Claims> = insecure_decode(token).map_err(|e| {
Error::Token(TokenError::DecodingFailed(format!(
"failed to decode JWT: {}",
e
)))
})?;
Ok(token_data.claims)
}
fn build_validation(&self) -> Validation {
let algorithms: Vec<Algorithm> =
self.config.algorithms.iter().map(|a| (*a).into()).collect();
let mut validation = Validation::new(algorithms[0]);
validation.algorithms = algorithms;
validation.validate_exp = self.config.validate_exp;
validation.validate_nbf = self.config.validate_nbf;
validation.leeway = self.config.leeway;
if let Some(ref iss) = self.config.issuer {
validation.set_issuer(&[iss]);
}
if let Some(ref aud) = self.config.audience {
validation.set_audience(&[aud]);
}
validation
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
pub access_token_expires_at: i64,
pub refresh_token_expires_at: i64,
pub token_type: String,
}
pub struct TokenPairGenerator {
secret: Vec<u8>,
issuer: Option<String>,
audience: Option<String>,
access_token_lifetime: Duration,
refresh_token_lifetime: Duration,
}
impl TokenPairGenerator {
pub fn new(secret: &[u8]) -> Self {
Self {
secret: secret.to_vec(),
issuer: None,
audience: None,
access_token_lifetime: Duration::hours(1), refresh_token_lifetime: Duration::days(7), }
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn with_access_token_lifetime(mut self, duration: Duration) -> Self {
self.access_token_lifetime = duration;
self
}
pub fn with_refresh_token_lifetime(mut self, duration: Duration) -> Self {
self.refresh_token_lifetime = duration;
self
}
pub fn generate(&self, subject: impl Into<String>) -> Result<TokenPair> {
let sub = subject.into();
let now = Utc::now().timestamp();
let access_exp = now + self.access_token_lifetime.num_seconds();
let refresh_exp = now + self.refresh_token_lifetime.num_seconds();
let mut access_builder = JwtBuilder::new()
.subject(&sub)
.expires_in(self.access_token_lifetime)
.claim("type", "access");
if let Some(ref iss) = self.issuer {
access_builder = access_builder.issuer(iss);
}
if let Some(ref aud) = self.audience {
access_builder = access_builder.audience(aud);
}
let access_token = access_builder.build_with_secret(&self.secret)?;
let mut refresh_builder = JwtBuilder::new()
.subject(&sub)
.expires_in(self.refresh_token_lifetime)
.with_random_jwt_id()
.claim("type", "refresh");
if let Some(ref iss) = self.issuer {
refresh_builder = refresh_builder.issuer(iss);
}
if let Some(ref aud) = self.audience {
refresh_builder = refresh_builder.audience(aud);
}
let refresh_token = refresh_builder.build_with_secret(&self.secret)?;
Ok(TokenPair {
access_token,
refresh_token,
access_token_expires_at: access_exp,
refresh_token_expires_at: refresh_exp,
token_type: "Bearer".to_string(),
})
}
pub fn refresh(&self, refresh_token: &str) -> Result<TokenPair> {
let validator = JwtValidator::new(&self.secret);
let claims = validator.validate(refresh_token)?;
let token_type: Option<String> = claims.get_custom("type");
if token_type.as_deref() != Some("refresh") {
return Err(Error::Token(TokenError::InvalidClaim(
"not a refresh token".to_string(),
)));
}
let subject = claims
.sub
.ok_or_else(|| Error::Token(TokenError::MissingClaim("sub".to_string())))?;
self.generate(subject)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &[u8] = b"test-secret-key-at-least-32-bytes!";
#[test]
fn test_jwt_builder_basic() {
let token = JwtBuilder::new()
.subject("user123")
.issuer("test-app")
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
assert!(!token.is_empty());
assert_eq!(token.matches('.').count(), 2); }
#[test]
fn test_jwt_validate() {
let token = JwtBuilder::new()
.subject("user123")
.issuer("test-app")
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let validator = JwtValidator::new(TEST_SECRET);
let claims = validator.validate(&token).unwrap();
assert_eq!(claims.sub, Some("user123".to_string()));
assert_eq!(claims.iss, Some("test-app".to_string()));
}
#[test]
fn test_jwt_custom_claims() {
let token = JwtBuilder::new()
.subject("user123")
.claim("role", "admin")
.claim("permissions", vec!["read", "write"])
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let validator = JwtValidator::new(TEST_SECRET);
let claims = validator.validate(&token).unwrap();
let role: Option<String> = claims.get_custom("role");
assert_eq!(role, Some("admin".to_string()));
let permissions: Option<Vec<String>> = claims.get_custom("permissions");
assert_eq!(
permissions,
Some(vec!["read".to_string(), "write".to_string()])
);
}
#[test]
fn test_jwt_expired() {
let token = JwtBuilder::new()
.subject("user123")
.expires_in_seconds(-10) .build_with_secret(TEST_SECRET)
.unwrap();
let validator = JwtValidator::new(TEST_SECRET);
let result = validator.validate(&token);
assert!(result.is_err());
if let Err(Error::Token(TokenError::Expired)) = result {
} else {
panic!("Expected TokenError::Expired");
}
}
#[test]
fn test_jwt_invalid_signature() {
let token = JwtBuilder::new()
.subject("user123")
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let wrong_secret = b"wrong-secret-key-at-least-32-bytes!";
let validator = JwtValidator::new(wrong_secret);
let result = validator.validate(&token);
assert!(result.is_err());
}
#[test]
fn test_jwt_validator_config() {
let token = JwtBuilder::new()
.subject("user123")
.issuer("test-app")
.audience("api")
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let config = JwtValidatorConfig::new()
.with_issuer("test-app")
.with_audience("api");
let validator = JwtValidator::with_config(TEST_SECRET, config);
let claims = validator.validate(&token).unwrap();
assert_eq!(claims.sub, Some("user123".to_string()));
}
#[test]
fn test_jwt_validator_wrong_issuer() {
let token = JwtBuilder::new()
.subject("user123")
.issuer("test-app")
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let config = JwtValidatorConfig::new().with_issuer("wrong-app");
let validator = JwtValidator::with_config(TEST_SECRET, config);
let result = validator.validate(&token);
assert!(result.is_err());
}
#[test]
fn test_token_pair_generator() {
let generator = TokenPairGenerator::new(TEST_SECRET)
.with_issuer("test-app")
.with_access_token_lifetime(Duration::minutes(15))
.with_refresh_token_lifetime(Duration::days(7));
let pair = generator.generate("user123").unwrap();
assert!(!pair.access_token.is_empty());
assert!(!pair.refresh_token.is_empty());
assert_eq!(pair.token_type, "Bearer");
assert!(pair.access_token_expires_at > Utc::now().timestamp());
assert!(pair.refresh_token_expires_at > pair.access_token_expires_at);
}
#[test]
fn test_token_refresh() {
let generator = TokenPairGenerator::new(TEST_SECRET);
let pair = generator.generate("user123").unwrap();
std::thread::sleep(std::time::Duration::from_millis(1100));
let new_pair = generator.refresh(&pair.refresh_token).unwrap();
assert!(!new_pair.access_token.is_empty());
assert!(new_pair.access_token_expires_at >= pair.access_token_expires_at);
}
#[test]
fn test_cannot_refresh_with_access_token() {
let generator = TokenPairGenerator::new(TEST_SECRET);
let pair = generator.generate("user123").unwrap();
let result = generator.refresh(&pair.access_token);
assert!(result.is_err());
}
#[test]
fn test_claims_is_expired() {
let mut claims = Claims::new();
assert!(!claims.is_expired());
claims.exp = Some(Utc::now().timestamp() - 100);
assert!(claims.is_expired());
claims.exp = Some(Utc::now().timestamp() + 100);
assert!(!claims.is_expired());
}
#[test]
fn test_decode_without_validation() {
let token = JwtBuilder::new()
.subject("user123")
.claim("data", "test")
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let claims = JwtValidator::decode_without_validation(&token).unwrap();
assert_eq!(claims.sub, Some("user123".to_string()));
}
#[test]
fn test_jwt_algorithm() {
let token = JwtBuilder::new()
.algorithm(JwtAlgorithm::HS512)
.subject("user123")
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let config = JwtValidatorConfig::new().with_algorithms(vec![JwtAlgorithm::HS512]);
let validator = JwtValidator::with_config(TEST_SECRET, config);
let claims = validator.validate(&token).unwrap();
assert_eq!(claims.sub, Some("user123".to_string()));
}
#[test]
fn test_jwt_with_random_id() {
let token = JwtBuilder::new()
.subject("user123")
.with_random_jwt_id()
.expires_in_hours(1)
.build_with_secret(TEST_SECRET)
.unwrap();
let validator = JwtValidator::new(TEST_SECRET);
let claims = validator.validate(&token).unwrap();
assert!(claims.jti.is_some());
assert!(!claims.jti.unwrap().is_empty());
}
}