use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub type CachedBackend<T> = Arc<Mutex<T>>;
pub struct SharedModelBudget {
budget_bytes: u64,
total_bytes: AtomicU64,
}
impl SharedModelBudget {
pub fn new(budget_bytes: u64) -> Arc<Self> {
Arc::new(Self {
budget_bytes,
total_bytes: AtomicU64::new(0),
})
}
pub fn from_env_or(default_mb: u64) -> Arc<Self> {
let mb = std::env::var("CAR_INFERENCE_MODEL_CACHE_MB")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(default_mb);
Self::new(mb.saturating_mul(1024 * 1024))
}
fn add(&self, n: u64) {
self.total_bytes.fetch_add(n, Ordering::Relaxed);
}
fn sub(&self, n: u64) {
let _ = self
.total_bytes
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |cur| {
Some(cur.saturating_sub(n))
});
}
fn total(&self) -> u64 {
self.total_bytes.load(Ordering::Relaxed)
}
fn over_budget(&self) -> bool {
self.budget_bytes != 0 && self.total() > self.budget_bytes
}
pub fn is_disabled(&self) -> bool {
self.budget_bytes == 0
}
}
pub fn default_model_cache_mb() -> u64 {
crate::hardware::HardwareInfo::detect().max_model_mb
}
pub fn idle_ttl_from_env() -> Option<Duration> {
let idle_secs = std::env::var("CAR_INFERENCE_MODEL_IDLE_SECS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(300);
(idle_secs > 0).then(|| Duration::from_secs(idle_secs))
}
struct Entry<T> {
backend: CachedBackend<T>,
size_bytes: u64,
last_used: Instant,
}
struct Inner<T> {
map: HashMap<String, Entry<T>>,
lru: VecDeque<String>,
}
pub struct BackendCache<T: Send + 'static> {
inner: Mutex<Inner<T>>,
budget: Arc<SharedModelBudget>,
idle_ttl: Option<Duration>,
}
impl<T: Send + 'static> BackendCache<T> {
pub fn new(budget_bytes: u64) -> Self {
Self::from_shared(SharedModelBudget::new(budget_bytes), None)
}
pub fn with_idle_ttl(budget_bytes: u64, idle_ttl: Option<Duration>) -> Self {
Self::from_shared(SharedModelBudget::new(budget_bytes), idle_ttl)
}
pub fn from_shared(budget: Arc<SharedModelBudget>, idle_ttl: Option<Duration>) -> Self {
Self {
inner: Mutex::new(Inner {
map: HashMap::new(),
lru: VecDeque::new(),
}),
budget,
idle_ttl,
}
}
pub fn from_env() -> Self {
Self::from_shared(
SharedModelBudget::from_env_or(default_model_cache_mb()),
idle_ttl_from_env(),
)
}
pub fn is_disabled(&self) -> bool {
self.budget.is_disabled()
}
pub fn get_or_load<E>(
&self,
key: &str,
size_bytes: u64,
loader: impl FnOnce() -> Result<T, E>,
) -> Result<CachedBackend<T>, E> {
{
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(entry) = guard.map.get_mut(key) {
let handle = Arc::clone(&entry.backend);
entry.last_used = Instant::now();
guard.lru.retain(|k| k != key);
guard.lru.push_back(key.to_string());
return Ok(handle);
}
}
let backend = loader()?;
let handle = Arc::new(Mutex::new(backend));
if self.budget.is_disabled() {
return Ok(handle);
}
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(existing) = guard.map.get(key) {
return Ok(Arc::clone(&existing.backend));
}
self.budget.add(size_bytes);
guard.map.insert(
key.to_string(),
Entry {
backend: Arc::clone(&handle),
size_bytes,
last_used: Instant::now(),
},
);
guard.lru.push_back(key.to_string());
while self.budget.over_budget() {
let Some(victim_key) = guard.lru.pop_front() else {
break;
};
if victim_key == key {
guard.lru.push_front(victim_key);
break;
}
if let Some(victim) = guard.map.remove(&victim_key) {
self.budget.sub(victim.size_bytes);
drop(victim);
}
}
Ok(handle)
}
pub fn invalidate(&self, key: &str) {
let mut guard = self.inner.lock().expect("backend cache poisoned");
if let Some(entry) = guard.map.remove(key) {
self.budget.sub(entry.size_bytes);
guard.lru.retain(|k| k != key);
}
}
pub fn evict_idle(&self) -> (usize, u64) {
let Some(ttl) = self.idle_ttl else {
return (0, 0);
};
let now = Instant::now();
let mut guard = self.inner.lock().expect("backend cache poisoned");
let stale: Vec<String> = guard
.map
.iter()
.filter(|(_, e)| now.duration_since(e.last_used) >= ttl)
.map(|(k, _)| k.clone())
.collect();
let mut entries = 0usize;
let mut bytes = 0u64;
for key in stale {
if let Some(victim) = guard.map.remove(&key) {
self.budget.sub(victim.size_bytes);
guard.lru.retain(|k| k != &key);
entries += 1;
bytes = bytes.saturating_add(victim.size_bytes);
drop(victim);
}
}
(entries, bytes)
}
pub fn clear(&self) {
let mut guard = self.inner.lock().expect("backend cache poisoned");
let freed: u64 = guard.map.values().map(|e| e.size_bytes).sum();
guard.map.clear();
guard.lru.clear();
self.budget.sub(freed);
}
pub fn stats(&self) -> (usize, u64, u64) {
let guard = self.inner.lock().expect("backend cache poisoned");
(guard.map.len(), self.budget.total(), self.budget.budget_bytes)
}
}
pub fn estimate_model_size(model_dir: &Path) -> u64 {
fn visit(dir: &Path, total: &mut u64) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
visit(&path, total);
continue;
}
if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
if let Ok(meta) = path.metadata() {
*total = total.saturating_add(meta.len());
}
}
}
}
let mut total = 0u64;
visit(model_dir, &mut total);
total
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_hit_returns_same_handle() {
let cache: BackendCache<u32> = BackendCache::new(1024);
let a = cache.get_or_load::<()>("a", 100, || Ok(42)).unwrap();
let b = cache
.get_or_load::<()>("a", 100, || panic!("should not reload"))
.unwrap();
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn evicts_lru_when_over_budget() {
let cache: BackendCache<u32> = BackendCache::new(250);
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
let _a_again = cache
.get_or_load::<()>("a", 100, || panic!("cached"))
.unwrap();
let _c = cache.get_or_load::<()>("c", 100, || Ok(3)).unwrap();
let (n, bytes, budget) = cache.stats();
assert_eq!(n, 2, "a + c should remain, b evicted");
assert_eq!(bytes, 200);
assert_eq!(budget, 250);
}
#[test]
fn zero_budget_disables_cache_but_returns_handle() {
let cache: BackendCache<u32> = BackendCache::new(0);
let mut load_count = 0u32;
let a = cache
.get_or_load::<()>("a", 100, || {
load_count += 1;
Ok(1)
})
.unwrap();
assert_eq!(*a.lock().unwrap(), 1);
let b = cache
.get_or_load::<()>("a", 100, || {
load_count += 1;
Ok(1)
})
.unwrap();
assert_eq!(*b.lock().unwrap(), 1);
assert_eq!(load_count, 2, "disabled cache reloads every call");
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn evict_idle_drops_stale_entries_below_capacity() {
let cache: BackendCache<u32> =
BackendCache::with_idle_ttl(1_000_000, Some(Duration::from_millis(20)));
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
assert_eq!(cache.stats().0, 2);
assert_eq!(cache.evict_idle(), (0, 0));
std::thread::sleep(Duration::from_millis(40));
let _a_again = cache
.get_or_load::<()>("a", 100, || panic!("cached"))
.unwrap();
let (entries, bytes) = cache.evict_idle();
assert_eq!((entries, bytes), (1, 100), "only b should be swept");
let (n, total, _) = cache.stats();
assert_eq!(n, 1, "a remains");
assert_eq!(total, 100);
}
#[test]
fn evict_idle_noop_when_disabled() {
let cache: BackendCache<u32> = BackendCache::new(1024); let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
assert_eq!(cache.evict_idle(), (0, 0));
assert_eq!(cache.stats().0, 1, "disabled idle eviction keeps the entry");
}
#[test]
fn invalidate_removes_key() {
let cache: BackendCache<u32> = BackendCache::new(1024);
let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
assert_eq!(cache.stats().0, 1);
cache.invalidate("a");
assert_eq!(cache.stats().0, 0);
}
#[test]
fn shared_budget_spans_caches_and_each_self_trims() {
let budget = SharedModelBudget::new(250);
let a: BackendCache<u32> = BackendCache::from_shared(budget.clone(), None);
let b: BackendCache<u32> = BackendCache::from_shared(budget.clone(), None);
let _ = a.get_or_load::<()>("a1", 100, || Ok(1)).unwrap();
let _ = b.get_or_load::<()>("b1", 100, || Ok(2)).unwrap();
assert_eq!(budget.total(), 200);
let _ = a.get_or_load::<()>("a2", 100, || Ok(3)).unwrap(); assert_eq!(budget.total(), 200, "A self-trimmed to the shared budget");
assert_eq!(a.stats().0, 1, "A kept a2, evicted a1");
assert_eq!(b.stats().0, 1, "B's b1 not evicted by A");
}
#[test]
fn default_budget_is_ram_derived_not_flat() {
let mb = default_model_cache_mb();
assert!(mb > 0, "RAM-derived default should be positive");
}
}