use crate::TaskMeta;
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use uuid::Uuid;
#[derive(Debug, Clone)]
struct CacheEntry {
data: TaskMeta,
cached_at: DateTime<Utc>,
access_count: u64,
}
impl CacheEntry {
fn new(data: TaskMeta) -> Self {
Self {
data,
cached_at: Utc::now(),
access_count: 0,
}
}
fn is_expired(&self, ttl: Duration) -> bool {
let age = Utc::now() - self.cached_at;
age.num_milliseconds() > ttl.as_millis() as i64
}
fn access(&mut self) -> TaskMeta {
self.access_count += 1;
self.data.clone()
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub capacity: usize,
pub ttl: Duration,
pub enabled: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
capacity: 1000,
ttl: Duration::from_secs(300), enabled: true,
}
}
}
impl CacheConfig {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self {
capacity: 0,
ttl: Duration::from_secs(0),
enabled: false,
}
}
pub fn with_capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity;
self
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
#[derive(Debug, Clone)]
pub struct ResultCache {
config: CacheConfig,
entries: Arc<RwLock<HashMap<Uuid, CacheEntry>>>,
}
impl Default for ResultCache {
fn default() -> Self {
Self::new(CacheConfig::default())
}
}
impl ResultCache {
pub fn new(config: CacheConfig) -> Self {
Self {
config,
entries: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn disabled() -> Self {
Self::new(CacheConfig::disabled())
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
pub fn get(&self, task_id: Uuid) -> Option<TaskMeta> {
if !self.config.enabled {
return None;
}
let mut entries = self.entries.write().ok()?;
if let Some(entry) = entries.get_mut(&task_id) {
if entry.is_expired(self.config.ttl) {
entries.remove(&task_id);
return None;
}
Some(entry.access())
} else {
None
}
}
pub fn put(&self, task_id: Uuid, meta: TaskMeta) {
if !self.config.enabled {
return;
}
let mut entries = self.entries.write().expect("lock should not be poisoned");
if entries.len() >= self.config.capacity && !entries.contains_key(&task_id) {
self.evict_oldest(&mut entries);
}
entries.insert(task_id, CacheEntry::new(meta));
}
pub fn invalidate(&self, task_id: Uuid) {
if !self.config.enabled {
return;
}
if let Ok(mut entries) = self.entries.write() {
entries.remove(&task_id);
}
}
pub fn clear(&self) {
if let Ok(mut entries) = self.entries.write() {
entries.clear();
}
}
pub fn len(&self) -> usize {
self.entries.read().map(|e| e.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn cleanup_expired(&self) -> usize {
if !self.config.enabled {
return 0;
}
let mut entries = match self.entries.write() {
Ok(e) => e,
Err(_) => return 0,
};
let before_count = entries.len();
let expired_keys: Vec<Uuid> = entries
.iter()
.filter(|(_, entry)| entry.is_expired(self.config.ttl))
.map(|(key, _)| *key)
.collect();
for key in expired_keys {
entries.remove(&key);
}
before_count - entries.len()
}
fn evict_oldest(&self, entries: &mut HashMap<Uuid, CacheEntry>) {
if let Some((&oldest_key, _)) = entries.iter().min_by_key(|(_, entry)| entry.cached_at) {
entries.remove(&oldest_key);
}
}
pub fn stats(&self) -> CacheStats {
let entries = match self.entries.read() {
Ok(e) => e,
Err(_) => {
return CacheStats {
size: 0,
capacity: self.config.capacity,
expired_count: 0,
}
}
};
let expired_count = entries
.values()
.filter(|entry| entry.is_expired(self.config.ttl))
.count();
CacheStats {
size: entries.len(),
capacity: self.config.capacity,
expired_count,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub size: usize,
pub capacity: usize,
pub expired_count: usize,
}
impl std::fmt::Display for CacheStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cache: {}/{} entries ({} expired)",
self.size, self.capacity, self.expired_count
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_config_defaults() {
let config = CacheConfig::default();
assert_eq!(config.capacity, 1000);
assert_eq!(config.ttl.as_secs(), 300);
assert!(config.enabled);
}
#[test]
fn test_cache_config_builder() {
let config = CacheConfig::new()
.with_capacity(500)
.with_ttl(Duration::from_secs(60))
.with_enabled(true);
assert_eq!(config.capacity, 500);
assert_eq!(config.ttl.as_secs(), 60);
assert!(config.enabled);
}
#[test]
fn test_cache_disabled() {
let cache = ResultCache::disabled();
assert!(!cache.is_enabled());
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, "test".to_string());
cache.put(task_id, meta.clone());
assert!(cache.get(task_id).is_none());
}
#[test]
fn test_cache_put_and_get() {
let cache = ResultCache::new(CacheConfig::new());
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, "test".to_string());
cache.put(task_id, meta.clone());
let result = cache.get(task_id);
assert!(result.is_some());
assert_eq!(result.unwrap().task_id, task_id);
}
#[test]
fn test_cache_miss() {
let cache = ResultCache::new(CacheConfig::new());
let task_id = Uuid::new_v4();
assert!(cache.get(task_id).is_none());
}
#[test]
fn test_cache_invalidate() {
let cache = ResultCache::new(CacheConfig::new());
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, "test".to_string());
cache.put(task_id, meta.clone());
assert!(cache.get(task_id).is_some());
cache.invalidate(task_id);
assert!(cache.get(task_id).is_none());
}
#[test]
fn test_cache_clear() {
let cache = ResultCache::new(CacheConfig::new());
for i in 0..10 {
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, format!("test-{}", i));
cache.put(task_id, meta);
}
assert_eq!(cache.len(), 10);
cache.clear();
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cache_capacity() {
let config = CacheConfig::new().with_capacity(5);
let cache = ResultCache::new(config);
let mut task_ids = Vec::new();
for i in 0..10 {
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, format!("test-{}", i));
cache.put(task_id, meta);
task_ids.push(task_id);
}
assert_eq!(cache.len(), 5);
for task_id in task_ids.iter().skip(5) {
assert!(cache.get(*task_id).is_some());
}
}
#[test]
fn test_cache_expiration() {
let config = CacheConfig::new().with_ttl(Duration::from_millis(50));
let cache = ResultCache::new(config);
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, "test".to_string());
cache.put(task_id, meta);
assert!(cache.get(task_id).is_some());
std::thread::sleep(Duration::from_millis(100));
assert!(cache.get(task_id).is_none());
}
#[test]
fn test_cache_cleanup_expired() {
let config = CacheConfig::new().with_ttl(Duration::from_millis(50));
let cache = ResultCache::new(config);
for i in 0..5 {
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, format!("test-{}", i));
cache.put(task_id, meta);
}
assert_eq!(cache.len(), 5);
std::thread::sleep(Duration::from_millis(100));
let removed = cache.cleanup_expired();
assert_eq!(removed, 5);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cache_stats() {
let cache = ResultCache::new(CacheConfig::new().with_capacity(10));
for i in 0..5 {
let task_id = Uuid::new_v4();
let meta = TaskMeta::new(task_id, format!("test-{}", i));
cache.put(task_id, meta);
}
let stats = cache.stats();
assert_eq!(stats.size, 5);
assert_eq!(stats.capacity, 10);
}
}