use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::sync::{Arc, Mutex};
pub type CachedBackend<T> = Arc<Mutex<T>>;
struct Entry<T> {
backend: CachedBackend<T>,
size_bytes: u64,
}
struct Inner<T> {
map: HashMap<String, Entry<T>>,
lru: VecDeque<String>,
total_bytes: u64,
}
pub struct BackendCache<T: Send + 'static> {
inner: Mutex<Inner<T>>,
budget_bytes: u64,
}
impl<T: Send + 'static> BackendCache<T> {
pub fn new(budget_bytes: u64) -> Self {
Self {
inner: Mutex::new(Inner {
map: HashMap::new(),
lru: VecDeque::new(),
total_bytes: 0,
}),
budget_bytes,
}
}
pub fn from_env() -> Self {
let mb = std::env::var("CAR_INFERENCE_MODEL_CACHE_MB")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(24 * 1024);
Self::new(mb.saturating_mul(1024 * 1024))
}
pub fn is_disabled(&self) -> bool {
self.budget_bytes == 0
}
pub fn get_or_load<E>(
&self,
key: &str,
size_bytes: u64,
loader: impl FnOnce() -> Result<T, E>,
) -> Result<CachedBackend<T>, E> {
{
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(entry) = guard.map.get(key) {
let handle = Arc::clone(&entry.backend);
guard.lru.retain(|k| k != key);
guard.lru.push_back(key.to_string());
return Ok(handle);
}
}
let backend = loader()?;
let handle = Arc::new(Mutex::new(backend));
if self.budget_bytes == 0 {
return Ok(handle);
}
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(existing) = guard.map.get(key) {
return Ok(Arc::clone(&existing.backend));
}
guard.total_bytes = guard.total_bytes.saturating_add(size_bytes);
guard.map.insert(
key.to_string(),
Entry {
backend: Arc::clone(&handle),
size_bytes,
},
);
guard.lru.push_back(key.to_string());
while guard.total_bytes > self.budget_bytes {
let Some(victim_key) = guard.lru.pop_front() else {
break;
};
if victim_key == key {
guard.lru.push_front(victim_key);
break;
}
if let Some(victim) = guard.map.remove(&victim_key) {
guard.total_bytes = guard.total_bytes.saturating_sub(victim.size_bytes);
drop(victim);
}
}
Ok(handle)
}
pub fn invalidate(&self, key: &str) {
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(entry) = guard.map.remove(key) {
guard.total_bytes = guard.total_bytes.saturating_sub(entry.size_bytes);
guard.lru.retain(|k| k != key);
}
}
pub fn clear(&self) {
let mut guard = self.inner.lock().expect("backend cache poisoned");
guard.map.clear();
guard.lru.clear();
guard.total_bytes = 0;
}
pub fn stats(&self) -> (usize, u64, u64) {
let guard = self.inner.lock().expect("backend cache poisoned");
(guard.map.len(), guard.total_bytes, self.budget_bytes)
}
}
pub fn estimate_model_size(model_dir: &Path) -> u64 {
fn visit(dir: &Path, total: &mut u64) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
visit(&path, total);
continue;
}
if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
if let Ok(meta) = path.metadata() {
*total = total.saturating_add(meta.len());
}
}
}
}
let mut total = 0u64;
visit(model_dir, &mut total);
total
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_hit_returns_same_handle() {
let cache: BackendCache<u32> = BackendCache::new(1024);
let a = cache.get_or_load::<()>("a", 100, || Ok(42)).unwrap();
let b = cache
.get_or_load::<()>("a", 100, || panic!("should not reload"))
.unwrap();
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn evicts_lru_when_over_budget() {
let cache: BackendCache<u32> = BackendCache::new(250);
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
let _a_again = cache
.get_or_load::<()>("a", 100, || panic!("cached"))
.unwrap();
let _c = cache.get_or_load::<()>("c", 100, || Ok(3)).unwrap();
let (n, bytes, budget) = cache.stats();
assert_eq!(n, 2, "a + c should remain, b evicted");
assert_eq!(bytes, 200);
assert_eq!(budget, 250);
}
#[test]
fn zero_budget_disables_cache_but_returns_handle() {
let cache: BackendCache<u32> = BackendCache::new(0);
let mut load_count = 0u32;
let a = cache
.get_or_load::<()>("a", 100, || {
load_count += 1;
Ok(1)
})
.unwrap();
assert_eq!(*a.lock().unwrap(), 1);
let b = cache
.get_or_load::<()>("a", 100, || {
load_count += 1;
Ok(1)
})
.unwrap();
assert_eq!(*b.lock().unwrap(), 1);
assert_eq!(load_count, 2, "disabled cache reloads every call");
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn invalidate_removes_key() {
let cache: BackendCache<u32> = BackendCache::new(1024);
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
assert_eq!(cache.stats().0, 1);
cache.invalidate("a");
assert_eq!(cache.stats().0, 0);
}
}