use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::time::interval;
use tracing::{debug, info};
use super::token::TokenInfo;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_size: usize,
pub ttl: Duration,
pub cleanup_interval: Duration,
pub enable_stats: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size: 1000,
ttl: Duration::from_secs(3600), cleanup_interval: Duration::from_secs(300), enable_stats: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub cleanups: u64,
pub current_size: usize,
}
#[derive(Debug, Clone)]
struct CacheEntry {
token_info: TokenInfo,
created_at: Instant,
last_accessed: Instant,
expires_at: Instant,
}
impl CacheEntry {
fn new(token_info: TokenInfo, ttl: Duration) -> Self {
let now = Instant::now();
let expires_at = now + ttl;
Self {
token_info,
created_at: now,
last_accessed: now,
expires_at,
}
}
fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
#[allow(dead_code)]
fn update_access(&mut self) {
self.last_accessed = Instant::now();
}
}
#[derive(Debug)]
pub struct MemoryTokenCache {
config: CacheConfig,
cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
stats: Arc<RwLock<CacheStats>>,
cleanup_handle: Option<tokio::task::JoinHandle<()>>,
}
impl MemoryTokenCache {
pub fn new(config: CacheConfig) -> Self {
let cache = Arc::new(RwLock::new(HashMap::new()));
let stats = Arc::new(RwLock::new(CacheStats::default()));
let mut instance = Self {
config,
cache,
stats,
cleanup_handle: None,
};
instance.start_cleanup_task();
instance
}
fn start_cleanup_task(&mut self) {
if self.config.cleanup_interval.is_zero() {
return;
}
let cache = self.cache.clone();
let stats = self.stats.clone();
let cleanup_interval = self.config.cleanup_interval;
let handle = tokio::spawn(async move {
let mut interval = interval(cleanup_interval);
loop {
interval.tick().await;
Self::cleanup_expired_entries(&cache, &stats);
}
});
self.cleanup_handle = Some(handle);
}
fn cleanup_expired_entries(
cache: &Arc<RwLock<HashMap<String, CacheEntry>>>,
stats: &Arc<RwLock<CacheStats>>,
) {
let mut cache_guard = cache.write().unwrap_or_else(|e| e.into_inner());
let before_size = cache_guard.len();
cache_guard.retain(|_, entry| !entry.is_expired());
let after_size = cache_guard.len();
if before_size != after_size {
let mut stats_guard = stats.write().unwrap_or_else(|e| e.into_inner());
stats_guard.cleanups += 1;
stats_guard.current_size = after_size;
}
debug!(
"Cleaned up {} expired cache entries, current size: {}",
before_size - after_size,
after_size
);
}
pub async fn put(&self, key: &str, token: TokenInfo) {
let cache = self.cache.clone();
let stats = self.stats.clone();
let key_owned = key.to_string();
let mut cache_guard = cache.write().unwrap_or_else(|e| e.into_inner());
let ttl = Duration::from_secs(token.expires_in_seconds());
let entry = CacheEntry::new(token, ttl);
if cache_guard.len() >= 1000 {
if let Some(oldest_key) = cache_guard
.iter()
.min_by_key(|(_, entry)| entry.created_at)
.map(|(k, _)| k.clone())
{
cache_guard.remove(&oldest_key);
}
}
cache_guard.insert(key_owned, entry);
if let Ok(mut stats_guard) = stats.try_write() {
stats_guard.current_size = cache_guard.len();
}
debug!("Cached token for key");
}
pub async fn get(&self, key: &str) -> Option<TokenInfo> {
let cache = self.cache.clone();
{
let cache_guard = cache.read().unwrap_or_else(|e| e.into_inner());
if let Some(entry) = cache_guard.get(key) {
if entry.is_expired() {
drop(cache_guard);
let mut cache_guard = cache.write().unwrap_or_else(|e| e.into_inner());
cache_guard.remove(key);
return None;
} else {
return Some(entry.token_info.clone());
}
}
}
None
}
pub async fn remove(&self, key: &str) -> Option<TokenInfo> {
let cache = self.cache.clone();
let stats = self.stats.clone();
let mut cache_guard = cache.write().unwrap_or_else(|e| e.into_inner());
let entry = cache_guard.remove(key);
if entry.is_some() {
if let Ok(mut stats_guard) = stats.try_write() {
stats_guard.current_size = cache_guard.len();
}
}
entry.map(|e| e.token_info)
}
pub async fn clear(&self) {
let cache = self.cache.clone();
let stats = self.stats.clone();
let mut cache_guard = cache.write().unwrap_or_else(|e| e.into_inner());
cache_guard.clear();
if let Ok(mut stats_guard) = stats.try_write() {
stats_guard.current_size = 0;
}
info!("Cleared all cache entries");
}
pub async fn size(&self) -> usize {
let cache = self.cache.read().unwrap_or_else(|e| e.into_inner());
cache.len()
}
pub async fn stats(&self) -> CacheStats {
let stats = self.stats.read().unwrap_or_else(|e| e.into_inner());
stats.clone()
}
pub async fn validate(&self) -> Result<usize, String> {
let cache = self.cache.read().unwrap_or_else(|e| e.into_inner());
let stats = self.stats.read().unwrap_or_else(|e| e.into_inner());
let cache_size = cache.len();
let stats_size = stats.current_size;
if cache_size == stats_size {
Ok(cache_size)
} else {
Err(format!(
"Cache size ({}) differs from stats ({})",
cache_size, stats_size
))
}
}
}
impl Drop for MemoryTokenCache {
fn drop(&mut self) {
if let Some(handle) = self.cleanup_handle.take() {
handle.abort();
}
}
}
#[allow(async_fn_in_trait)]
pub trait TokenStorage: Send + Sync {
async fn store(&self, key: &str, token: &TokenInfo) -> Result<(), Box<dyn std::error::Error>>;
async fn retrieve(&self, key: &str) -> Result<Option<TokenInfo>, Box<dyn std::error::Error>>;
async fn delete(&self, key: &str) -> Result<Option<TokenInfo>, Box<dyn std::error::Error>>;
async fn list(&self, prefix: Option<&str>) -> Result<Vec<String>, Box<dyn std::error::Error>>;
}
impl TokenStorage for MemoryTokenCache {
async fn store(&self, key: &str, token: &TokenInfo) -> Result<(), Box<dyn std::error::Error>> {
self.put(key, token.clone()).await;
Ok(())
}
async fn retrieve(&self, key: &str) -> Result<Option<TokenInfo>, Box<dyn std::error::Error>> {
Ok(self.get(key).await)
}
async fn delete(&self, key: &str) -> Result<Option<TokenInfo>, Box<dyn std::error::Error>> {
Ok(self.remove(key).await)
}
async fn list(&self, prefix: Option<&str>) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let cache = self.cache.read().unwrap_or_else(|e| e.into_inner());
let keys: Vec<String> = cache
.keys()
.filter(|k| prefix.is_none_or(|p| k.starts_with(p)))
.cloned()
.collect();
Ok(keys)
}
}
#[cfg(test)]
#[allow(unused_imports)]
mod tests {
use super::*;
use crate::auth::TokenType;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_memory_cache_basic_operations() {
let config = CacheConfig::default();
let cache = MemoryTokenCache::new(config);
let token = TokenInfo::new(
"test_token".to_string(),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
cache.put("test_key", token.clone()).await;
let retrieved = cache.get("test_key").await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().access_token, "test_token");
let stats = cache.stats().await;
assert_eq!(stats.current_size, 1);
}
#[tokio::test]
async fn test_memory_cache_expiry() {
let config = CacheConfig {
ttl: Duration::from_millis(100), ..Default::default()
};
let cache = MemoryTokenCache::new(config);
let token = TokenInfo::new(
"test_token".to_string(),
TokenType::AppAccessToken,
Duration::from_secs(0), "test_app".to_string(),
);
cache.put("test_key", token).await;
let retrieved = cache.get("test_key").await;
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_memory_cache_remove() {
let config = CacheConfig::default();
let cache = MemoryTokenCache::new(config);
let token = TokenInfo::new(
"test_token".to_string(),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
cache.put("test_key", token.clone()).await;
assert_eq!(cache.size().await, 1);
let removed = cache.remove("test_key").await;
assert!(removed.is_some());
assert_eq!(removed.unwrap().access_token, "test_token");
assert_eq!(cache.size().await, 0);
let retrieved = cache.get("test_key").await;
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_memory_cache_clear() {
let config = CacheConfig::default();
let cache = MemoryTokenCache::new(config);
for i in 0..5 {
let token = TokenInfo::new(
format!("token_{}", i),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
cache.put(&format!("key_{}", i), token).await;
}
assert_eq!(cache.size().await, 5);
cache.clear().await;
assert_eq!(cache.size().await, 0);
for i in 0..5 {
let retrieved = cache.get(&format!("key_{}", i)).await;
assert!(retrieved.is_none());
}
}
#[tokio::test]
async fn test_memory_cache_capacity_limit() {
let config = CacheConfig::default();
let cache = MemoryTokenCache::new(config);
for i in 0..1005 {
let token = TokenInfo::new(
format!("token_{}", i),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
cache.put(&format!("key_{}", i), token).await;
}
let size = cache.size().await;
assert!(size <= 1000);
let early_token = cache.get("key_0").await;
let later_token = cache.get("key_1004").await;
assert!(early_token.is_none()); assert!(later_token.is_some()); }
#[tokio::test]
async fn test_cache_entry_creation() {
let token = TokenInfo::new(
"test_token".to_string(),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
let ttl = Duration::from_secs(300);
let entry = CacheEntry::new(token.clone(), ttl);
assert_eq!(entry.token_info.access_token, "test_token");
assert!(!entry.is_expired()); }
#[tokio::test]
async fn test_cache_entry_expiry() {
let token = TokenInfo::new(
"test_token".to_string(),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
let ttl = Duration::from_millis(1); let entry = CacheEntry::new(token, ttl);
sleep(Duration::from_millis(10)).await;
assert!(entry.is_expired());
}
#[tokio::test]
async fn test_token_storage_trait() {
let config = CacheConfig::default();
let cache = MemoryTokenCache::new(config);
let token = TokenInfo::new(
"test_token".to_string(),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
cache.store("test_key", &token).await.unwrap();
let retrieved = cache.retrieve("test_key").await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().access_token, "test_token");
let deleted = cache.delete("test_key").await.unwrap();
assert!(deleted.is_some());
let after_delete = cache.retrieve("test_key").await.unwrap();
assert!(after_delete.is_none());
}
#[tokio::test]
async fn test_cache_validation() {
let config = CacheConfig::default();
let cache = MemoryTokenCache::new(config);
let validation_result = cache.validate().await;
assert!(validation_result.is_ok());
assert_eq!(validation_result.unwrap(), 0);
let token = TokenInfo::new(
"test_token".to_string(),
TokenType::AppAccessToken,
Duration::from_secs(3600),
"test_app".to_string(),
);
cache.put("test_key", token).await;
let validation_result = cache.validate().await;
assert!(validation_result.is_ok());
assert_eq!(validation_result.unwrap(), 1); }
}