car_inference/backend_cache.rs
1//! LRU-evicting cache for loaded inference backends.
2//!
3//! Solves three problems at once:
4//!
5//! 1. **Cold start on every call.** Before: `FluxBackend::load()` /
6//! `LtxBackend::load()` / `KokoroBackend::load()` ran for every
7//! `generate_image` / `generate_video` / `synth_tts` request, paying
8//! the 1–14 s model-load cost on a hot path. After: first call loads,
9//! subsequent calls get a cheap `Arc<Mutex<T>>` handle.
10//!
11//! 2. **Concurrent calls racing on the same backend.** `mlx-rs` `Array`
12//! is `!Sync`; two tokio tasks calling the same backend simultaneously
13//! is undefined behavior. The per-entry `Mutex` serializes concurrent
14//! callers onto the same backend. Different backends still run in
15//! parallel (MLX itself queues them on the single Metal driver).
16//!
17//! 3. **Unbounded RAM growth.** Before: loading Flux + LTX + Kokoro +
18//! every text-gen model held ~30 GB of quantized weights forever. The
19//! cache tracks an approximate per-entry size (sum of the model
20//! directory's `.safetensors` bytes) and evicts LRU entries once the
21//! total exceeds `budget_bytes`. Default budget via
22//! `CAR_INFERENCE_MODEL_CACHE_MB` (default 24 GB). Set to 0 to
23//! effectively disable caching.
24//!
25//! The cache is generic over `T: Send + 'static`. Inference backends
26//! don't need to implement any trait — just wrap them on insert.
27//!
28//! Invariant: an evicted entry is only removed from the cache map; any
29//! outstanding `Arc<Mutex<T>>` handle continues to work until the last
30//! caller drops it. This makes eviction safe even during a long-running
31//! inference call — RAM is reclaimed lazily when the last user finishes.
32
33use std::collections::{HashMap, VecDeque};
34use std::path::Path;
35use std::sync::{Arc, Mutex};
36
37/// Handle to a cached backend. Callers lock the inner mutex for the
38/// duration of an inference call to serialize with concurrent requests
39/// for the same model.
40pub type CachedBackend<T> = Arc<Mutex<T>>;
41
42struct Entry<T> {
43 backend: CachedBackend<T>,
44 size_bytes: u64,
45}
46
47struct Inner<T> {
48 /// Loaded backends keyed by stable ID (typically `ModelSchema.id`).
49 map: HashMap<String, Entry<T>>,
50 /// LRU recency list — back is most recent, front is oldest.
51 lru: VecDeque<String>,
52 total_bytes: u64,
53}
54
55/// LRU-bounded cache of loaded inference backends.
56///
57/// Not itself `Sync` over `T` — the cache stores handles only, and the
58/// lock guarding the map is coarse (held only for look-up / insert /
59/// evict). Per-entry mutation is serialized by each entry's own `Mutex`.
60pub struct BackendCache<T: Send + 'static> {
61 inner: Mutex<Inner<T>>,
62 budget_bytes: u64,
63}
64
65impl<T: Send + 'static> BackendCache<T> {
66 /// Create a new cache with the given RAM budget in bytes.
67 pub fn new(budget_bytes: u64) -> Self {
68 Self {
69 inner: Mutex::new(Inner {
70 map: HashMap::new(),
71 lru: VecDeque::new(),
72 total_bytes: 0,
73 }),
74 budget_bytes,
75 }
76 }
77
78 /// Read the `CAR_INFERENCE_MODEL_CACHE_MB` env var, default 24 GB.
79 /// A value of 0 disables caching (every call loads fresh).
80 pub fn from_env() -> Self {
81 let mb = std::env::var("CAR_INFERENCE_MODEL_CACHE_MB")
82 .ok()
83 .and_then(|v| v.parse::<u64>().ok())
84 .unwrap_or(24 * 1024);
85 Self::new(mb.saturating_mul(1024 * 1024))
86 }
87
88 /// Is this a disabled (zero-budget) cache?
89 pub fn is_disabled(&self) -> bool {
90 self.budget_bytes == 0
91 }
92
93 /// Get a handle to the cached backend, or load + insert it.
94 ///
95 /// `loader` is invoked only on a cache miss. If the new entry plus
96 /// existing entries exceed the budget, the oldest entries are
97 /// evicted until we're within budget (or only the new entry remains).
98 ///
99 /// `size_bytes` should be an approximate on-disk or in-memory size
100 /// for the loaded backend; [`estimate_model_size`] is a reasonable
101 /// default that sums the `.safetensors` file sizes in a model dir.
102 pub fn get_or_load<E>(
103 &self,
104 key: &str,
105 size_bytes: u64,
106 loader: impl FnOnce() -> Result<T, E>,
107 ) -> Result<CachedBackend<T>, E> {
108 // Fast path: already cached.
109 {
110 let mut guard = self.inner.lock().expect("backend cache poisoned");
111 if let Some(entry) = guard.map.get(key) {
112 let handle = Arc::clone(&entry.backend);
113 // Promote to most-recently-used.
114 guard.lru.retain(|k| k != key);
115 guard.lru.push_back(key.to_string());
116 return Ok(handle);
117 }
118 }
119
120 // Slow path: load outside the lock. Concurrent loads of the same
121 // key may duplicate work, but the second caller's insert simply
122 // loses the race — memory isn't leaked because of the evict loop.
123 let backend = loader()?;
124 let handle = Arc::new(Mutex::new(backend));
125
126 if self.budget_bytes == 0 {
127 // Caching disabled: return the handle without retaining it.
128 return Ok(handle);
129 }
130
131 let mut guard = self.inner.lock().expect("backend cache poisoned");
132
133 // Re-check in case someone else inserted while we were loading.
134 if let Some(existing) = guard.map.get(key) {
135 return Ok(Arc::clone(&existing.backend));
136 }
137
138 guard.total_bytes = guard.total_bytes.saturating_add(size_bytes);
139 guard.map.insert(
140 key.to_string(),
141 Entry {
142 backend: Arc::clone(&handle),
143 size_bytes,
144 },
145 );
146 guard.lru.push_back(key.to_string());
147
148 // Evict LRU entries until within budget (but never evict the
149 // just-inserted key — that would defeat the purpose).
150 while guard.total_bytes > self.budget_bytes {
151 let Some(victim_key) = guard.lru.pop_front() else {
152 break;
153 };
154 if victim_key == key {
155 // Skipped our own entry; put it back at the front and stop
156 // (we've already visited every other key).
157 guard.lru.push_front(victim_key);
158 break;
159 }
160 if let Some(victim) = guard.map.remove(&victim_key) {
161 guard.total_bytes = guard.total_bytes.saturating_sub(victim.size_bytes);
162 // The `Arc` may still have outstanding handles; they
163 // continue to work. Memory is reclaimed when the last
164 // one drops.
165 drop(victim);
166 }
167 }
168
169 Ok(handle)
170 }
171
172 /// Manually remove a key, e.g. when model weights have been updated
173 /// on disk. Outstanding `Arc<Mutex<T>>` handles remain valid.
174 pub fn invalidate(&self, key: &str) {
175 let mut guard = self.inner.lock().expect("backend cache poisoned");
176 if let Some(entry) = guard.map.remove(key) {
177 guard.total_bytes = guard.total_bytes.saturating_sub(entry.size_bytes);
178 guard.lru.retain(|k| k != key);
179 }
180 }
181
182 /// Evict all entries. Outstanding handles continue to work.
183 pub fn clear(&self) {
184 let mut guard = self.inner.lock().expect("backend cache poisoned");
185 guard.map.clear();
186 guard.lru.clear();
187 guard.total_bytes = 0;
188 }
189
190 /// Returns `(entries, total_bytes, budget_bytes)` for diagnostics.
191 pub fn stats(&self) -> (usize, u64, u64) {
192 let guard = self.inner.lock().expect("backend cache poisoned");
193 (guard.map.len(), guard.total_bytes, self.budget_bytes)
194 }
195}
196
197/// Sum the sizes of all `.safetensors` files under `model_dir`.
198/// A loose upper-bound RAM estimate for the loaded model (quantized
199/// tensors live in MLX-owned memory roughly matching their on-disk size).
200pub fn estimate_model_size(model_dir: &Path) -> u64 {
201 fn visit(dir: &Path, total: &mut u64) {
202 let Ok(entries) = std::fs::read_dir(dir) else {
203 return;
204 };
205 for entry in entries.flatten() {
206 let path = entry.path();
207 if path.is_dir() {
208 visit(&path, total);
209 continue;
210 }
211 if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
212 if let Ok(meta) = path.metadata() {
213 *total = total.saturating_add(meta.len());
214 }
215 }
216 }
217 }
218 let mut total = 0u64;
219 visit(model_dir, &mut total);
220 total
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn cache_hit_returns_same_handle() {
229 let cache: BackendCache<u32> = BackendCache::new(1024);
230 let a = cache.get_or_load::<()>("a", 100, || Ok(42)).unwrap();
231 let b = cache
232 .get_or_load::<()>("a", 100, || panic!("should not reload"))
233 .unwrap();
234 assert!(Arc::ptr_eq(&a, &b));
235 }
236
237 #[test]
238 fn evicts_lru_when_over_budget() {
239 let cache: BackendCache<u32> = BackendCache::new(250);
240 let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
241 let _b = cache.get_or_load::<()>("b", 100, || Ok(2)).unwrap();
242 // Total 200, still under 250. Touch `a` so `b` is LRU.
243 let _a_again = cache
244 .get_or_load::<()>("a", 100, || panic!("cached"))
245 .unwrap();
246 // Insert `c`: pushes us to 300, evicts `b` (the LRU one).
247 let _c = cache.get_or_load::<()>("c", 100, || Ok(3)).unwrap();
248 let (n, bytes, budget) = cache.stats();
249 assert_eq!(n, 2, "a + c should remain, b evicted");
250 assert_eq!(bytes, 200);
251 assert_eq!(budget, 250);
252 }
253
254 #[test]
255 fn zero_budget_disables_cache_but_returns_handle() {
256 let cache: BackendCache<u32> = BackendCache::new(0);
257 let mut load_count = 0u32;
258 let a = cache
259 .get_or_load::<()>("a", 100, || {
260 load_count += 1;
261 Ok(1)
262 })
263 .unwrap();
264 assert_eq!(*a.lock().unwrap(), 1);
265 let b = cache
266 .get_or_load::<()>("a", 100, || {
267 load_count += 1;
268 Ok(1)
269 })
270 .unwrap();
271 assert_eq!(*b.lock().unwrap(), 1);
272 assert_eq!(load_count, 2, "disabled cache reloads every call");
273 assert!(!Arc::ptr_eq(&a, &b));
274 }
275
276 #[test]
277 fn invalidate_removes_key() {
278 let cache: BackendCache<u32> = BackendCache::new(1024);
279 let _a = cache.get_or_load::<()>("a", 100, || Ok(1)).unwrap();
280 assert_eq!(cache.stats().0, 1);
281 cache.invalidate("a");
282 assert_eq!(cache.stats().0, 0);
283 }
284}