use super::entry::CacheEntry;
use rand::prelude::IndexedRandom;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::RwLock;
use tokio::task::AbortHandle;
struct ActiveSampler {
sample_size: usize,
threshold: f32,
}
impl Default for ActiveSampler {
fn default() -> Self {
Self {
sample_size: 20,
threshold: 0.25,
}
}
}
type TtlIndex = HashMap<u64, Vec<String>>;
pub struct LayeredCacheStore {
store: Arc<RwLock<HashMap<String, CacheEntry>>>,
ttl_index: Arc<RwLock<TtlIndex>>,
active_sampler: ActiveSampler,
cleanup_handle: Arc<std::sync::Mutex<Option<AbortHandle>>>,
}
impl LayeredCacheStore {
pub fn new() -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
ttl_index: Arc::new(RwLock::new(HashMap::new())),
active_sampler: ActiveSampler::default(),
cleanup_handle: Arc::new(std::sync::Mutex::new(None)),
}
}
pub fn with_sampler(sample_size: usize, threshold: f32) -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
ttl_index: Arc::new(RwLock::new(HashMap::new())),
active_sampler: ActiveSampler {
sample_size,
threshold,
},
cleanup_handle: Arc::new(std::sync::Mutex::new(None)),
}
}
pub async fn get(&self, key: &str) -> Option<Vec<u8>> {
let mut store = self.store.write().await;
if let Some(entry) = store.get_mut(key) {
if entry.is_expired() {
store.remove(key);
return None;
}
entry.touch();
return Some(entry.value.clone());
}
None
}
pub async fn set(&self, key: String, value: Vec<u8>, ttl: Option<std::time::Duration>) {
let entry = CacheEntry::new(value, ttl);
let mut store = self.store.write().await;
if let Some(expires_at) = entry.expires_at {
let timestamp = expires_at
.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.map(|d| d.as_secs())
.unwrap_or(0);
let mut ttl_index = self.ttl_index.write().await;
ttl_index
.entry(timestamp)
.or_insert_with(Vec::new)
.push(key.clone());
}
store.insert(key, entry);
}
pub async fn delete(&self, key: &str) {
let mut store = self.store.write().await;
store.remove(key);
}
pub async fn has_key(&self, key: &str) -> bool {
let store = self.store.read().await;
if let Some(entry) = store.get(key) {
!entry.is_expired()
} else {
false
}
}
pub async fn clear(&self) {
let mut store = self.store.write().await;
let mut ttl_index = self.ttl_index.write().await;
store.clear();
ttl_index.clear();
}
pub async fn len(&self) -> usize {
let store = self.store.read().await;
store.len()
}
pub async fn is_empty(&self) -> bool {
let store = self.store.read().await;
store.is_empty()
}
pub async fn keys(&self) -> Vec<String> {
let store = self.store.read().await;
store.keys().cloned().collect()
}
pub(crate) async fn get_store_clone(&self) -> HashMap<String, CacheEntry> {
let store = self.store.read().await;
store.clone()
}
pub(crate) async fn get_entry(&self, key: &str) -> Option<CacheEntry> {
let store = self.store.read().await;
store.get(key).cloned()
}
pub async fn get_entry_timestamps(
&self,
key: &str,
) -> Option<(SystemTime, Option<SystemTime>)> {
let store = self.store.read().await;
if let Some(entry) = store.get(key) {
if entry.is_expired() {
return None;
}
Some((entry.created_at, entry.accessed_at))
} else {
None
}
}
pub async fn cleanup_active_sampling(&self) {
const MAX_ROUNDS: usize = 100;
for _ in 0..MAX_ROUNDS {
let keys = {
let store = self.store.read().await;
store.keys().cloned().collect::<Vec<_>>()
};
if keys.is_empty() {
return;
}
let sample_size = self.active_sampler.sample_size.min(keys.len());
let sample: Vec<_> = {
let mut rng = rand::rng();
keys.choose_multiple(&mut rng, sample_size)
.cloned()
.collect()
};
let mut expired_keys = Vec::new();
{
let store = self.store.read().await;
for key in &sample {
if let Some(entry) = store.get::<String>(key)
&& entry.is_expired()
{
expired_keys.push(key.clone());
}
}
}
let expired_ratio = expired_keys.len() as f32 / sample.len() as f32;
if expired_ratio > self.active_sampler.threshold {
let mut store = self.store.write().await;
for key in expired_keys {
store.remove(&key);
}
} else {
return;
}
}
}
pub async fn cleanup_ttl_index(&self) {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.map(|d| d.as_secs())
.unwrap_or(0);
let expired_timestamps: Vec<u64> = {
let ttl_index = self.ttl_index.read().await;
ttl_index.keys().filter(|&&ts| ts <= now).cloned().collect()
};
if expired_timestamps.is_empty() {
return;
}
let mut store = self.store.write().await;
let mut ttl_index = self.ttl_index.write().await;
for timestamp in expired_timestamps {
if let Some(keys) = ttl_index.remove(×tamp) {
for key in keys {
store.remove(&key);
}
}
}
}
pub async fn cleanup(&self) {
self.cleanup_ttl_index().await;
self.cleanup_active_sampling().await;
}
pub fn start_auto_cleanup(&self, interval: std::time::Duration)
where
Self: Clone,
{
let mut handle_guard = self
.cleanup_handle
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(existing) = handle_guard.take() {
existing.abort();
}
let store = self.clone();
let abort_handle = tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
loop {
interval_timer.tick().await;
store.cleanup().await;
}
})
.abort_handle();
*handle_guard = Some(abort_handle);
}
pub fn stop_auto_cleanup(&self) {
let mut handle_guard = self
.cleanup_handle
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(handle) = handle_guard.take() {
handle.abort();
}
}
}
impl Clone for LayeredCacheStore {
fn clone(&self) -> Self {
Self {
store: Arc::clone(&self.store),
ttl_index: Arc::clone(&self.ttl_index),
active_sampler: ActiveSampler {
sample_size: self.active_sampler.sample_size,
threshold: self.active_sampler.threshold,
},
cleanup_handle: Arc::clone(&self.cleanup_handle),
}
}
}
impl Default for LayeredCacheStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
async fn poll_until<F, Fut>(
timeout: std::time::Duration,
interval: std::time::Duration,
mut condition: F,
) -> std::result::Result<(), String>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = bool>,
{
let start = std::time::Instant::now();
while start.elapsed() < timeout {
if condition().await {
return Ok(());
}
tokio::time::sleep(interval).await;
}
Err(format!("Timeout after {:?} waiting for condition", timeout))
}
#[tokio::test]
async fn test_passive_expiration() {
let store = LayeredCacheStore::new();
store
.set(
"key1".to_string(),
vec![1, 2, 3],
Some(Duration::from_millis(50)),
)
.await;
assert!(store.get("key1").await.is_some());
poll_until(
Duration::from_millis(150),
Duration::from_millis(10),
|| async { store.get("key1").await.is_none() },
)
.await
.expect("Key should expire and be deleted within 150ms");
assert!(!store.has_key("key1").await);
}
#[tokio::test]
async fn test_active_sampling_basic() {
let store = LayeredCacheStore::with_sampler(10, 0.25);
for i in 0..50 {
store
.set(
format!("key{}", i),
vec![i as u8],
Some(Duration::from_millis(50)),
)
.await;
}
assert_eq!(store.len().await, 50);
poll_until(
Duration::from_millis(150),
Duration::from_millis(10),
|| async {
store.get("key0").await.is_none()
},
)
.await
.expect("Keys should expire within 150ms");
store.cleanup_active_sampling().await;
assert_eq!(store.len().await, 0);
}
#[tokio::test]
async fn test_active_sampling_threshold() {
let store = LayeredCacheStore::with_sampler(20, 0.25);
for i in 0..80 {
store
.set(
format!("expired{}", i),
vec![i as u8],
Some(Duration::from_millis(50)),
)
.await;
}
for i in 0..20 {
store
.set(format!("permanent{}", i), vec![i as u8], None)
.await;
}
poll_until(
Duration::from_millis(150),
Duration::from_millis(10),
|| async { store.get("expired0").await.is_none() },
)
.await
.expect("Expired keys should expire within 150ms");
store.cleanup_active_sampling().await;
assert!(store.get("permanent0").await.is_some());
assert!(store.get("permanent10").await.is_some());
assert!(store.get("expired0").await.is_none());
}
#[tokio::test]
async fn test_ttl_index_cleanup() {
let store = LayeredCacheStore::new();
for i in 0..100 {
store
.set(
format!("key{}", i),
vec![i as u8],
Some(Duration::from_secs(1)),
)
.await;
}
assert_eq!(store.len().await, 100);
poll_until(
Duration::from_secs(2),
Duration::from_millis(100),
|| async { store.get("key0").await.is_none() },
)
.await
.expect("Keys should expire within 2 seconds");
store.cleanup_ttl_index().await;
assert_eq!(store.len().await, 0);
}
#[tokio::test]
async fn test_ttl_index_partial_cleanup() {
let store = LayeredCacheStore::new();
for i in 0..50 {
store
.set(
format!("short{}", i),
vec![i as u8],
Some(Duration::from_millis(50)),
)
.await;
}
for i in 0..50 {
store
.set(
format!("long{}", i),
vec![i as u8],
Some(Duration::from_secs(10)),
)
.await;
}
tokio::time::sleep(Duration::from_millis(60)).await;
store.cleanup_ttl_index().await;
assert!(store.get("short0").await.is_none());
assert!(store.get("long0").await.is_some());
}
#[tokio::test]
async fn test_combined_cleanup() {
let store = LayeredCacheStore::new();
for i in 0..100 {
let ttl = if i < 50 {
Some(Duration::from_millis(50))
} else {
None
};
store.set(format!("key{}", i), vec![i as u8], ttl).await;
}
tokio::time::sleep(Duration::from_millis(60)).await;
store.cleanup().await;
assert!(store.get("key0").await.is_none());
assert!(store.get("key49").await.is_none());
assert!(store.get("key50").await.is_some());
assert!(store.get("key99").await.is_some());
}
#[tokio::test]
async fn test_layered_vs_naive_performance() {
let store = LayeredCacheStore::new();
let num_keys = 10000;
for i in 0..num_keys {
store
.set(
format!("key{}", i),
vec![i as u8],
Some(Duration::from_millis(50)),
)
.await;
}
let start = std::time::Instant::now();
store.cleanup().await;
let layered_duration = start.elapsed();
println!(
"Layered cleanup for {} keys: {:?}",
num_keys, layered_duration
);
}
#[tokio::test]
async fn test_basic_operations() {
let store = LayeredCacheStore::new();
store.set("key1".to_string(), vec![1, 2, 3], None).await;
assert_eq!(store.get("key1").await, Some(vec![1, 2, 3]));
assert!(store.has_key("key1").await);
assert!(!store.has_key("nonexistent").await);
store.delete("key1").await;
assert!(store.get("key1").await.is_none());
store.set("key2".to_string(), vec![4, 5, 6], None).await;
store.set("key3".to_string(), vec![7, 8, 9], None).await;
assert_eq!(store.len().await, 2);
store.clear().await;
assert_eq!(store.len().await, 0);
}
#[tokio::test]
async fn test_keys_listing() {
let store = LayeredCacheStore::new();
assert!(store.keys().await.is_empty());
store.set("key1".to_string(), vec![1], None).await;
store.set("key2".to_string(), vec![2], None).await;
store.set("key3".to_string(), vec![3], None).await;
let keys = store.keys().await;
assert_eq!(keys.len(), 3);
assert!(keys.contains(&"key1".to_string()));
assert!(keys.contains(&"key2".to_string()));
assert!(keys.contains(&"key3".to_string()));
}
}