Skip to main content

oxibonsai_runtime/
model_cache.rs

1//! In-process model cache for GGUF files.
2//!
3//! Avoids reloading model weights for each request by keeping a bounded set of
4//! [`ModelEntry`] values in a [`ModelCache`].  The cache uses LRU-like eviction
5//! (evict the entry with the longest idle time) when the slot limit is reached.
6//!
7//! A companion [`ModelWarmup`] helper runs a small number of dummy inference
8//! passes on a freshly-loaded engine so that internal caches and JIT paths are
9//! primed before the first real request.
10
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15
16use oxibonsai_core::config::Qwen3Config;
17
18use crate::engine::InferenceEngine;
19use crate::sampling::SamplingParams;
20
21// ─────────────────────────────────────────────────────────────────────────────
22// ModelEntry
23// ─────────────────────────────────────────────────────────────────────────────
24
25/// A single cached model entry, storing metadata about a loaded model.
26///
27/// The entry does **not** own the actual weight tensors; those live in the
28/// [`InferenceEngine`] that the cache manages externally.  The entry tracks
29/// usage statistics so that the cache can decide which entries to evict.
30pub struct ModelEntry {
31    /// Model configuration extracted from GGUF metadata.
32    pub config: Qwen3Config,
33    /// Filesystem path to the GGUF file (if known).
34    pub model_path: Option<String>,
35    /// Wall-clock time at which this entry was first inserted.
36    pub loaded_at: Instant,
37    /// Wall-clock time of the most recent cache hit for this entry.
38    pub last_used: Instant,
39    /// Cumulative number of times this entry has been returned from the cache.
40    pub use_count: u64,
41    /// Estimated resident-memory footprint of the loaded model.
42    pub memory_bytes: usize,
43}
44
45impl ModelEntry {
46    /// Create a new entry stamped with the current time.
47    pub fn new(config: Qwen3Config, model_path: Option<String>, memory_bytes: usize) -> Self {
48        let now = Instant::now();
49        Self {
50            config,
51            model_path,
52            loaded_at: now,
53            last_used: now,
54            use_count: 0,
55            memory_bytes,
56        }
57    }
58
59    /// How long this entry has been in the cache.
60    pub fn age(&self) -> Duration {
61        self.loaded_at.elapsed()
62    }
63
64    /// How long since this entry was last accessed.
65    pub fn idle_time(&self) -> Duration {
66        self.last_used.elapsed()
67    }
68
69    /// Whether this entry has been idle for longer than `ttl`.
70    pub fn is_stale(&self, ttl: Duration) -> bool {
71        self.idle_time() >= ttl
72    }
73}
74
75// ─────────────────────────────────────────────────────────────────────────────
76// ModelCacheConfig
77// ─────────────────────────────────────────────────────────────────────────────
78
79/// Configuration for [`ModelCache`].
80#[derive(Debug, Clone)]
81pub struct ModelCacheConfig {
82    /// Maximum number of model entries to keep in the cache simultaneously.
83    pub max_models: usize,
84    /// Time-to-live: entries idle longer than this are eligible for eviction.
85    pub ttl: Duration,
86    /// When `true`, the cache will proactively evict entries when the total
87    /// resident memory exceeds `memory_budget_bytes`.
88    pub evict_on_memory_pressure: bool,
89    /// Optional memory ceiling in bytes.  When the aggregate `memory_bytes` of
90    /// all cached entries exceeds this value the least-recently-used entry is
91    /// evicted before inserting a new one.
92    pub memory_budget_bytes: Option<usize>,
93}
94
95impl Default for ModelCacheConfig {
96    fn default() -> Self {
97        Self {
98            max_models: 4,
99            ttl: Duration::from_secs(3600),
100            evict_on_memory_pressure: true,
101            memory_budget_bytes: None,
102        }
103    }
104}
105
106// ─────────────────────────────────────────────────────────────────────────────
107// ModelCacheStats
108// ─────────────────────────────────────────────────────────────────────────────
109
110/// Snapshot of cache utilisation metrics, suitable for serialisation to JSON.
111#[derive(Debug, serde::Serialize)]
112pub struct ModelCacheStats {
113    /// Number of entries currently held in the cache.
114    pub cached_models: usize,
115    /// Cumulative cache hits since the cache was created.
116    pub total_hits: u64,
117    /// Cumulative cache misses since the cache was created.
118    pub total_misses: u64,
119    /// Hit rate as a fraction in `[0.0, 1.0]`.
120    pub hit_rate: f32,
121    /// Sum of `memory_bytes` across all cached entries.
122    pub total_memory_bytes: usize,
123    /// Age in seconds of the oldest entry, or `None` if the cache is empty.
124    pub oldest_entry_age_secs: Option<u64>,
125}
126
127// ─────────────────────────────────────────────────────────────────────────────
128// ModelCache
129// ─────────────────────────────────────────────────────────────────────────────
130
131/// Thread-safe in-process model cache.
132///
133/// Uses a [`Mutex`]-guarded [`HashMap`] internally.  Eviction is based on
134/// idle time (longest-idle entry is removed first) when the slot or memory
135/// budget is exceeded.
136pub struct ModelCache {
137    entries: Mutex<HashMap<String, ModelEntry>>,
138    config: ModelCacheConfig,
139    /// Cumulative number of cache hits.
140    pub hits: AtomicU64,
141    /// Cumulative number of cache misses.
142    pub misses: AtomicU64,
143}
144
145impl ModelCache {
146    /// Create a new, empty cache with the given configuration.
147    pub fn new(config: ModelCacheConfig) -> Self {
148        Self {
149            entries: Mutex::new(HashMap::new()),
150            config,
151            hits: AtomicU64::new(0),
152            misses: AtomicU64::new(0),
153        }
154    }
155
156    /// Return a shared reference to the cached entry for `key`, or insert a
157    /// new one produced by `loader` if none exists (or if the existing entry
158    /// is stale).
159    ///
160    /// The returned [`Arc`] allows callers to hold a reference to the entry
161    /// while the cache mutex is not held.
162    pub fn get_or_insert<F>(&self, key: &str, loader: F) -> Arc<ModelEntry>
163    where
164        F: FnOnce() -> ModelEntry,
165    {
166        let mut entries = self
167            .entries
168            .lock()
169            .expect("model cache mutex should not be poisoned");
170
171        // Check for a live (non-stale) existing entry.
172        if let Some(entry) = entries.get_mut(key) {
173            if !entry.is_stale(self.config.ttl) {
174                entry.last_used = Instant::now();
175                entry.use_count += 1;
176                self.hits.fetch_add(1, Ordering::Relaxed);
177                // Clone the relevant fields into a new Arc — we cannot hand
178                // out a reference into the HashMap while the mutex is held by
179                // the caller.
180                return Arc::new(ModelEntry {
181                    config: entry.config.clone(),
182                    model_path: entry.model_path.clone(),
183                    loaded_at: entry.loaded_at,
184                    last_used: entry.last_used,
185                    use_count: entry.use_count,
186                    memory_bytes: entry.memory_bytes,
187                });
188            }
189            // Stale — remove and fall through to reload.
190            entries.remove(key);
191        }
192
193        // Cache miss: invoke the loader.
194        self.misses.fetch_add(1, Ordering::Relaxed);
195        let new_entry = loader();
196
197        // Evict if necessary before inserting.
198        self.evict_if_needed_locked(&mut entries, new_entry.memory_bytes);
199
200        let result = Arc::new(ModelEntry {
201            config: new_entry.config.clone(),
202            model_path: new_entry.model_path.clone(),
203            loaded_at: new_entry.loaded_at,
204            last_used: new_entry.last_used,
205            use_count: new_entry.use_count,
206            memory_bytes: new_entry.memory_bytes,
207        });
208
209        entries.insert(key.to_owned(), new_entry);
210        result
211    }
212
213    /// Return `true` if a non-stale entry exists for `key`.
214    pub fn contains(&self, key: &str) -> bool {
215        let entries = self
216            .entries
217            .lock()
218            .expect("model cache mutex should not be poisoned");
219        entries
220            .get(key)
221            .map(|e| !e.is_stale(self.config.ttl))
222            .unwrap_or(false)
223    }
224
225    /// Remove the entry for `key`.  Returns `true` if an entry was removed.
226    pub fn evict(&self, key: &str) -> bool {
227        let mut entries = self
228            .entries
229            .lock()
230            .expect("model cache mutex should not be poisoned");
231        entries.remove(key).is_some()
232    }
233
234    /// Remove all entries that have been idle longer than the configured TTL.
235    /// Returns the number of entries removed.
236    pub fn evict_stale(&self) -> usize {
237        let mut entries = self
238            .entries
239            .lock()
240            .expect("model cache mutex should not be poisoned");
241        let ttl = self.config.ttl;
242        let before = entries.len();
243        entries.retain(|_, e| !e.is_stale(ttl));
244        before - entries.len()
245    }
246
247    /// Number of entries currently in the cache.
248    pub fn len(&self) -> usize {
249        self.entries
250            .lock()
251            .expect("model cache mutex should not be poisoned")
252            .len()
253    }
254
255    /// `true` when the cache holds no entries.
256    pub fn is_empty(&self) -> bool {
257        self.len() == 0
258    }
259
260    /// Cache hit rate as a fraction in `[0.0, 1.0]`.
261    ///
262    /// Returns `0.0` when no lookups have been performed yet.
263    pub fn hit_rate(&self) -> f32 {
264        let hits = self.hits.load(Ordering::Relaxed);
265        let misses = self.misses.load(Ordering::Relaxed);
266        let total = hits + misses;
267        if total == 0 {
268            return 0.0;
269        }
270        hits as f32 / total as f32
271    }
272
273    /// Sum of `memory_bytes` across all cached entries.
274    pub fn total_memory_bytes(&self) -> usize {
275        self.entries
276            .lock()
277            .expect("model cache mutex should not be poisoned")
278            .values()
279            .map(|e| e.memory_bytes)
280            .sum()
281    }
282
283    /// Take a statistics snapshot of the current cache state.
284    pub fn stats(&self) -> ModelCacheStats {
285        let entries = self
286            .entries
287            .lock()
288            .expect("model cache mutex should not be poisoned");
289        let hits = self.hits.load(Ordering::Relaxed);
290        let misses = self.misses.load(Ordering::Relaxed);
291        let total = hits + misses;
292        let hit_rate = if total == 0 {
293            0.0
294        } else {
295            hits as f32 / total as f32
296        };
297        let total_memory_bytes: usize = entries.values().map(|e| e.memory_bytes).sum();
298        let oldest_entry_age_secs = entries.values().map(|e| e.age().as_secs()).max();
299
300        ModelCacheStats {
301            cached_models: entries.len(),
302            total_hits: hits,
303            total_misses: misses,
304            hit_rate,
305            total_memory_bytes,
306            oldest_entry_age_secs,
307        }
308    }
309
310    // ── Private helpers ────────────────────────────────────────────────────
311
312    /// Evict entries while over capacity or over memory budget.
313    ///
314    /// Must be called with the mutex already held.
315    fn evict_if_needed_locked(
316        &self,
317        entries: &mut HashMap<String, ModelEntry>,
318        incoming_bytes: usize,
319    ) {
320        // Slot limit.
321        while entries.len() >= self.config.max_models {
322            Self::evict_lru(entries);
323        }
324
325        // Memory budget.
326        if self.config.evict_on_memory_pressure {
327            if let Some(budget) = self.config.memory_budget_bytes {
328                let current: usize = entries.values().map(|e| e.memory_bytes).sum();
329                let projected = current.saturating_add(incoming_bytes);
330                while projected > budget && !entries.is_empty() {
331                    Self::evict_lru(entries);
332                }
333            }
334        }
335    }
336
337    /// Remove the entry with the longest idle time (LRU eviction policy).
338    fn evict_lru(entries: &mut HashMap<String, ModelEntry>) {
339        if entries.is_empty() {
340            return;
341        }
342        let lru_key = entries
343            .iter()
344            .max_by_key(|(_, e)| {
345                // Convert to a comparable integer (microseconds since last use).
346                e.idle_time().as_micros()
347            })
348            .map(|(k, _)| k.clone());
349
350        if let Some(key) = lru_key {
351            entries.remove(&key);
352        }
353    }
354}
355
356// ─────────────────────────────────────────────────────────────────────────────
357// ModelWarmup
358// ─────────────────────────────────────────────────────────────────────────────
359
360/// Runs a small number of dummy inference passes on a freshly-initialised
361/// [`InferenceEngine`] to prime internal allocation caches and JIT paths
362/// before the first real request arrives.
363pub struct ModelWarmup {
364    /// Number of tokens to generate during the warmup pass.
365    pub num_warmup_tokens: usize,
366    /// Prompt text fed to the engine during warmup.
367    pub warmup_prompt: String,
368}
369
370impl Default for ModelWarmup {
371    fn default() -> Self {
372        Self::new()
373    }
374}
375
376impl ModelWarmup {
377    /// Create a warmup helper with sensible defaults (32 tokens, generic prompt).
378    pub fn new() -> Self {
379        Self {
380            num_warmup_tokens: 32,
381            warmup_prompt: "Warm up the inference engine.".to_owned(),
382        }
383    }
384
385    /// Override the number of warmup tokens.
386    pub fn with_tokens(mut self, n: usize) -> Self {
387        self.num_warmup_tokens = n;
388        self
389    }
390
391    /// Override the warmup prompt text.
392    pub fn with_prompt(mut self, p: &str) -> Self {
393        self.warmup_prompt = p.to_owned();
394        self
395    }
396
397    /// Execute the warmup passes on `engine` using `params`.
398    ///
399    /// Generates up to [`ModelWarmup::num_warmup_tokens`] tokens from a small
400    /// synthetic token sequence and discards the output.  Returns the elapsed
401    /// wall-clock time in milliseconds.
402    ///
403    /// Errors from the engine are logged as warnings but do **not** propagate —
404    /// warmup failure is non-fatal.
405    pub fn run(&self, engine: &mut InferenceEngine<'_>, params: &SamplingParams) -> u64 {
406        let start = Instant::now();
407
408        // Build a minimal synthetic prompt from the warmup text.
409        // Without a real tokenizer we use a fixed representative token sequence.
410        let dummy_tokens: Vec<u32> = self
411            .warmup_prompt
412            .bytes()
413            .take(16)
414            .map(|b| u32::from(b) % 32000)
415            .collect();
416
417        let prompt_tokens = if dummy_tokens.is_empty() {
418            vec![151644u32] // <|im_start|>
419        } else {
420            dummy_tokens
421        };
422
423        // Temporarily swap in the caller-supplied params via generate_with_seed.
424        match engine.generate_with_seed(&prompt_tokens, self.num_warmup_tokens, 0, params) {
425            Ok(toks) => {
426                tracing::debug!(generated = toks.len(), "warmup pass completed");
427            }
428            Err(e) => {
429                tracing::warn!(error = %e, "warmup pass encountered an error (non-fatal)");
430            }
431        }
432
433        // Reset state so the engine is clean for real requests.
434        engine.reset();
435
436        start.elapsed().as_millis() as u64
437    }
438
439    /// Heuristic: should this engine be warmed up?
440    ///
441    /// The current implementation always returns `true` — callers are
442    /// responsible for deciding when to apply warmup (e.g. once after initial
443    /// load, or after a cache miss).
444    pub fn needs_warmup(_engine: &InferenceEngine<'_>) -> bool {
445        true
446    }
447}
448
449// ─────────────────────────────────────────────────────────────────────────────
450// Tests
451// ─────────────────────────────────────────────────────────────────────────────
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use oxibonsai_core::config::Qwen3Config;
457
458    fn tiny_entry() -> ModelEntry {
459        ModelEntry::new(
460            Qwen3Config::tiny_test(),
461            Some(std::env::temp_dir().join("tiny.gguf").display().to_string()),
462            1024,
463        )
464    }
465
466    // ── ModelEntry ────────────────────────────────────────────────────────
467
468    #[test]
469    fn test_model_entry_age() {
470        let entry = tiny_entry();
471        let age = entry.age();
472        // Age should be very small (sub-second) immediately after creation.
473        assert!(age < Duration::from_secs(1));
474    }
475
476    #[test]
477    fn test_model_entry_is_stale() {
478        let entry = tiny_entry();
479        // With a 1-hour TTL the brand-new entry must not be stale.
480        assert!(!entry.is_stale(Duration::from_secs(3600)));
481        // With a zero TTL every entry is stale.
482        assert!(entry.is_stale(Duration::from_nanos(0)));
483    }
484
485    // ── ModelCache — miss path ─────────────────────────────────────────────
486
487    #[test]
488    fn test_model_cache_miss_calls_loader() {
489        let cache = ModelCache::new(ModelCacheConfig::default());
490        let mut loader_called = false;
491
492        let _entry = cache.get_or_insert("model-a", || {
493            loader_called = true;
494            tiny_entry()
495        });
496
497        assert!(loader_called, "loader should have been called on a miss");
498        assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
499        assert_eq!(cache.hits.load(Ordering::Relaxed), 0);
500        assert_eq!(cache.len(), 1);
501    }
502
503    // ── ModelCache — hit path ──────────────────────────────────────────────
504
505    #[test]
506    fn test_model_cache_hit_skips_loader() {
507        let cache = ModelCache::new(ModelCacheConfig::default());
508
509        // First call: miss.
510        cache.get_or_insert("model-b", tiny_entry);
511
512        // Second call: should be a hit.
513        let mut second_loader_called = false;
514        cache.get_or_insert("model-b", || {
515            second_loader_called = true;
516            tiny_entry()
517        });
518
519        assert!(!second_loader_called, "loader must not be called on a hit");
520        assert_eq!(cache.hits.load(Ordering::Relaxed), 1);
521        assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
522    }
523
524    // ── ModelCache — manual eviction ──────────────────────────────────────
525
526    #[test]
527    fn test_model_cache_evict() {
528        let cache = ModelCache::new(ModelCacheConfig::default());
529        cache.get_or_insert("model-c", tiny_entry);
530        assert!(cache.contains("model-c"));
531
532        let removed = cache.evict("model-c");
533        assert!(removed);
534        assert!(!cache.contains("model-c"));
535        assert_eq!(cache.len(), 0);
536
537        // Evicting a non-existent key returns false.
538        assert!(!cache.evict("no-such-model"));
539    }
540
541    // ── ModelCache — stale eviction ──────────────────────────────────────
542
543    #[test]
544    fn test_model_cache_evict_stale() {
545        // Use a zero TTL so every entry is immediately stale.
546        let cfg = ModelCacheConfig {
547            ttl: Duration::from_nanos(0),
548            ..Default::default()
549        };
550        let cache = ModelCache::new(cfg);
551
552        // Insert via get_or_insert so the entry lands in the map.
553        {
554            let mut entries = cache.entries.lock().expect("mutex should not be poisoned");
555            entries.insert("model-d".to_owned(), tiny_entry());
556        }
557
558        assert_eq!(cache.len(), 1);
559        let evicted = cache.evict_stale();
560        assert_eq!(evicted, 1);
561        assert_eq!(cache.len(), 0);
562    }
563
564    // ── ModelCache — hit rate ─────────────────────────────────────────────
565
566    #[test]
567    fn test_model_cache_hit_rate() {
568        let cache = ModelCache::new(ModelCacheConfig::default());
569
570        // No lookups yet → 0.0.
571        assert!((cache.hit_rate() - 0.0).abs() < f32::EPSILON);
572
573        cache.get_or_insert("rate-model", tiny_entry); // miss
574        cache.get_or_insert("rate-model", tiny_entry); // hit
575        cache.get_or_insert("rate-model", tiny_entry); // hit
576
577        // 2 hits out of 3 total → ~0.667
578        let rate = cache.hit_rate();
579        assert!(rate > 0.6 && rate < 0.7, "expected ~0.667, got {rate}");
580    }
581
582    // ── ModelCache — stats snapshot ───────────────────────────────────────
583
584    #[test]
585    fn test_model_cache_stats() {
586        let cache = ModelCache::new(ModelCacheConfig::default());
587        cache.get_or_insert("stats-model", tiny_entry); // miss
588
589        let stats = cache.stats();
590        assert_eq!(stats.cached_models, 1);
591        assert_eq!(stats.total_misses, 1);
592        assert_eq!(stats.total_hits, 0);
593        assert_eq!(stats.total_memory_bytes, 1024);
594        assert!(stats.oldest_entry_age_secs.is_some());
595    }
596
597    // ── ModelWarmup ───────────────────────────────────────────────────────
598
599    #[test]
600    fn test_warmup_runs_without_panic() {
601        let config = Qwen3Config::tiny_test();
602        let params = SamplingParams::default();
603        let mut engine = InferenceEngine::new(config, params.clone(), 42);
604
605        let warmup = ModelWarmup::new().with_tokens(4).with_prompt("Hello");
606        let elapsed_ms = warmup.run(&mut engine, &params);
607
608        // Warmup must complete (even if it generates 0 tokens on a tiny model).
609        // We just check it didn't panic and returned a sensible elapsed time.
610        assert!(elapsed_ms < 60_000, "warmup should complete in under 60 s");
611        assert!(ModelWarmup::needs_warmup(&engine));
612    }
613}