use super::test_helpers::*;
use crate::jwks;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::{Mutex, RwLock};
const TEST_CACHE_DURATION: u64 = 2; const TEST_CACHE_MAX_AGE: u64 = 5;
struct TestableJwksCache {
cache: Arc<RwLock<Option<jwks::JwksResponse>>>,
expires_at: Arc<RwLock<Option<u64>>>,
cached_at: Arc<RwLock<Option<u64>>>,
current_time: Arc<RwLock<u64>>,
cache_hits: Arc<AtomicU64>,
cache_misses: Arc<AtomicU64>,
refresh_attempts: Arc<AtomicU64>,
refresh_successes: Arc<AtomicU64>,
network_available: Arc<AtomicBool>,
auto_refresh_enabled: Arc<AtomicBool>,
refresh_mutex: Arc<Mutex<()>>,
}
impl TestableJwksCache {
pub fn new() -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
cache: Arc::new(RwLock::new(None)),
expires_at: Arc::new(RwLock::new(None)),
cached_at: Arc::new(RwLock::new(None)),
current_time: Arc::new(RwLock::new(now)),
cache_hits: Arc::new(AtomicU64::new(0)),
cache_misses: Arc::new(AtomicU64::new(0)),
refresh_attempts: Arc::new(AtomicU64::new(0)),
refresh_successes: Arc::new(AtomicU64::new(0)),
network_available: Arc::new(AtomicBool::new(true)),
auto_refresh_enabled: Arc::new(AtomicBool::new(true)),
refresh_mutex: Arc::new(Mutex::new(())),
}
}
pub async fn advance_time(&self, seconds: u64) {
let mut time = self.current_time.write().await;
*time += seconds;
}
pub async fn get_current_time(&self) -> u64 {
*self.current_time.read().await
}
pub async fn set_cache(&self, jwks: jwks::JwksResponse) {
let now = self.get_current_time().await;
*self.cache.write().await = Some(jwks);
*self.expires_at.write().await = Some(now + TEST_CACHE_DURATION);
*self.cached_at.write().await = Some(now);
}
pub async fn get_cached_jwks(&self) -> Option<jwks::JwksResponse> {
let now = self.get_current_time().await;
let expires_at = *self.expires_at.read().await;
if let Some(expires) = expires_at {
if now < expires {
return self.cache.read().await.clone();
}
}
None
}
pub async fn get_stale_cache(&self) -> Option<jwks::JwksResponse> {
let now = self.get_current_time().await;
let cached_at = *self.cached_at.read().await;
if let Some(cache_time) = cached_at {
if now - cache_time <= TEST_CACHE_MAX_AGE {
return self.cache.read().await.clone();
}
}
None
}
pub fn set_network_available(&self, available: bool) {
self.network_available.store(available, Ordering::Relaxed);
}
pub async fn simulate_network_request(&self) -> Result<jwks::JwksResponse, String> {
if !self.network_available.load(Ordering::Relaxed) {
return Err("Network unavailable".to_string());
}
Ok(create_test_jwks_response())
}
pub fn set_auto_refresh_enabled(&self, enabled: bool) {
self.auto_refresh_enabled.store(enabled, Ordering::Relaxed);
}
pub async fn get_jwks_with_auto_refresh(&self) -> Result<jwks::JwksResponse, String> {
if let Some(cached) = self.get_cached_jwks().await {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
return Ok(cached);
}
self.cache_misses.fetch_add(1, Ordering::Relaxed);
if self.auto_refresh_enabled.load(Ordering::Relaxed) {
let _lock = self.refresh_mutex.lock().await;
self.refresh_attempts.fetch_add(1, Ordering::Relaxed);
match self.simulate_network_request().await {
Ok(jwks) => {
self.set_cache(jwks.clone()).await;
self.refresh_successes.fetch_add(1, Ordering::Relaxed);
return Ok(jwks);
}
Err(_) => {
if let Some(stale) = self.get_stale_cache().await {
return Ok(stale);
}
}
}
}
Err("No cache available and refresh failed".to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_valid_cache_hit() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
let initial_time = cache.get_current_time().await;
let cached_jwks = cache.get_cached_jwks().await;
assert!(cached_jwks.is_some());
assert_eq!(cached_jwks.unwrap().keys[0].kid, TEST_KID);
cache.advance_time(1).await;
assert!(cache.get_current_time().await < initial_time + TEST_CACHE_DURATION);
let cached_jwks = cache.get_cached_jwks().await;
assert!(cached_jwks.is_some());
assert_eq!(cached_jwks.unwrap().keys[0].kid, TEST_KID);
}
#[tokio::test]
async fn test_stale_cache_fallback() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
let initial_time = cache.get_current_time().await;
cache.advance_time(TEST_CACHE_DURATION + 1).await;
assert!(cache.get_current_time().await > initial_time + TEST_CACHE_DURATION);
assert!(cache.get_current_time().await < initial_time + TEST_CACHE_MAX_AGE);
let cached_jwks = cache.get_cached_jwks().await;
assert!(cached_jwks.is_none());
let stale_jwks = cache.get_stale_cache().await;
assert!(stale_jwks.is_some());
assert_eq!(stale_jwks.unwrap().keys[0].kid, TEST_KID);
}
#[tokio::test]
async fn test_cache_max_age_expiration() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
let initial_time = cache.get_current_time().await;
cache.advance_time(TEST_CACHE_MAX_AGE + 1).await;
assert!(cache.get_current_time().await > initial_time + TEST_CACHE_MAX_AGE);
let cached_jwks = cache.get_cached_jwks().await;
assert!(cached_jwks.is_none());
let stale_jwks = cache.get_stale_cache().await;
assert!(stale_jwks.is_none());
}
#[tokio::test]
async fn test_cache_refresh_extends_validity() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
let _initial_time = cache.get_current_time().await;
cache.advance_time(1).await;
let updated_jwks = create_test_jwks_response(); cache.set_cache(updated_jwks).await;
let refresh_time = cache.get_current_time().await;
cache.advance_time(1).await;
assert!(cache.get_current_time().await < refresh_time + TEST_CACHE_DURATION);
let cached_jwks = cache.get_cached_jwks().await;
assert!(cached_jwks.is_some());
assert_eq!(cached_jwks.unwrap().keys[0].kid, TEST_KID);
cache.advance_time(2).await;
assert!(cache.get_current_time().await > refresh_time + TEST_CACHE_DURATION);
let cached_jwks = cache.get_cached_jwks().await;
assert!(cached_jwks.is_none());
let stale_jwks = cache.get_stale_cache().await;
assert!(stale_jwks.is_some());
}
#[tokio::test]
async fn test_initial_state_is_empty() {
let cache = TestableJwksCache::new();
let valid_jwks = cache.get_cached_jwks().await;
assert!(valid_jwks.is_none(), "Initial valid cache should be empty");
let stale_jwks = cache.get_stale_cache().await;
assert!(stale_jwks.is_none(), "Initial stale cache should be empty");
}
#[tokio::test]
async fn test_concurrent_access_is_safe() {
let cache = Arc::new(TestableJwksCache::new());
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks).await;
let initial_time = cache.get_current_time().await;
let mut tasks = vec![];
for i in 0..10 {
let cache_clone = Arc::clone(&cache);
tasks.push(tokio::spawn(async move {
let jwks = cache_clone.get_cached_jwks().await;
assert!(jwks.is_some(), "Task {i} should read valid cache");
assert_eq!(jwks.unwrap().keys[0].kid, TEST_KID);
let stale_jwks = cache_clone.get_stale_cache().await;
assert!(stale_jwks.is_some(), "Task {i} should read stale cache");
assert_eq!(stale_jwks.unwrap().keys[0].kid, TEST_KID);
}));
}
for task in tasks {
task.await.unwrap();
}
let final_time = cache.get_current_time().await;
assert_eq!(
final_time, initial_time,
"Time should not advance during read-only operations"
);
let valid_jwks = cache.get_cached_jwks().await;
assert!(
valid_jwks.is_some(),
"Cache should still be valid after concurrent reads"
);
let mut write_tasks = vec![];
let new_jwks = create_test_jwks_response();
let before_write_time = cache.get_current_time().await;
for _ in 0..5 {
let cache_clone = Arc::clone(&cache);
let jwks_clone = new_jwks.clone();
write_tasks.push(tokio::spawn(async move {
cache_clone.set_cache(jwks_clone).await;
}));
}
for task in write_tasks {
task.await.unwrap();
}
let after_write_time = cache.get_current_time().await;
assert_eq!(
before_write_time, after_write_time,
"Time should not advance during concurrent set_cache calls"
);
cache.advance_time(TEST_CACHE_MAX_AGE + 1).await;
let final_time = cache.get_current_time().await;
assert!(final_time > after_write_time + TEST_CACHE_MAX_AGE);
let expired_valid_jwks = cache.get_cached_jwks().await;
assert!(
expired_valid_jwks.is_none(),
"Cache should be expired after advancing time past MAX_AGE"
);
let expired_stale_jwks = cache.get_stale_cache().await;
assert!(
expired_stale_jwks.is_none(),
"Stale cache should also be expired after advancing time past MAX_AGE"
);
}
#[tokio::test]
async fn test_basic_refresh_on_network_available() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
cache.advance_time(TEST_CACHE_DURATION + 1).await;
cache.set_network_available(true);
cache.set_auto_refresh_enabled(true);
let result = cache.get_jwks_with_auto_refresh().await;
assert!(
result.is_ok(),
"Should refresh successfully when network is available"
);
assert_eq!(cache.refresh_attempts.load(Ordering::Relaxed), 1);
assert_eq!(cache.refresh_successes.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_fallback_to_stale_cache_on_network_failure() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
cache.advance_time(TEST_CACHE_DURATION + 1).await;
cache.set_network_available(false);
cache.set_auto_refresh_enabled(true);
let result = cache.get_jwks_with_auto_refresh().await;
assert!(
result.is_ok(),
"Should fallback to stale cache when network fails"
);
assert_eq!(cache.refresh_attempts.load(Ordering::Relaxed), 1);
assert_eq!(cache.refresh_successes.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_disabled_refresh_fallback() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
cache.advance_time(TEST_CACHE_DURATION + 1).await;
cache.set_auto_refresh_enabled(false);
cache.set_network_available(true);
let result = cache.get_jwks_with_auto_refresh().await;
assert!(
result.is_err(),
"Should fail when refresh is disabled and cache is expired"
);
assert_eq!(cache.refresh_attempts.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_network_failure_fallback() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
cache.advance_time(TEST_CACHE_DURATION + 1).await;
cache.set_network_available(false);
let result = cache.get_jwks_with_auto_refresh().await;
assert!(
result.is_ok(),
"Should use stale cache during network failure"
);
assert_eq!(cache.refresh_attempts.load(Ordering::Relaxed), 1);
assert_eq!(cache.refresh_successes.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn test_network_recovery() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
cache.advance_time(TEST_CACHE_DURATION + 1).await;
cache.set_network_available(false);
let _ = cache.get_jwks_with_auto_refresh().await;
cache.set_network_available(true);
let result = cache.get_jwks_with_auto_refresh().await;
assert!(
result.is_ok(),
"Should work normally after network recovery"
);
}
#[tokio::test]
async fn test_cache_expiration_boundaries() {
let cache = TestableJwksCache::new();
let test_jwks = create_test_jwks_response();
cache.set_cache(test_jwks.clone()).await;
let initial_time = cache.get_current_time().await;
cache.advance_time(TEST_CACHE_DURATION - 1).await;
let cached = cache.get_cached_jwks().await;
assert!(
cached.is_some(),
"Cache should still be valid before expiration"
);
cache.advance_time(1).await;
assert_eq!(
cache.get_current_time().await,
initial_time + TEST_CACHE_DURATION
);
let cached = cache.get_cached_jwks().await;
assert!(
cached.is_none(),
"Cache should be expired at expiration time"
);
let stale = cache.get_stale_cache().await;
assert!(
stale.is_some(),
"Stale cache should be available within max age"
);
cache
.advance_time(TEST_CACHE_MAX_AGE - TEST_CACHE_DURATION + 1)
.await;
let stale = cache.get_stale_cache().await;
assert!(
stale.is_none(),
"Stale cache should not be available beyond max age"
);
}
}