car-inference 0.13.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
//! LRU-evicting cache for loaded inference backends.
//!
//! Solves three problems at once:
//!
//! 1. **Cold start on every call.** Before: `FluxBackend::load()` /
//!    `LtxBackend::load()` / `KokoroBackend::load()` ran for every
//!    `generate_image` / `generate_video` / `synth_tts` request, paying
//!    the 1–14 s model-load cost on a hot path. After: first call loads,
//!    subsequent calls get a cheap `Arc<Mutex<T>>` handle.
//!
//! 2. **Concurrent calls racing on the same backend.** `mlx-rs` `Array`
//!    is `!Sync`; two tokio tasks calling the same backend simultaneously
//!    is undefined behavior. The per-entry `Mutex` serializes concurrent
//!    callers onto the same backend. Different backends still run in
//!    parallel (MLX itself queues them on the single Metal driver).
//!
//! 3. **Unbounded RAM growth.** Before: loading Flux + LTX + Kokoro +
//!    every text-gen model held ~30 GB of quantized weights forever. The
//!    cache tracks an approximate per-entry size (sum of the model
//!    directory's `.safetensors` bytes) and evicts LRU entries once the
//!    total exceeds `budget_bytes`. Default budget via
//!    `CAR_INFERENCE_MODEL_CACHE_MB` (default 24 GB). Set to 0 to
//!    effectively disable caching.
//!
//! The cache is generic over `T: Send + 'static`. Inference backends
//! don't need to implement any trait — just wrap them on insert.
//!
//! Invariant: an evicted entry is only removed from the cache map; any
//! outstanding `Arc<Mutex<T>>` handle continues to work until the last
//! caller drops it. This makes eviction safe even during a long-running
//! inference call — RAM is reclaimed lazily when the last user finishes.

use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::sync::{Arc, Mutex};

/// Handle to a cached backend. Callers lock the inner mutex for the
/// duration of an inference call to serialize with concurrent requests
/// for the same model.
pub type CachedBackend<T> = Arc<Mutex<T>>;

struct Entry<T> {
    backend: CachedBackend<T>,
    size_bytes: u64,
}

struct Inner<T> {
    /// Loaded backends keyed by stable ID (typically `ModelSchema.id`).
    map: HashMap<String, Entry<T>>,
    /// LRU recency list — back is most recent, front is oldest.
    lru: VecDeque<String>,
    total_bytes: u64,
}

/// LRU-bounded cache of loaded inference backends.
///
/// Not itself `Sync` over `T` — the cache stores handles only, and the
/// lock guarding the map is coarse (held only for look-up / insert /
/// evict). Per-entry mutation is serialized by each entry's own `Mutex`.
pub struct BackendCache<T: Send + 'static> {
    inner: Mutex<Inner<T>>,
    budget_bytes: u64,
}

impl<T: Send + 'static> BackendCache<T> {
    /// Create a new cache with the given RAM budget in bytes.
    pub fn new(budget_bytes: u64) -> Self {
        Self {
            inner: Mutex::new(Inner {
                map: HashMap::new(),
                lru: VecDeque::new(),
                total_bytes: 0,
            }),
            budget_bytes,
        }
    }

    /// Read the `CAR_INFERENCE_MODEL_CACHE_MB` env var, default 24 GB.
    /// A value of 0 disables caching (every call loads fresh).
    pub fn from_env() -> Self {
        let mb = std::env::var("CAR_INFERENCE_MODEL_CACHE_MB")
            .ok()
            .and_then(|v| v.parse::<u64>().ok())
            .unwrap_or(24 * 1024);
        Self::new(mb.saturating_mul(1024 * 1024))
    }

    /// Is this a disabled (zero-budget) cache?
    pub fn is_disabled(&self) -> bool {
        self.budget_bytes == 0
    }

    /// Get a handle to the cached backend, or load + insert it.
    ///
    /// `loader` is invoked only on a cache miss. If the new entry plus
    /// existing entries exceed the budget, the oldest entries are
    /// evicted until we're within budget (or only the new entry remains).
    ///
    /// `size_bytes` should be an approximate on-disk or in-memory size
    /// for the loaded backend; [`estimate_model_size`] is a reasonable
    /// default that sums the `.safetensors` file sizes in a model dir.
    pub fn get_or_load<E>(
        &self,
        key: &str,
        size_bytes: u64,
        loader: impl FnOnce() -> Result<T, E>,
    ) -> Result<CachedBackend<T>, E> {
        // Fast path: already cached.
        {
            let mut guard = self.inner.lock().expect("backend cache poisoned");
            if let Some(entry) = guard.map.get(key) {
                let handle = Arc::clone(&entry.backend);
                // Promote to most-recently-used.
                guard.lru.retain(|k| k != key);
                guard.lru.push_back(key.to_string());
                return Ok(handle);
            }
        }

        // Slow path: load outside the lock. Concurrent loads of the same
        // key may duplicate work, but the second caller's insert simply
        // loses the race — memory isn't leaked because of the evict loop.
        let backend = loader()?;
        let handle = Arc::new(Mutex::new(backend));

        if self.budget_bytes == 0 {
            // Caching disabled: return the handle without retaining it.
            return Ok(handle);
        }

        let mut guard = self.inner.lock().expect("backend cache poisoned");

        // Re-check in case someone else inserted while we were loading.
        if let Some(existing) = guard.map.get(key) {
            return Ok(Arc::clone(&existing.backend));
        }

        guard.total_bytes = guard.total_bytes.saturating_add(size_bytes);
        guard.map.insert(
            key.to_string(),
            Entry {
                backend: Arc::clone(&handle),
                size_bytes,
            },
        );
        guard.lru.push_back(key.to_string());

        // Evict LRU entries until within budget (but never evict the
        // just-inserted key — that would defeat the purpose).
        while guard.total_bytes > self.budget_bytes {
            let Some(victim_key) = guard.lru.pop_front() else {
                break;
            };
            if victim_key == key {
                // Skipped our own entry; put it back at the front and stop
                // (we've already visited every other key).
                guard.lru.push_front(victim_key);
                break;
            }
            if let Some(victim) = guard.map.remove(&victim_key) {
                guard.total_bytes = guard.total_bytes.saturating_sub(victim.size_bytes);
                // The `Arc` may still have outstanding handles; they
                // continue to work. Memory is reclaimed when the last
                // one drops.
                drop(victim);
            }
        }

        Ok(handle)
    }

    /// Manually remove a key, e.g. when model weights have been updated
    /// on disk. Outstanding `Arc<Mutex<T>>` handles remain valid.
    pub fn invalidate(&self, key: &str) {
        let mut guard = self.inner.lock().expect("backend cache poisoned");
        if let Some(entry) = guard.map.remove(key) {
            guard.total_bytes = guard.total_bytes.saturating_sub(entry.size_bytes);
            guard.lru.retain(|k| k != key);
        }
    }

    /// Evict all entries. Outstanding handles continue to work.
    pub fn clear(&self) {
        let mut guard = self.inner.lock().expect("backend cache poisoned");
        guard.map.clear();
        guard.lru.clear();
        guard.total_bytes = 0;
    }

    /// Returns `(entries, total_bytes, budget_bytes)` for diagnostics.
    pub fn stats(&self) -> (usize, u64, u64) {
        let guard = self.inner.lock().expect("backend cache poisoned");
        (guard.map.len(), guard.total_bytes, self.budget_bytes)
    }
}

/// Sum the sizes of all `.safetensors` files under `model_dir`.
/// A loose upper-bound RAM estimate for the loaded model (quantized
/// tensors live in MLX-owned memory roughly matching their on-disk size).
pub fn estimate_model_size(model_dir: &Path) -> u64 {
    fn visit(dir: &Path, total: &mut u64) {
        let Ok(entries) = std::fs::read_dir(dir) else {
            return;
        };
        for entry in entries.flatten() {
            let path = entry.path();
            if path.is_dir() {
                visit(&path, total);
                continue;
            }
            if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
                if let Ok(meta) = path.metadata() {
                    *total = total.saturating_add(meta.len());
                }
            }
        }
    }
    let mut total = 0u64;
    visit(model_dir, &mut total);
    total
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn cache_hit_returns_same_handle() {
        let cache: BackendCache<u32> = BackendCache::new(1024);
        let a = cache.get_or_load::<()>("a", 100, || Ok(42)).unwrap();
        let b = cache
            .get_or_load::<()>("a", 100, || panic!("should not reload"))
            .unwrap();
        assert!(Arc::ptr_eq(&a, &b));
    }

    #[test]
    fn evicts_lru_when_over_budget() {
        let cache: BackendCache<u32> = BackendCache::new(250);
        let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
        let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
        // Total 200, still under 250. Touch `a` so `b` is LRU.
        let _a_again = cache
            .get_or_load::<()>("a", 100, || panic!("cached"))
            .unwrap();
        // Insert `c`: pushes us to 300, evicts `b` (the LRU one).
        let _c = cache.get_or_load::<()>("c", 100, || Ok(3)).unwrap();
        let (n, bytes, budget) = cache.stats();
        assert_eq!(n, 2, "a + c should remain, b evicted");
        assert_eq!(bytes, 200);
        assert_eq!(budget, 250);
    }

    #[test]
    fn zero_budget_disables_cache_but_returns_handle() {
        let cache: BackendCache<u32> = BackendCache::new(0);
        let mut load_count = 0u32;
        let a = cache
            .get_or_load::<()>("a", 100, || {
                load_count += 1;
                Ok(1)
            })
            .unwrap();
        assert_eq!(*a.lock().unwrap(), 1);
        let b = cache
            .get_or_load::<()>("a", 100, || {
                load_count += 1;
                Ok(1)
            })
            .unwrap();
        assert_eq!(*b.lock().unwrap(), 1);
        assert_eq!(load_count, 2, "disabled cache reloads every call");
        assert!(!Arc::ptr_eq(&a, &b));
    }

    #[test]
    fn invalidate_removes_key() {
        let cache: BackendCache<u32> = BackendCache::new(1024);
        let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
        assert_eq!(cache.stats().0, 1);
        cache.invalidate("a");
        assert_eq!(cache.stats().0, 0);
    }
}