use async_trait::async_trait;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use thiserror::Error;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum JwtSessionError {
#[error("JWT encoding error: {0}")]
EncodingError(String),
#[error("JWT decoding error: {0}")]
DecodingError(String),
#[error("Token not found: {0}")]
TokenNotFound(String),
#[error("Token expired")]
TokenExpired,
#[error("Invalid token")]
InvalidToken,
#[error(
"Invalid HMAC key length: {algorithm:?} requires at least {required} bytes, but got {actual} bytes"
)]
InvalidKeyLength {
algorithm: Algorithm,
required: usize,
actual: usize,
},
}
fn min_hmac_key_length(algorithm: Algorithm) -> Option<usize> {
match algorithm {
Algorithm::HS256 => Some(32),
Algorithm::HS384 => Some(48),
Algorithm::HS512 => Some(64),
_ => None,
}
}
fn validate_hmac_key_length(algorithm: Algorithm, secret: &str) -> Result<(), JwtSessionError> {
if let Some(min_len) = min_hmac_key_length(algorithm)
&& secret.len() < min_len
{
return Err(JwtSessionError::InvalidKeyLength {
algorithm,
required: min_len,
actual: secret.len(),
});
}
Ok(())
}
#[derive(Clone, PartialEq, Eq)]
pub struct JwtConfig {
pub secret: String,
pub algorithm: Algorithm,
pub expiration: u64,
pub issuer: Option<String>,
pub audience: Option<String>,
}
impl std::fmt::Debug for JwtConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtConfig")
.field("secret", &"[REDACTED]")
.field("algorithm", &self.algorithm)
.field("expiration", &self.expiration)
.field("issuer", &self.issuer)
.field("audience", &self.audience)
.finish()
}
}
impl JwtConfig {
pub fn new(secret: String) -> Self {
Self {
secret,
algorithm: Algorithm::HS256,
expiration: 3600,
issuer: None,
audience: None,
}
}
pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn with_expiration(mut self, expiration: u64) -> Self {
self.expiration = expiration;
self
}
pub fn with_issuer(mut self, issuer: String) -> Self {
self.issuer = Some(issuer);
self
}
pub fn with_audience(mut self, audience: String) -> Self {
self.audience = Some(audience);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SessionClaims {
data: serde_json::Value,
exp: usize,
iat: usize,
#[serde(skip_serializing_if = "Option::is_none")]
iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
aud: Option<String>,
}
#[derive(Debug, Clone)]
pub struct JwtSessionBackend {
config: Arc<JwtConfig>,
tokens: Arc<RwLock<HashMap<String, String>>>,
}
impl JwtSessionBackend {
pub fn new(config: JwtConfig) -> Result<Self, JwtSessionError> {
validate_hmac_key_length(config.algorithm, &config.secret)?;
Ok(Self {
config: Arc::new(config),
tokens: Arc::new(RwLock::new(HashMap::new())),
})
}
fn encode_token<T>(&self, data: &T, ttl: Option<u64>) -> Result<String, JwtSessionError>
where
T: Serialize,
{
let now = chrono::Utc::now().timestamp() as usize;
let expiration = ttl.unwrap_or(self.config.expiration);
let json_data = serde_json::to_value(data)
.map_err(|e| JwtSessionError::EncodingError(e.to_string()))?;
let claims = SessionClaims {
data: json_data,
exp: now + expiration as usize,
iat: now,
iss: self.config.issuer.clone(),
aud: self.config.audience.clone(),
};
let header = Header::new(self.config.algorithm);
let encoding_key = EncodingKey::from_secret(self.config.secret.as_bytes());
encode(&header, &claims, &encoding_key)
.map_err(|e| JwtSessionError::EncodingError(e.to_string()))
}
fn decode_token<T>(&self, token: &str) -> Result<T, JwtSessionError>
where
T: for<'de> Deserialize<'de>,
{
let decoding_key = DecodingKey::from_secret(self.config.secret.as_bytes());
let mut validation = Validation::new(self.config.algorithm);
if let Some(ref issuer) = self.config.issuer {
validation.set_issuer(&[issuer]);
}
if let Some(ref audience) = self.config.audience {
validation.set_audience(&[audience]);
}
let token_data =
decode::<SessionClaims>(token, &decoding_key, &validation).map_err(|e| {
match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
JwtSessionError::TokenExpired
}
_ => JwtSessionError::DecodingError(e.to_string()),
}
})?;
serde_json::from_value(token_data.claims.data)
.map_err(|e| JwtSessionError::DecodingError(e.to_string()))
}
}
use super::cache::{SessionBackend, SessionError};
#[async_trait]
impl SessionBackend for JwtSessionBackend {
async fn load<T>(&self, session_key: &str) -> Result<Option<T>, SessionError>
where
T: for<'de> Deserialize<'de> + Send,
{
let tokens = self
.tokens
.read()
.map_err(|e| SessionError::CacheError(format!("Lock error: {}", e)))?;
if let Some(token) = tokens.get(session_key) {
match self.decode_token::<T>(token) {
Ok(data) => Ok(Some(data)),
Err(JwtSessionError::TokenExpired) => Ok(None),
Err(e) => Err(SessionError::CacheError(e.to_string())),
}
} else {
Ok(None)
}
}
async fn save<T>(
&self,
session_key: &str,
data: &T,
ttl: Option<u64>,
) -> Result<(), SessionError>
where
T: Serialize + Send + Sync,
{
let token = self
.encode_token(data, ttl)
.map_err(|e| SessionError::SerializationError(e.to_string()))?;
let mut tokens = self
.tokens
.write()
.map_err(|e| SessionError::CacheError(format!("Lock error: {}", e)))?;
tokens.insert(session_key.to_string(), token);
Ok(())
}
async fn delete(&self, session_key: &str) -> Result<(), SessionError> {
let mut tokens = self
.tokens
.write()
.map_err(|e| SessionError::CacheError(format!("Lock error: {}", e)))?;
tokens.remove(session_key);
Ok(())
}
async fn exists(&self, session_key: &str) -> Result<bool, SessionError> {
let tokens = self
.tokens
.read()
.map_err(|e| SessionError::CacheError(format!("Lock error: {}", e)))?;
if let Some(token) = tokens.get(session_key) {
match self.decode_token::<serde_json::Value>(token) {
Ok(_) => Ok(true),
Err(JwtSessionError::TokenExpired) => Ok(false),
Err(_) => Ok(false),
}
} else {
Ok(false)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header, encode};
use rstest::rstest;
use serde_json::json;
#[rstest]
#[tokio::test]
async fn test_jwt_session_save_and_load() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let session_data = json!({
"user_id": 123,
"username": "test_user",
});
backend
.save("test_session", &session_data, Some(3600))
.await
.unwrap();
let loaded: Option<serde_json::Value> = backend.load("test_session").await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap()["user_id"], 123);
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_expiration() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config.clone()).unwrap();
let now = chrono::Utc::now().timestamp() as usize;
let expired_claims = SessionClaims {
data: json!({
"user_id": 456,
}),
exp: now - 3600,
iat: now - 7200,
iss: None,
aud: None,
};
let header = Header::new(config.algorithm);
let encoding_key = EncodingKey::from_secret(config.secret.as_bytes());
let expired_token = encode(&header, &expired_claims, &encoding_key).unwrap();
backend
.tokens
.write()
.unwrap()
.insert("expired_session".to_string(), expired_token);
let loaded: Option<serde_json::Value> = backend.load("expired_session").await.unwrap();
assert!(loaded.is_none());
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_delete() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let session_data = json!({
"user_id": 789,
});
backend
.save("delete_test", &session_data, Some(3600))
.await
.unwrap();
assert!(backend.exists("delete_test").await.unwrap());
backend.delete("delete_test").await.unwrap();
assert!(!backend.exists("delete_test").await.unwrap());
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_exists() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let session_data = json!({
"user_id": 999,
});
assert!(!backend.exists("non_existent").await.unwrap());
backend
.save("exists_test", &session_data, Some(3600))
.await
.unwrap();
assert!(backend.exists("exists_test").await.unwrap());
}
#[rstest]
#[tokio::test]
async fn test_jwt_with_different_algorithms() {
let config = JwtConfig::new(
"test-secret-key-for-jwt-testing-hs512-algorithm-minimum-64-bytes!!".to_string(),
)
.with_algorithm(Algorithm::HS512);
let backend = JwtSessionBackend::new(config).unwrap();
let session_data = json!({
"user_id": 111,
});
backend
.save("hs512_test", &session_data, Some(3600))
.await
.unwrap();
let loaded: Option<serde_json::Value> = backend.load("hs512_test").await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap()["user_id"], 111);
}
#[rstest]
#[tokio::test]
async fn test_jwt_with_issuer_and_audience() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string())
.with_issuer("test-app".to_string())
.with_audience("test-users".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let session_data = json!({
"user_id": 222,
});
backend
.save("iss_aud_test", &session_data, Some(3600))
.await
.unwrap();
let loaded: Option<serde_json::Value> = backend.load("iss_aud_test").await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap()["user_id"], 222);
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_complex_data() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let session_data = json!({
"user_id": 333,
"username": "complex_user",
"roles": ["admin", "editor"],
"metadata": {
"last_login": "2024-01-01T00:00:00Z",
"preferences": {
"theme": "dark",
"language": "en"
}
}
});
backend
.save("complex_test", &session_data, Some(3600))
.await
.unwrap();
let loaded: Option<serde_json::Value> = backend.load("complex_test").await.unwrap();
let data = loaded.unwrap();
assert_eq!(data["user_id"], 333);
assert_eq!(data["roles"][0], "admin");
assert_eq!(data["metadata"]["preferences"]["theme"], "dark");
}
#[rstest]
#[case::hs256_short_key(Algorithm::HS256, "short-key", 32)]
#[case::hs384_short_key(Algorithm::HS384, "this-key-is-only-32-bytes-long!!", 48)]
#[case::hs512_short_key(Algorithm::HS512, "this-key-is-32-bytes-but-not-64!", 64)]
fn test_jwt_rejects_short_hmac_key(
#[case] algorithm: Algorithm,
#[case] secret: &str,
#[case] expected_min_length: usize,
) {
let config = JwtConfig::new(secret.to_string()).with_algorithm(algorithm);
let result = JwtSessionBackend::new(config);
assert_eq!(
result.unwrap_err(),
JwtSessionError::InvalidKeyLength {
algorithm,
required: expected_min_length,
actual: secret.len(),
}
);
}
#[rstest]
#[case::hs256_exact(Algorithm::HS256, "exactly-32-bytes-long-secret-key")]
#[case::hs384_exact(Algorithm::HS384, "exactly-48-bytes-long-secret-key-for-hs384-algo!")]
#[case::hs512_exact(
Algorithm::HS512,
"exactly-64-bytes-long-secret-key-for-hs512-algorithm-testing!!!!"
)]
fn test_jwt_accepts_minimum_length_hmac_key(
#[case] algorithm: Algorithm,
#[case] secret: &str,
) {
let config = JwtConfig::new(secret.to_string()).with_algorithm(algorithm);
let result = JwtSessionBackend::new(config);
assert!(
result.is_ok(),
"{:?} should accept a {}-byte key",
algorithm,
secret.len()
);
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_load_nonexistent_key() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let loaded: Option<serde_json::Value> = backend.load("unknown_key").await.unwrap();
assert!(loaded.is_none());
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_save_overwrite() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let data_v1 = json!({"version": 1, "name": "old"});
let data_v2 = json!({"version": 2, "name": "new"});
backend
.save("overwrite_key", &data_v1, Some(3600))
.await
.unwrap();
backend
.save("overwrite_key", &data_v2, Some(3600))
.await
.unwrap();
let loaded: Option<serde_json::Value> = backend.load("overwrite_key").await.unwrap();
let loaded = loaded.unwrap();
assert_eq!(loaded["version"], 2);
assert_eq!(loaded["name"], "new");
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_delete_then_load() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let data = json!({"to_delete": true});
backend
.save("del_load_key", &data, Some(3600))
.await
.unwrap();
backend.delete("del_load_key").await.unwrap();
let loaded: Option<serde_json::Value> = backend.load("del_load_key").await.unwrap();
assert!(loaded.is_none());
}
#[rstest]
#[tokio::test]
async fn test_jwt_session_exists_after_delete() {
let config = JwtConfig::new("test-secret-key-for-jwt-testing!!".to_string());
let backend = JwtSessionBackend::new(config).unwrap();
let data = json!({"exists_check": true});
backend
.save("exists_del_key", &data, Some(3600))
.await
.unwrap();
assert!(backend.exists("exists_del_key").await.unwrap());
backend.delete("exists_del_key").await.unwrap();
assert!(!backend.exists("exists_del_key").await.unwrap());
}
}