use std::collections::HashMap;
use std::hash::Hash;
use std::sync::RwLock;
use std::time::Duration;
use tokio::time::Instant;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
struct CacheEntry<V> {
value: V,
inserted_at: Instant,
ttl: Duration,
}
impl<V> CacheEntry<V> {
fn new(value: V, ttl: Duration) -> Self {
Self {
value,
inserted_at: Instant::now(),
ttl,
}
}
fn is_expired(&self) -> bool {
self.inserted_at.elapsed() > self.ttl
}
fn is_stale(&self) -> bool {
self.inserted_at.elapsed() > (self.ttl * 3 / 4)
}
fn age(&self) -> Duration {
self.inserted_at.elapsed()
}
}
pub struct TtlCache<K, V> {
entries: RwLock<HashMap<K, CacheEntry<V>>>,
default_ttl: Duration,
max_capacity: usize,
}
const DEFAULT_MAX_CAPACITY: usize = 1024;
impl<K, V> TtlCache<K, V>
where
K: Eq + Hash + Clone + std::fmt::Debug,
V: Clone,
{
pub fn new(default_ttl: Duration) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
default_ttl,
max_capacity: DEFAULT_MAX_CAPACITY,
}
}
pub fn with_max_capacity(default_ttl: Duration, max_capacity: usize) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
default_ttl,
max_capacity,
}
}
pub fn get(&self, key: &K) -> Option<V> {
let entries = match self.entries.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("Cache read lock poisoned, recovering");
poisoned.into_inner()
}
};
let entry = entries.get(key)?;
if entry.is_expired() {
debug!(
hit = false,
?key,
age_secs = entry.age().as_secs(),
"cache lookup (expired)"
);
None
} else {
debug!(hit = true, ?key, "cache lookup");
Some(entry.value.clone())
}
}
pub fn get_stale(&self, key: &K) -> Option<V> {
let entries = match self.entries.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("Cache read lock poisoned, recovering");
poisoned.into_inner()
}
};
entries.get(key).map(|entry| {
if entry.is_expired() {
debug!(
?key,
age_secs = entry.age().as_secs(),
"Serving stale cache entry"
);
}
entry.value.clone()
})
}
pub fn needs_refresh(&self, key: &K) -> bool {
let entries = match self.entries.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("Cache read lock poisoned, recovering");
poisoned.into_inner()
}
};
entries.get(key).is_some_and(|entry| entry.is_stale())
}
pub fn insert(&self, key: K, value: V) {
self.insert_with_ttl(key, value, self.default_ttl);
}
pub fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) {
let mut entries = match self.entries.write() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("Cache write lock poisoned, recovering");
poisoned.into_inner()
}
};
if entries.len() >= self.max_capacity && !entries.contains_key(&key) {
let before = entries.len();
entries.retain(|_, entry| !entry.is_expired());
let removed = before - entries.len();
if removed > 0 {
debug!(removed, "Evicted expired entries to make room");
}
if entries.len() >= self.max_capacity {
if let Some(oldest_key) = entries
.iter()
.max_by_key(|(_, entry)| entry.age())
.map(|(k, _)| k.clone())
{
entries.remove(&oldest_key);
debug!(?oldest_key, "Evicted oldest entry to make room");
}
}
}
debug!(?key, ttl_secs = ttl.as_secs(), "Inserting cache entry");
entries.insert(key, CacheEntry::new(value, ttl));
}
pub fn remove(&self, key: &K) -> Option<V> {
let mut entries = match self.entries.write() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("Cache write lock poisoned, recovering");
poisoned.into_inner()
}
};
entries.remove(key).map(|e| e.value)
}
pub fn cleanup(&self) {
let mut entries = match self.entries.write() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("Cache write lock poisoned, recovering");
poisoned.into_inner()
}
};
let before = entries.len();
entries.retain(|_, entry| !entry.is_expired());
let removed = before - entries.len();
if removed > 0 {
debug!(removed, remaining = entries.len(), "Cache cleanup complete");
}
}
pub fn len(&self) -> usize {
match self.entries.read() {
Ok(entries) => entries.len(),
Err(poisoned) => {
warn!("Cache read lock poisoned, recovering");
poisoned.into_inner().len()
}
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
let mut entries = match self.entries.write() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("Cache write lock poisoned, recovering");
poisoned.into_inner()
}
};
entries.clear();
}
}
pub struct SingleValueCache<V> {
entry: RwLock<Option<CacheEntry<V>>>,
ttl: Duration,
}
impl<V: Clone> SingleValueCache<V> {
pub fn new(ttl: Duration) -> Self {
Self {
entry: RwLock::new(None),
ttl,
}
}
pub fn get(&self) -> Option<V> {
let guard = match self.entry.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("SingleValueCache read lock poisoned, recovering");
poisoned.into_inner()
}
};
let entry = guard.as_ref()?;
if entry.is_expired() {
None
} else {
Some(entry.value.clone())
}
}
pub fn get_stale(&self) -> Option<V> {
let guard = match self.entry.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("SingleValueCache read lock poisoned, recovering");
poisoned.into_inner()
}
};
guard.as_ref().map(|e| e.value.clone())
}
pub fn needs_refresh(&self) -> bool {
let guard = match self.entry.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("SingleValueCache read lock poisoned, recovering");
poisoned.into_inner()
}
};
match guard.as_ref() {
Some(e) => e.is_stale(),
None => true,
}
}
pub fn has_value(&self) -> bool {
let guard = match self.entry.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("SingleValueCache read lock poisoned, recovering");
poisoned.into_inner()
}
};
guard.is_some()
}
pub fn set(&self, value: V) {
let mut guard = match self.entry.write() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("SingleValueCache write lock poisoned, recovering");
poisoned.into_inner()
}
};
*guard = Some(CacheEntry::new(value, self.ttl));
}
pub fn clear(&self) {
let mut guard = match self.entry.write() {
Ok(guard) => guard,
Err(poisoned) => {
warn!("SingleValueCache write lock poisoned, recovering");
poisoned.into_inner()
}
};
*guard = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_insert_and_get() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
cache.insert("key".to_string(), "value".to_string());
assert_eq!(cache.get(&"key".to_string()), Some("value".to_string()));
}
#[test]
fn test_cache_get_missing_key() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
assert_eq!(cache.get(&"missing".to_string()), None);
}
#[test]
fn test_cache_expiration() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_millis(10));
cache.insert("key".to_string(), "value".to_string());
assert_eq!(cache.get(&"key".to_string()), Some("value".to_string()));
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cache.get(&"key".to_string()), None);
}
#[test]
fn test_cache_get_stale_after_expiration() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_millis(10));
cache.insert("key".to_string(), "value".to_string());
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cache.get(&"key".to_string()), None);
assert_eq!(
cache.get_stale(&"key".to_string()),
Some("value".to_string())
);
}
#[test]
fn test_cache_remove() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
cache.insert("key".to_string(), "value".to_string());
assert!(cache.get(&"key".to_string()).is_some());
cache.remove(&"key".to_string());
assert!(cache.get(&"key".to_string()).is_none());
}
#[test]
fn test_cache_cleanup() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_millis(10));
cache.insert("key1".to_string(), "value1".to_string());
cache.insert("key2".to_string(), "value2".to_string());
std::thread::sleep(Duration::from_millis(20));
cache.insert_with_ttl(
"key3".to_string(),
"value3".to_string(),
Duration::from_secs(3600),
);
assert_eq!(cache.len(), 3);
cache.cleanup();
assert_eq!(cache.len(), 1);
assert_eq!(cache.get(&"key3".to_string()), Some("value3".to_string()));
}
#[test]
fn test_cache_clear() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
cache.insert("key1".to_string(), "value1".to_string());
cache.insert("key2".to_string(), "value2".to_string());
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_single_value_cache() {
let cache: SingleValueCache<String> = SingleValueCache::new(Duration::from_secs(3600));
assert!(!cache.has_value());
assert!(cache.get().is_none());
cache.set("value".to_string());
assert!(cache.has_value());
assert_eq!(cache.get(), Some("value".to_string()));
}
#[test]
fn test_single_value_cache_expiration() {
let cache: SingleValueCache<String> = SingleValueCache::new(Duration::from_millis(10));
cache.set("value".to_string());
assert_eq!(cache.get(), Some("value".to_string()));
std::thread::sleep(Duration::from_millis(20));
assert!(cache.get().is_none());
assert_eq!(cache.get_stale(), Some("value".to_string()));
}
#[tokio::test(start_paused = true)]
async fn test_needs_refresh() {
let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(1));
cache.insert("key".to_string(), "value".to_string());
assert!(!cache.needs_refresh(&"key".to_string()));
tokio::time::advance(Duration::from_millis(800)).await;
assert!(
cache.needs_refresh(&"key".to_string()),
"entry must be stale at t=800ms (>= 750ms threshold)"
);
assert!(
cache.get(&"key".to_string()).is_some(),
"entry must not be expired at t=800ms (< 1000ms TTL)"
);
tokio::time::advance(Duration::from_millis(300)).await;
assert!(
cache.get(&"key".to_string()).is_none(),
"entry must be expired at t=1100ms (> 1000ms TTL)"
);
}
}