Skip to main content

car_inference/
backend_cache.rs

1//! LRU-evicting cache for loaded inference backends.
2//!
3//! Solves three problems at once:
4//!
5//! 1. **Cold start on every call.** Before: `FluxBackend::load()` /
6//!    `LtxBackend::load()` / `KokoroBackend::load()` ran for every
7//!    `generate_image` / `generate_video` / `synth_tts` request, paying
8//!    the 1–14 s model-load cost on a hot path. After: first call loads,
9//!    subsequent calls get a cheap `Arc<Mutex<T>>` handle.
10//!
11//! 2. **Concurrent calls racing on the same backend.** `mlx-rs` `Array`
12//!    is `!Sync`; two tokio tasks calling the same backend simultaneously
13//!    is undefined behavior. The per-entry `Mutex` serializes concurrent
14//!    callers onto the same backend. Different backends still run in
15//!    parallel (MLX itself queues them on the single Metal driver).
16//!
17//! 3. **Unbounded RAM growth.** Before: loading Flux + LTX + Kokoro +
18//!    every text-gen model held ~30 GB of quantized weights forever. The
19//!    cache tracks an approximate per-entry size (sum of the model
20//!    directory's `.safetensors` bytes) and evicts LRU entries once the
21//!    total exceeds `budget_bytes`. Default budget via
22//!    `CAR_INFERENCE_MODEL_CACHE_MB` (default 24 GB). Set to 0 to
23//!    effectively disable caching.
24//!
25//! The cache is generic over `T: Send + 'static`. Inference backends
26//! don't need to implement any trait — just wrap them on insert.
27//!
28//! Invariant: an evicted entry is only removed from the cache map; any
29//! outstanding `Arc<Mutex<T>>` handle continues to work until the last
30//! caller drops it. This makes eviction safe even during a long-running
31//! inference call — RAM is reclaimed lazily when the last user finishes.
32
33use std::collections::{HashMap, VecDeque};
34use std::path::Path;
35use std::sync::{Arc, Mutex};
36
37/// Handle to a cached backend. Callers lock the inner mutex for the
38/// duration of an inference call to serialize with concurrent requests
39/// for the same model.
40pub type CachedBackend<T> = Arc<Mutex<T>>;
41
42struct Entry<T> {
43    backend: CachedBackend<T>,
44    size_bytes: u64,
45}
46
47struct Inner<T> {
48    /// Loaded backends keyed by stable ID (typically `ModelSchema.id`).
49    map: HashMap<String, Entry<T>>,
50    /// LRU recency list — back is most recent, front is oldest.
51    lru: VecDeque<String>,
52    total_bytes: u64,
53}
54
55/// LRU-bounded cache of loaded inference backends.
56///
57/// Not itself `Sync` over `T` — the cache stores handles only, and the
58/// lock guarding the map is coarse (held only for look-up / insert /
59/// evict). Per-entry mutation is serialized by each entry's own `Mutex`.
60pub struct BackendCache<T: Send + 'static> {
61    inner: Mutex<Inner<T>>,
62    budget_bytes: u64,
63}
64
65impl<T: Send + 'static> BackendCache<T> {
66    /// Create a new cache with the given RAM budget in bytes.
67    pub fn new(budget_bytes: u64) -> Self {
68        Self {
69            inner: Mutex::new(Inner {
70                map: HashMap::new(),
71                lru: VecDeque::new(),
72                total_bytes: 0,
73            }),
74            budget_bytes,
75        }
76    }
77
78    /// Read the `CAR_INFERENCE_MODEL_CACHE_MB` env var, default 24 GB.
79    /// A value of 0 disables caching (every call loads fresh).
80    pub fn from_env() -> Self {
81        let mb = std::env::var("CAR_INFERENCE_MODEL_CACHE_MB")
82            .ok()
83            .and_then(|v| v.parse::<u64>().ok())
84            .unwrap_or(24 * 1024);
85        Self::new(mb.saturating_mul(1024 * 1024))
86    }
87
88    /// Is this a disabled (zero-budget) cache?
89    pub fn is_disabled(&self) -> bool {
90        self.budget_bytes == 0
91    }
92
93    /// Get a handle to the cached backend, or load + insert it.
94    ///
95    /// `loader` is invoked only on a cache miss. If the new entry plus
96    /// existing entries exceed the budget, the oldest entries are
97    /// evicted until we're within budget (or only the new entry remains).
98    ///
99    /// `size_bytes` should be an approximate on-disk or in-memory size
100    /// for the loaded backend; [`estimate_model_size`] is a reasonable
101    /// default that sums the `.safetensors` file sizes in a model dir.
102    pub fn get_or_load<E>(
103        &self,
104        key: &str,
105        size_bytes: u64,
106        loader: impl FnOnce() -> Result<T, E>,
107    ) -> Result<CachedBackend<T>, E> {
108        // Fast path: already cached.
109        {
110            let mut guard = self.inner.lock().expect("backend cache poisoned");
111            if let Some(entry) = guard.map.get(key) {
112                let handle = Arc::clone(&entry.backend);
113                // Promote to most-recently-used.
114                guard.lru.retain(|k| k != key);
115                guard.lru.push_back(key.to_string());
116                return Ok(handle);
117            }
118        }
119
120        // Slow path: load outside the lock. Concurrent loads of the same
121        // key may duplicate work, but the second caller's insert simply
122        // loses the race — memory isn't leaked because of the evict loop.
123        let backend = loader()?;
124        let handle = Arc::new(Mutex::new(backend));
125
126        if self.budget_bytes == 0 {
127            // Caching disabled: return the handle without retaining it.
128            return Ok(handle);
129        }
130
131        let mut guard = self.inner.lock().expect("backend cache poisoned");
132
133        // Re-check in case someone else inserted while we were loading.
134        if let Some(existing) = guard.map.get(key) {
135            return Ok(Arc::clone(&existing.backend));
136        }
137
138        guard.total_bytes = guard.total_bytes.saturating_add(size_bytes);
139        guard.map.insert(
140            key.to_string(),
141            Entry {
142                backend: Arc::clone(&handle),
143                size_bytes,
144            },
145        );
146        guard.lru.push_back(key.to_string());
147
148        // Evict LRU entries until within budget (but never evict the
149        // just-inserted key — that would defeat the purpose).
150        while guard.total_bytes > self.budget_bytes {
151            let Some(victim_key) = guard.lru.pop_front() else {
152                break;
153            };
154            if victim_key == key {
155                // Skipped our own entry; put it back at the front and stop
156                // (we've already visited every other key).
157                guard.lru.push_front(victim_key);
158                break;
159            }
160            if let Some(victim) = guard.map.remove(&victim_key) {
161                guard.total_bytes = guard.total_bytes.saturating_sub(victim.size_bytes);
162                // The `Arc` may still have outstanding handles; they
163                // continue to work. Memory is reclaimed when the last
164                // one drops.
165                drop(victim);
166            }
167        }
168
169        Ok(handle)
170    }
171
172    /// Manually remove a key, e.g. when model weights have been updated
173    /// on disk. Outstanding `Arc<Mutex<T>>` handles remain valid.
174    pub fn invalidate(&self, key: &str) {
175        let mut guard = self.inner.lock().expect("backend cache poisoned");
176        if let Some(entry) = guard.map.remove(key) {
177            guard.total_bytes = guard.total_bytes.saturating_sub(entry.size_bytes);
178            guard.lru.retain(|k| k != key);
179        }
180    }
181
182    /// Evict all entries. Outstanding handles continue to work.
183    pub fn clear(&self) {
184        let mut guard = self.inner.lock().expect("backend cache poisoned");
185        guard.map.clear();
186        guard.lru.clear();
187        guard.total_bytes = 0;
188    }
189
190    /// Returns `(entries, total_bytes, budget_bytes)` for diagnostics.
191    pub fn stats(&self) -> (usize, u64, u64) {
192        let guard = self.inner.lock().expect("backend cache poisoned");
193        (guard.map.len(), guard.total_bytes, self.budget_bytes)
194    }
195}
196
197/// Sum the sizes of all `.safetensors` files under `model_dir`.
198/// A loose upper-bound RAM estimate for the loaded model (quantized
199/// tensors live in MLX-owned memory roughly matching their on-disk size).
200pub fn estimate_model_size(model_dir: &Path) -> u64 {
201    fn visit(dir: &Path, total: &mut u64) {
202        let Ok(entries) = std::fs::read_dir(dir) else {
203            return;
204        };
205        for entry in entries.flatten() {
206            let path = entry.path();
207            if path.is_dir() {
208                visit(&path, total);
209                continue;
210            }
211            if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
212                if let Ok(meta) = path.metadata() {
213                    *total = total.saturating_add(meta.len());
214                }
215            }
216        }
217    }
218    let mut total = 0u64;
219    visit(model_dir, &mut total);
220    total
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn cache_hit_returns_same_handle() {
229        let cache: BackendCache<u32> = BackendCache::new(1024);
230        let a = cache.get_or_load::<()>("a", 100, || Ok(42)).unwrap();
231        let b = cache
232            .get_or_load::<()>("a", 100, || panic!("should not reload"))
233            .unwrap();
234        assert!(Arc::ptr_eq(&a, &b));
235    }
236
237    #[test]
238    fn evicts_lru_when_over_budget() {
239        let cache: BackendCache<u32> = BackendCache::new(250);
240        let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
241        let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
242        // Total 200, still under 250. Touch `a` so `b` is LRU.
243        let _a_again = cache
244            .get_or_load::<()>("a", 100, || panic!("cached"))
245            .unwrap();
246        // Insert `c`: pushes us to 300, evicts `b` (the LRU one).
247        let _c = cache.get_or_load::<()>("c", 100, || Ok(3)).unwrap();
248        let (n, bytes, budget) = cache.stats();
249        assert_eq!(n, 2, "a + c should remain, b evicted");
250        assert_eq!(bytes, 200);
251        assert_eq!(budget, 250);
252    }
253
254    #[test]
255    fn zero_budget_disables_cache_but_returns_handle() {
256        let cache: BackendCache<u32> = BackendCache::new(0);
257        let mut load_count = 0u32;
258        let a = cache
259            .get_or_load::<()>("a", 100, || {
260                load_count += 1;
261                Ok(1)
262            })
263            .unwrap();
264        assert_eq!(*a.lock().unwrap(), 1);
265        let b = cache
266            .get_or_load::<()>("a", 100, || {
267                load_count += 1;
268                Ok(1)
269            })
270            .unwrap();
271        assert_eq!(*b.lock().unwrap(), 1);
272        assert_eq!(load_count, 2, "disabled cache reloads every call");
273        assert!(!Arc::ptr_eq(&a, &b));
274    }
275
276    #[test]
277    fn invalidate_removes_key() {
278        let cache: BackendCache<u32> = BackendCache::new(1024);
279        let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
280        assert_eq!(cache.stats().0, 1);
281        cache.invalidate("a");
282        assert_eq!(cache.stats().0, 0);
283    }
284}