use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
use crate::revocation::TokenBlacklist;
#[derive(Debug)]
pub enum AuthError {
InvalidToken(String),
Expired,
Revoked,
EncodingKeyMissing,
Encode(String),
Internal(String),
}
impl fmt::Display for AuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidToken(e) => write!(f, "invalid token: {e}"),
Self::Expired => f.write_str("token expired"),
Self::Revoked => f.write_str("token revoked"),
Self::EncodingKeyMissing => f.write_str("no encoding key configured"),
Self::Encode(e) => write!(f, "encoding failed: {e}"),
Self::Internal(e) => write!(f, "internal error: {e}"),
}
}
}
impl std::error::Error for AuthError {}
pub trait HasJti {
fn jti(&self) -> Option<&str> {
None
}
}
pub struct JwtConfig {
pub decoding_key: DecodingKey,
pub encoding_key: Option<EncodingKey>,
pub validation: Validation,
}
#[derive(Clone)]
pub struct JwtManager {
config: Arc<JwtConfig>,
blacklist: Option<TokenBlacklist>,
}
impl JwtManager {
pub fn new(secret: &[u8]) -> Self {
let mut validation = Validation::new(Algorithm::HS256);
validation.leeway = 60; Self {
config: Arc::new(JwtConfig {
decoding_key: DecodingKey::from_secret(secret),
encoding_key: Some(EncodingKey::from_secret(secret)),
validation,
}),
blacklist: None,
}
}
pub fn verify_only(secret: &[u8]) -> Self {
let mut validation = Validation::new(Algorithm::HS256);
validation.leeway = 60;
Self {
config: Arc::new(JwtConfig {
decoding_key: DecodingKey::from_secret(secret),
encoding_key: None,
validation,
}),
blacklist: None,
}
}
pub fn from_rsa_pem(private_key_pem: &[u8], public_key_pem: &[u8]) -> Result<Self, AuthError> {
let encoding_key = EncodingKey::from_rsa_pem(private_key_pem)
.map_err(|e| AuthError::Internal(format!("RSA private key: {e}")))?;
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem)
.map_err(|e| AuthError::Internal(format!("RSA public key: {e}")))?;
Ok(Self {
config: Arc::new(JwtConfig {
encoding_key: Some(encoding_key),
decoding_key,
validation: Validation::new(Algorithm::RS256),
}),
blacklist: None,
})
}
pub fn from_rsa_public_pem(public_key_pem: &[u8]) -> Result<Self, AuthError> {
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem)
.map_err(|e| AuthError::Internal(format!("RSA public key: {e}")))?;
Ok(Self {
config: Arc::new(JwtConfig {
encoding_key: None,
decoding_key,
validation: Validation::new(Algorithm::RS256),
}),
blacklist: None,
})
}
pub fn from_ec_pem(private_key_pem: &[u8], public_key_pem: &[u8]) -> Result<Self, AuthError> {
let encoding_key = EncodingKey::from_ec_pem(private_key_pem)
.map_err(|e| AuthError::Internal(format!("EC private key: {e}")))?;
let decoding_key = DecodingKey::from_ec_pem(public_key_pem)
.map_err(|e| AuthError::Internal(format!("EC public key: {e}")))?;
Ok(Self {
config: Arc::new(JwtConfig {
encoding_key: Some(encoding_key),
decoding_key,
validation: Validation::new(Algorithm::ES256),
}),
blacklist: None,
})
}
pub fn from_ec_public_pem(public_key_pem: &[u8]) -> Result<Self, AuthError> {
let decoding_key = DecodingKey::from_ec_pem(public_key_pem)
.map_err(|e| AuthError::Internal(format!("EC public key: {e}")))?;
Ok(Self {
config: Arc::new(JwtConfig {
encoding_key: None,
decoding_key,
validation: Validation::new(Algorithm::ES256),
}),
blacklist: None,
})
}
pub fn with_config(config: JwtConfig) -> Self {
Self {
config: Arc::new(config),
blacklist: None,
}
}
pub fn with_blacklist(mut self, blacklist: TokenBlacklist) -> Self {
self.blacklist = Some(blacklist);
self
}
pub fn decode<T>(&self, token: &str) -> Result<T, AuthError>
where
T: for<'de> Deserialize<'de> + HasJti,
{
let token_data = decode::<T>(token, &self.config.decoding_key, &self.config.validation)
.map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::Expired,
_ => AuthError::InvalidToken(e.to_string()),
})?;
if let Some(bl) = &self.blacklist
&& let Some(jti) = token_data.claims.jti()
&& bl.is_revoked(jti)
{
return Err(AuthError::Revoked);
}
Ok(token_data.claims)
}
pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String, AuthError> {
let key = self
.config
.encoding_key
.as_ref()
.ok_or(AuthError::EncodingKeyMissing)?;
encode(&Header::default(), claims, key).map_err(|e| AuthError::Encode(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct TestClaims {
sub: String,
exp: u64,
}
impl HasJti for TestClaims {}
fn far_future_exp() -> u64 {
253_370_764_800_u64
}
#[test]
fn test_encode_decode_roundtrip() {
let mgr = JwtManager::new(b"test-secret-key");
let claims = TestClaims {
sub: "user-42".to_string(),
exp: far_future_exp(),
};
let token = mgr.encode(&claims).expect("encode should succeed");
assert!(!token.is_empty());
let decoded: TestClaims = mgr.decode(&token).expect("decode should succeed");
assert_eq!(decoded, claims);
}
#[test]
fn test_decode_wrong_secret_fails() {
let mgr_sign = JwtManager::new(b"correct-secret");
let mgr_verify = JwtManager::new(b"wrong-secret");
let claims = TestClaims {
sub: "user-1".to_string(),
exp: far_future_exp(),
};
let token = mgr_sign.encode(&claims).expect("encode must succeed");
let result: Result<TestClaims, _> = mgr_verify.decode(&token);
assert!(result.is_err(), "decode with wrong secret should fail");
}
#[test]
fn test_decode_invalid_token_fails() {
let mgr = JwtManager::new(b"any-secret");
let result: Result<TestClaims, _> = mgr.decode("not.a.jwt");
assert!(result.is_err(), "decode of garbage should fail");
}
#[test]
fn test_decode_mangled_token_fails() {
let mgr = JwtManager::new(b"secret");
let claims = TestClaims {
sub: "u".to_string(),
exp: far_future_exp(),
};
let mut token = mgr.encode(&claims).expect("encode ok");
let last = token.pop().unwrap();
token.push(if last == 'A' { 'B' } else { 'A' });
let result: Result<TestClaims, _> = mgr.decode(&token);
assert!(result.is_err(), "mangled token should fail");
}
#[test]
fn test_clone_shares_key() {
let mgr1 = JwtManager::new(b"shared-key");
let mgr2 = mgr1.clone();
let claims = TestClaims {
sub: "u".to_string(),
exp: far_future_exp(),
};
let token = mgr1.encode(&claims).unwrap();
let decoded: TestClaims = mgr2.decode(&token).expect("clone should decode");
assert_eq!(decoded.sub, "u");
}
#[test]
fn test_encode_without_key_returns_error() {
let config = JwtConfig {
decoding_key: DecodingKey::from_secret(b"secret"),
encoding_key: None,
validation: Validation::new(Algorithm::HS256),
};
let mgr = JwtManager::with_config(config);
let claims = TestClaims {
sub: "x".to_string(),
exp: far_future_exp(),
};
let result = mgr.encode(&claims);
assert!(
matches!(result, Err(AuthError::EncodingKeyMissing)),
"expected EncodingKeyMissing, got {result:?}"
);
}
#[test]
fn test_revoked_token_rejected() {
use crate::revocation::TokenBlacklist;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct ClaimsWithJti {
sub: String,
jti: String,
exp: u64,
}
impl HasJti for ClaimsWithJti {
fn jti(&self) -> Option<&str> {
Some(&self.jti)
}
}
let blacklist = TokenBlacklist::new();
let mgr = JwtManager::new(b"s").with_blacklist(blacklist.clone());
let claims = ClaimsWithJti {
sub: "u".into(),
jti: "unique-jti-1".into(),
exp: far_future_exp(),
};
let token = mgr.encode(&claims).unwrap();
mgr.decode::<ClaimsWithJti>(&token)
.expect("should be valid before revocation");
blacklist.revoke("unique-jti-1".into(), None);
let result = mgr.decode::<ClaimsWithJti>(&token);
assert!(
matches!(result, Err(AuthError::Revoked)),
"revoked token should be rejected, got {result:?}"
);
}
#[test]
fn test_expired_token_returns_expired_error() {
let mgr = JwtManager::new(b"secret");
let claims = serde_json::json!({ "sub": "u", "exp": 1_u64 });
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"secret"),
)
.unwrap();
let result: Result<TestClaims, _> = mgr.decode(&token);
assert!(
matches!(result, Err(AuthError::Expired)),
"expected Expired, got {result:?}"
);
}
}