use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub type CachedBackend<T> = Arc<Mutex<T>>;
struct Entry<T> {
backend: CachedBackend<T>,
size_bytes: u64,
last_used: Instant,
}
struct Inner<T> {
map: HashMap<String, Entry<T>>,
lru: VecDeque<String>,
total_bytes: u64,
}
pub struct BackendCache<T: Send + 'static> {
inner: Mutex<Inner<T>>,
budget_bytes: u64,
idle_ttl: Option<Duration>,
}
impl<T: Send + 'static> BackendCache<T> {
pub fn new(budget_bytes: u64) -> Self {
Self {
inner: Mutex::new(Inner {
map: HashMap::new(),
lru: VecDeque::new(),
total_bytes: 0,
}),
budget_bytes,
idle_ttl: None,
}
}
pub fn with_idle_ttl(budget_bytes: u64, idle_ttl: Option<Duration>) -> Self {
let mut cache = Self::new(budget_bytes);
cache.idle_ttl = idle_ttl;
cache
}
pub fn from_env() -> Self {
let mb = std::env::var("CAR_INFERENCE_MODEL_CACHE_MB")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(24 * 1024);
let idle_secs = std::env::var("CAR_INFERENCE_MODEL_IDLE_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(300);
let idle_ttl = (idle_secs > 0).then(|| Duration::from_secs(idle_secs));
Self::with_idle_ttl(mb.saturating_mul(1024 * 1024), idle_ttl)
}
pub fn is_disabled(&self) -> bool {
self.budget_bytes == 0
}
pub fn get_or_load<E>(
&self,
key: &str,
size_bytes: u64,
loader: impl FnOnce() -> Result<T, E>,
) -> Result<CachedBackend<T>, E> {
{
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(entry) = guard.map.get_mut(key) {
let handle = Arc::clone(&entry.backend);
entry.last_used = Instant::now();
guard.lru.retain(|k| k != key);
guard.lru.push_back(key.to_string());
return Ok(handle);
}
}
let backend = loader()?;
let handle = Arc::new(Mutex::new(backend));
if self.budget_bytes == 0 {
return Ok(handle);
}
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(existing) = guard.map.get(key) {
return Ok(Arc::clone(&existing.backend));
}
guard.total_bytes = guard.total_bytes.saturating_add(size_bytes);
guard.map.insert(
key.to_string(),
Entry {
backend: Arc::clone(&handle),
size_bytes,
last_used: Instant::now(),
},
);
guard.lru.push_back(key.to_string());
while guard.total_bytes > self.budget_bytes {
let Some(victim_key) = guard.lru.pop_front() else {
break;
};
if victim_key == key {
guard.lru.push_front(victim_key);
break;
}
if let Some(victim) = guard.map.remove(&victim_key) {
guard.total_bytes = guard.total_bytes.saturating_sub(victim.size_bytes);
drop(victim);
}
}
Ok(handle)
}
pub fn invalidate(&self, key: &str) {
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(entry) = guard.map.remove(key) {
guard.total_bytes = guard.total_bytes.saturating_sub(entry.size_bytes);
guard.lru.retain(|k| k != key);
}
}
pub fn evict_idle(&self) -> (usize, u64) {
let Some(ttl) = self.idle_ttl else {
return (0, 0);
};
let now = Instant::now();
let mut guard = self.inner.lock().expect("backend cache poisoned");
let stale: Vec<String> = guard
.map
.iter()
.filter(|(_, e)| now.duration_since(e.last_used) >= ttl)
.map(|(k, _)| k.clone())
.collect();
let mut entries = 0usize;
let mut bytes = 0u64;
for key in stale {
if let Some(victim) = guard.map.remove(&key) {
guard.total_bytes = guard.total_bytes.saturating_sub(victim.size_bytes);
guard.lru.retain(|k| k != &key);
entries += 1;
bytes = bytes.saturating_add(victim.size_bytes);
drop(victim);
}
}
(entries, bytes)
}
pub fn clear(&self) {
let mut guard = self.inner.lock().expect("backend cache poisoned");
guard.map.clear();
guard.lru.clear();
guard.total_bytes = 0;
}
pub fn stats(&self) -> (usize, u64, u64) {
let guard = self.inner.lock().expect("backend cache poisoned");
(guard.map.len(), guard.total_bytes, self.budget_bytes)
}
}
pub fn estimate_model_size(model_dir: &Path) -> u64 {
fn visit(dir: &Path, total: &mut u64) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
visit(&path, total);
continue;
}
if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
if let Ok(meta) = path.metadata() {
*total = total.saturating_add(meta.len());
}
}
}
}
let mut total = 0u64;
visit(model_dir, &mut total);
total
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_hit_returns_same_handle() {
let cache: BackendCache<u32> = BackendCache::new(1024);
let a = cache.get_or_load::<()>("a", 100, || Ok(42)).unwrap();
let b = cache
.get_or_load::<()>("a", 100, || panic!("should not reload"))
.unwrap();
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn evicts_lru_when_over_budget() {
let cache: BackendCache<u32> = BackendCache::new(250);
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
let _a_again = cache
.get_or_load::<()>("a", 100, || panic!("cached"))
.unwrap();
let _c = cache.get_or_load::<()>("c", 100, || Ok(3)).unwrap();
let (n, bytes, budget) = cache.stats();
assert_eq!(n, 2, "a + c should remain, b evicted");
assert_eq!(bytes, 200);
assert_eq!(budget, 250);
}
#[test]
fn zero_budget_disables_cache_but_returns_handle() {
let cache: BackendCache<u32> = BackendCache::new(0);
let mut load_count = 0u32;
let a = cache
.get_or_load::<()>("a", 100, || {
load_count += 1;
Ok(1)
})
.unwrap();
assert_eq!(*a.lock().unwrap(), 1);
let b = cache
.get_or_load::<()>("a", 100, || {
load_count += 1;
Ok(1)
})
.unwrap();
assert_eq!(*b.lock().unwrap(), 1);
assert_eq!(load_count, 2, "disabled cache reloads every call");
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn evict_idle_drops_stale_entries_below_capacity() {
let cache: BackendCache<u32> =
BackendCache::with_idle_ttl(1_000_000, Some(Duration::from_millis(20)));
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
assert_eq!(cache.stats().0, 2);
assert_eq!(cache.evict_idle(), (0, 0));
std::thread::sleep(Duration::from_millis(40));
let _a_again = cache.get_or_load::<()>("a", 100, || panic!("cached")).unwrap();
let (entries, bytes) = cache.evict_idle();
assert_eq!((entries, bytes), (1, 100), "only b should be swept");
let (n, total, _) = cache.stats();
assert_eq!(n, 1, "a remains");
assert_eq!(total, 100);
}
#[test]
fn evict_idle_noop_when_disabled() {
let cache: BackendCache<u32> = BackendCache::new(1024); let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
assert_eq!(cache.evict_idle(), (0, 0));
assert_eq!(cache.stats().0, 1, "disabled idle eviction keeps the entry");
}
#[test]
fn invalidate_removes_key() {
let cache: BackendCache<u32> = BackendCache::new(1024);
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
assert_eq!(cache.stats().0, 1);
cache.invalidate("a");
assert_eq!(cache.stats().0, 0);
}
}