use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
#[derive(Debug)]
pub struct SmartCache<V: CacheValue> {
entries: Arc<DashMap<String, CacheEntry<V>>>,
max_entries: usize,
max_size_bytes: usize,
current_size: Arc<AtomicUsize>,
default_ttl: Duration,
stats: Arc<CacheStats>,
}
#[derive(Debug, Clone)]
pub struct CacheEntry<V> {
pub value: V,
pub created_at: Instant,
pub last_accessed: Instant,
pub size_bytes: usize,
pub ttl: Duration,
pub access_count: u32,
}
pub trait CacheValue: Clone + Send + Sync + 'static {
fn estimated_size(&self) -> usize;
}
impl CacheValue for String {
fn estimated_size(&self) -> usize {
self.len() + std::mem::size_of::<String>()
}
}
impl CacheValue for Vec<u8> {
fn estimated_size(&self) -> usize {
self.len() + std::mem::size_of::<Vec<u8>>()
}
}
impl CacheValue for serde_json::Value {
fn estimated_size(&self) -> usize {
estimate_json_size(self) + std::mem::size_of::<serde_json::Value>()
}
}
fn estimate_json_size(v: &serde_json::Value) -> usize {
match v {
serde_json::Value::Null => 4,
serde_json::Value::Bool(_) => 5,
serde_json::Value::Number(n) => {
if n.is_f64() {
24
} else {
20
}
}
serde_json::Value::String(s) => s.len() + 2 + std::mem::size_of::<String>(),
serde_json::Value::Array(arr) => {
let inner: usize = arr.iter().map(estimate_json_size).sum();
inner
+ arr.len() * std::mem::size_of::<serde_json::Value>()
+ std::mem::size_of::<Vec<serde_json::Value>>()
}
serde_json::Value::Object(map) => {
let inner: usize = map
.iter()
.map(|(k, v)| k.len() + std::mem::size_of::<String>() + estimate_json_size(v))
.sum();
inner + std::mem::size_of::<serde_json::Map<String, serde_json::Value>>()
}
}
}
#[derive(Debug, Default)]
pub struct CacheStats {
hits: AtomicUsize,
misses: AtomicUsize,
evictions: AtomicUsize,
expirations: AtomicUsize,
}
impl CacheStats {
pub fn hits(&self) -> usize {
self.hits.load(Ordering::Relaxed)
}
pub fn misses(&self) -> usize {
self.misses.load(Ordering::Relaxed)
}
pub fn evictions(&self) -> usize {
self.evictions.load(Ordering::Relaxed)
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits() as f64;
let total = hits + self.misses() as f64;
if total > 0.0 {
hits / total
} else {
0.0
}
}
pub fn reset(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.evictions.store(0, Ordering::Relaxed);
self.expirations.store(0, Ordering::Relaxed);
}
}
impl<V: CacheValue> SmartCache<V> {
pub fn new() -> Self {
Self::with_limits(1000, 50 * 1024 * 1024, Duration::from_secs(300))
}
pub fn with_limits(max_entries: usize, max_size_bytes: usize, default_ttl: Duration) -> Self {
Self {
entries: Arc::new(DashMap::with_capacity(max_entries.min(10000))),
max_entries,
max_size_bytes,
current_size: Arc::new(AtomicUsize::new(0)),
default_ttl,
stats: Arc::new(CacheStats::default()),
}
}
pub fn small() -> Self {
Self::with_limits(100, 5 * 1024 * 1024, Duration::from_secs(60))
}
pub fn large() -> Self {
Self::with_limits(10000, 200 * 1024 * 1024, Duration::from_secs(600))
}
pub async fn get(&self, key: &str) -> Option<V> {
let now = Instant::now();
if let Some(mut entry) = self.entries.get_mut(key) {
if now.duration_since(entry.created_at) <= entry.ttl {
let value = entry.value.clone();
entry.last_accessed = now;
entry.access_count = entry.access_count.saturating_add(1);
self.stats.hits.fetch_add(1, Ordering::Relaxed);
return Some(value);
}
let size = entry.size_bytes;
drop(entry);
self.entries.remove(key);
self.current_size.fetch_sub(size, Ordering::Relaxed);
self.stats.expirations.fetch_add(1, Ordering::Relaxed);
}
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub async fn set(&self, key: impl Into<String>, value: V) {
self.set_with_ttl(key, value, self.default_ttl).await;
}
pub async fn set_with_ttl(&self, key: impl Into<String>, value: V, ttl: Duration) {
let key = key.into();
let size = value.estimated_size() + key.len() + std::mem::size_of::<CacheEntry<V>>();
if let Some((_, old)) = self.entries.remove(&key) {
self.current_size
.fetch_sub(old.size_bytes, Ordering::Relaxed);
}
let over_count = self
.entries
.len()
.saturating_sub(self.max_entries.saturating_sub(1));
let over_bytes =
(self.current_size.load(Ordering::Relaxed) + size).saturating_sub(self.max_size_bytes);
if over_count > 0 || over_bytes > 0 {
self.batch_evict(over_count, over_bytes);
}
let now = Instant::now();
let entry = CacheEntry {
value,
created_at: now,
last_accessed: now,
size_bytes: size,
ttl,
access_count: 1,
};
self.entries.insert(key, entry);
self.current_size.fetch_add(size, Ordering::Relaxed);
}
pub async fn remove(&self, key: &str) -> Option<V> {
if let Some((_, entry)) = self.entries.remove(key) {
self.current_size
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
Some(entry.value)
} else {
None
}
}
pub async fn clear(&self) {
self.entries.clear();
self.current_size.store(0, Ordering::Relaxed);
}
pub async fn len(&self) -> usize {
self.entries.len()
}
pub async fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn size_bytes(&self) -> usize {
self.current_size.load(Ordering::Relaxed)
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
fn batch_evict(&self, min_count: usize, min_bytes: usize) {
if self.entries.is_empty() {
return;
}
let now = Instant::now();
let mut freed_count = 0usize;
let mut freed_bytes = 0usize;
let expired_keys: Vec<String> = self
.entries
.iter()
.filter(|r| now.duration_since(r.value().created_at) > r.value().ttl)
.map(|r| r.key().clone())
.collect();
for key in &expired_keys {
if let Some((_, entry)) = self.entries.remove(key) {
freed_bytes += entry.size_bytes;
freed_count += 1;
self.current_size
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
self.stats.expirations.fetch_add(1, Ordering::Relaxed);
}
}
if freed_count >= min_count && freed_bytes >= min_bytes {
return;
}
let remaining_count = min_count.saturating_sub(freed_count);
let remaining_bytes = min_bytes.saturating_sub(freed_bytes);
if remaining_count == 0 && remaining_bytes == 0 {
return;
}
let mut candidates: Vec<(String, Instant, usize)> = self
.entries
.iter()
.map(|r| {
(
r.key().clone(),
r.value().last_accessed,
r.value().size_bytes,
)
})
.collect();
if candidates.is_empty() {
return;
}
let pivot = remaining_count.min(candidates.len()).max(1) - 1;
candidates.select_nth_unstable_by_key(pivot, |(_, accessed, _)| *accessed);
for (key, _, _) in &candidates {
if freed_count >= min_count && freed_bytes >= min_bytes {
break;
}
if let Some((_, entry)) = self.entries.remove(key) {
freed_bytes += entry.size_bytes;
freed_count += 1;
self.current_size
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
}
pub async fn cleanup_expired(&self) {
let now = Instant::now();
let expired_keys: Vec<String> = self
.entries
.iter()
.filter(|r| now.duration_since(r.value().created_at) > r.value().ttl)
.map(|r| r.key().clone())
.collect();
for key in expired_keys {
if let Some((_, entry)) = self.entries.remove(&key) {
self.current_size
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
self.stats.expirations.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn start_cleanup_task(self: Arc<Self>, interval: Duration) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
self.cleanup_expired().await;
}
})
}
}
impl<V: CacheValue> Default for SmartCache<V> {
fn default() -> Self {
Self::new()
}
}
pub type HtmlCache = SmartCache<String>;
pub type JsonCache = SmartCache<serde_json::Value>;
pub fn html_cache(max_pages: usize, max_mb: usize) -> HtmlCache {
SmartCache::with_limits(max_pages, max_mb * 1024 * 1024, Duration::from_secs(300))
}
pub fn json_cache(max_entries: usize, max_mb: usize) -> JsonCache {
SmartCache::with_limits(max_entries, max_mb * 1024 * 1024, Duration::from_secs(300))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_smart_cache_basic() {
let cache: SmartCache<String> = SmartCache::new();
cache.set("key1", "value1".to_string()).await;
assert_eq!(cache.get("key1").await, Some("value1".to_string()));
assert_eq!(cache.get("key2").await, None);
}
#[tokio::test]
async fn test_smart_cache_eviction() {
let cache: SmartCache<String> =
SmartCache::with_limits(2, 1024 * 1024, Duration::from_secs(60));
cache.set("key1", "value1".to_string()).await;
cache.set("key2", "value2".to_string()).await;
cache.set("key3", "value3".to_string()).await;
assert_eq!(cache.len().await, 2);
assert_eq!(cache.get("key3").await, Some("value3".to_string()));
}
#[tokio::test]
async fn test_smart_cache_size_limit() {
let cache: SmartCache<String> = SmartCache::with_limits(100, 1024, Duration::from_secs(60));
for i in 0..50 {
cache.set(format!("key{}", i), "x".repeat(100)).await;
}
assert!(cache.size_bytes() <= 1024 + 200); }
#[tokio::test]
async fn test_smart_cache_ttl() {
let cache: SmartCache<String> =
SmartCache::with_limits(100, 1024 * 1024, Duration::from_millis(50));
cache.set("key1", "value1".to_string()).await;
assert!(cache.get("key1").await.is_some());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(cache.get("key1").await.is_none());
}
#[tokio::test]
async fn test_cache_stats() {
let cache: SmartCache<String> = SmartCache::new();
cache.set("key1", "value1".to_string()).await;
cache.get("key1").await; cache.get("key1").await; cache.get("key2").await;
let stats = cache.stats();
assert_eq!(stats.hits(), 2);
assert_eq!(stats.misses(), 1);
assert!((stats.hit_rate() - 0.666).abs() < 0.01);
}
#[tokio::test]
async fn test_smart_cache_concurrent_reads_and_writes() {
let cache: Arc<SmartCache<String>> = Arc::new(SmartCache::with_limits(
1024,
16 * 1024 * 1024,
Duration::from_secs(120),
));
cache.set("shared", "v0".to_string()).await;
let mut tasks = tokio::task::JoinSet::new();
for reader_idx in 0..64usize {
let cache = cache.clone();
tasks.spawn(async move {
let mut observed = 0usize;
for _ in 0..250usize {
if cache.get("shared").await.is_some() {
observed += 1;
}
}
(reader_idx, observed)
});
}
for writer_idx in 0..8usize {
let cache = cache.clone();
tasks.spawn(async move {
for round in 0..120usize {
let value = format!("writer-{writer_idx}-round-{round}");
cache.set("shared", value).await;
}
(writer_idx, 120usize)
});
}
while let Some(joined) = tasks.join_next().await {
assert!(joined.is_ok(), "task panicked under concurrency");
}
assert!(cache.get("shared").await.is_some());
assert!(cache.len().await <= 1024);
assert!(cache.stats().hits() > 0);
}
#[tokio::test]
async fn test_smart_cache_concurrent_eviction_stays_bounded() {
let cache: Arc<SmartCache<String>> = Arc::new(SmartCache::with_limits(
64,
64 * 1024,
Duration::from_secs(120),
));
let mut tasks = tokio::task::JoinSet::new();
for worker in 0..24usize {
let cache = cache.clone();
tasks.spawn(async move {
for n in 0..180usize {
let key = format!("w{worker}-k{n}");
let value = "x".repeat(256);
cache.set(key, value).await;
}
});
}
while let Some(joined) = tasks.join_next().await {
assert!(joined.is_ok(), "worker panicked during eviction stress");
}
assert!(cache.len().await <= 64);
assert!(cache.size_bytes() <= (64 * 1024) + 4096);
}
}