use std::sync::Arc;
use std::time::Duration;
use moka::future::Cache;
use super::types::AuthResult;
const DEFAULT_TTL: Duration = Duration::from_secs(300);
const MAX_TTL: Duration = Duration::from_secs(3600);
#[derive(Clone)]
pub struct JwtCache {
inner: Cache<Arc<str>, Arc<AuthResult>>,
}
impl JwtCache {
pub fn new(max_entries: u64) -> Self {
let inner = Cache::builder()
.max_capacity(max_entries)
.time_to_live(MAX_TTL)
.build();
Self { inner }
}
pub async fn get(&self, token: &str) -> Option<Arc<AuthResult>> {
self.inner.get(&Arc::<str>::from(token)).await
}
pub async fn insert(&self, token: &str, result: AuthResult) {
let ttl = ttl_from_claims(&result);
self.inner.insert(Arc::from(token), Arc::new(result)).await;
let _ = ttl; }
pub fn invalidate_all(&self) {
self.inner.invalidate_all();
}
pub fn entry_count(&self) -> u64 {
self.inner.entry_count()
}
}
fn ttl_from_claims(result: &AuthResult) -> Duration {
if let Some(exp) = result.claims.get("exp").and_then(|v| v.as_i64()) {
let now = chrono::Utc::now().timestamp();
if exp > now {
let remaining = Duration::from_secs((exp - now) as u64);
return remaining.min(MAX_TTL);
}
}
DEFAULT_TTL
}
impl std::fmt::Debug for JwtCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtCache")
.field("entry_count", &self.inner.entry_count())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use compact_str::CompactString;
fn make_result(role: &str, exp: Option<i64>) -> AuthResult {
let mut claims = serde_json::Map::new();
claims.insert(
"role".to_string(),
serde_json::Value::String(role.to_string()),
);
if let Some(e) = exp {
claims.insert("exp".to_string(), serde_json::json!(e));
}
AuthResult {
role: CompactString::from(role),
claims,
}
}
#[tokio::test]
async fn test_cache_insert_and_get() {
let cache = JwtCache::new(100);
let result = make_result("admin", Some(chrono::Utc::now().timestamp() + 3600));
cache.insert("token_abc", result.clone()).await;
let cached = cache.get("token_abc").await.unwrap();
assert_eq!(cached.role.as_str(), "admin");
}
#[tokio::test]
async fn test_cache_miss() {
let cache = JwtCache::new(100);
assert!(cache.get("nonexistent").await.is_none());
}
#[tokio::test]
async fn test_cache_invalidate_all() {
let cache = JwtCache::new(100);
let result = make_result("user", Some(chrono::Utc::now().timestamp() + 3600));
cache.insert("token1", result.clone()).await;
cache.insert("token2", result).await;
cache.invalidate_all();
assert!(cache.entry_count() <= 2);
}
#[tokio::test]
async fn test_cache_capacity() {
let cache = JwtCache::new(2);
let result = make_result("user", Some(chrono::Utc::now().timestamp() + 3600));
for i in 0..5 {
cache.insert(&format!("token_{i}"), result.clone()).await;
}
assert!(cache.entry_count() <= 5); }
#[test]
fn test_ttl_from_claims_with_exp() {
let result = make_result("user", Some(chrono::Utc::now().timestamp() + 600));
let ttl = ttl_from_claims(&result);
assert!(ttl.as_secs() >= 598 && ttl.as_secs() <= 601);
}
#[test]
fn test_ttl_from_claims_capped() {
let result = make_result("user", Some(chrono::Utc::now().timestamp() + 7200));
let ttl = ttl_from_claims(&result);
assert_eq!(ttl, MAX_TTL);
}
#[test]
fn test_ttl_from_claims_no_exp() {
let result = make_result("user", None);
let ttl = ttl_from_claims(&result);
assert_eq!(ttl, DEFAULT_TTL);
}
#[test]
fn test_ttl_from_claims_expired() {
let result = make_result("user", Some(chrono::Utc::now().timestamp() - 100));
let ttl = ttl_from_claims(&result);
assert_eq!(ttl, DEFAULT_TTL);
}
}