1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
//! In-process model cache for GGUF files.
//!
//! Avoids reloading model weights for each request by keeping a bounded set of
//! [`ModelEntry`] values in a [`ModelCache`]. The cache uses LRU-like eviction
//! (evict the entry with the longest idle time) when the slot limit is reached.
//!
//! A companion [`ModelWarmup`] helper runs a small number of dummy inference
//! passes on a freshly-loaded engine so that internal caches and JIT paths are
//! primed before the first real request.
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use oxibonsai_core::config::Qwen3Config;
use crate::engine::InferenceEngine;
use crate::sampling::SamplingParams;
// ─────────────────────────────────────────────────────────────────────────────
// ModelEntry
// ─────────────────────────────────────────────────────────────────────────────
/// A single cached model entry, storing metadata about a loaded model.
///
/// The entry does **not** own the actual weight tensors; those live in the
/// [`InferenceEngine`] that the cache manages externally. The entry tracks
/// usage statistics so that the cache can decide which entries to evict.
pub struct ModelEntry {
/// Model configuration extracted from GGUF metadata.
pub config: Qwen3Config,
/// Filesystem path to the GGUF file (if known).
pub model_path: Option<String>,
/// Wall-clock time at which this entry was first inserted.
pub loaded_at: Instant,
/// Wall-clock time of the most recent cache hit for this entry.
pub last_used: Instant,
/// Cumulative number of times this entry has been returned from the cache.
pub use_count: u64,
/// Estimated resident-memory footprint of the loaded model.
pub memory_bytes: usize,
}
impl ModelEntry {
/// Create a new entry stamped with the current time.
pub fn new(config: Qwen3Config, model_path: Option<String>, memory_bytes: usize) -> Self {
let now = Instant::now();
Self {
config,
model_path,
loaded_at: now,
last_used: now,
use_count: 0,
memory_bytes,
}
}
/// How long this entry has been in the cache.
pub fn age(&self) -> Duration {
self.loaded_at.elapsed()
}
/// How long since this entry was last accessed.
pub fn idle_time(&self) -> Duration {
self.last_used.elapsed()
}
/// Whether this entry has been idle for longer than `ttl`.
pub fn is_stale(&self, ttl: Duration) -> bool {
self.idle_time() >= ttl
}
}
// ─────────────────────────────────────────────────────────────────────────────
// ModelCacheConfig
// ─────────────────────────────────────────────────────────────────────────────
/// Configuration for [`ModelCache`].
#[derive(Debug, Clone)]
pub struct ModelCacheConfig {
/// Maximum number of model entries to keep in the cache simultaneously.
pub max_models: usize,
/// Time-to-live: entries idle longer than this are eligible for eviction.
pub ttl: Duration,
/// When `true`, the cache will proactively evict entries when the total
/// resident memory exceeds `memory_budget_bytes`.
pub evict_on_memory_pressure: bool,
/// Optional memory ceiling in bytes. When the aggregate `memory_bytes` of
/// all cached entries exceeds this value the least-recently-used entry is
/// evicted before inserting a new one.
pub memory_budget_bytes: Option<usize>,
}
impl Default for ModelCacheConfig {
fn default() -> Self {
Self {
max_models: 4,
ttl: Duration::from_secs(3600),
evict_on_memory_pressure: true,
memory_budget_bytes: None,
}
}
}
// ─────────────────────────────────────────────────────────────────────────────
// ModelCacheStats
// ─────────────────────────────────────────────────────────────────────────────
/// Snapshot of cache utilisation metrics, suitable for serialisation to JSON.
#[derive(Debug, serde::Serialize)]
pub struct ModelCacheStats {
/// Number of entries currently held in the cache.
pub cached_models: usize,
/// Cumulative cache hits since the cache was created.
pub total_hits: u64,
/// Cumulative cache misses since the cache was created.
pub total_misses: u64,
/// Hit rate as a fraction in `[0.0, 1.0]`.
pub hit_rate: f32,
/// Sum of `memory_bytes` across all cached entries.
pub total_memory_bytes: usize,
/// Age in seconds of the oldest entry, or `None` if the cache is empty.
pub oldest_entry_age_secs: Option<u64>,
}
// ─────────────────────────────────────────────────────────────────────────────
// ModelCache
// ─────────────────────────────────────────────────────────────────────────────
/// Thread-safe in-process model cache.
///
/// Uses a [`Mutex`]-guarded [`HashMap`] internally. Eviction is based on
/// idle time (longest-idle entry is removed first) when the slot or memory
/// budget is exceeded.
pub struct ModelCache {
entries: Mutex<HashMap<String, ModelEntry>>,
config: ModelCacheConfig,
/// Cumulative number of cache hits.
pub hits: AtomicU64,
/// Cumulative number of cache misses.
pub misses: AtomicU64,
}
impl ModelCache {
/// Create a new, empty cache with the given configuration.
pub fn new(config: ModelCacheConfig) -> Self {
Self {
entries: Mutex::new(HashMap::new()),
config,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
/// Return a shared reference to the cached entry for `key`, or insert a
/// new one produced by `loader` if none exists (or if the existing entry
/// is stale).
///
/// The returned [`Arc`] allows callers to hold a reference to the entry
/// while the cache mutex is not held.
pub fn get_or_insert<F>(&self, key: &str, loader: F) -> Arc<ModelEntry>
where
F: FnOnce() -> ModelEntry,
{
let mut entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
// Check for a live (non-stale) existing entry.
if let Some(entry) = entries.get_mut(key) {
if !entry.is_stale(self.config.ttl) {
entry.last_used = Instant::now();
entry.use_count += 1;
self.hits.fetch_add(1, Ordering::Relaxed);
// Clone the relevant fields into a new Arc — we cannot hand
// out a reference into the HashMap while the mutex is held by
// the caller.
return Arc::new(ModelEntry {
config: entry.config.clone(),
model_path: entry.model_path.clone(),
loaded_at: entry.loaded_at,
last_used: entry.last_used,
use_count: entry.use_count,
memory_bytes: entry.memory_bytes,
});
}
// Stale — remove and fall through to reload.
entries.remove(key);
}
// Cache miss: invoke the loader.
self.misses.fetch_add(1, Ordering::Relaxed);
let new_entry = loader();
// Evict if necessary before inserting.
self.evict_if_needed_locked(&mut entries, new_entry.memory_bytes);
let result = Arc::new(ModelEntry {
config: new_entry.config.clone(),
model_path: new_entry.model_path.clone(),
loaded_at: new_entry.loaded_at,
last_used: new_entry.last_used,
use_count: new_entry.use_count,
memory_bytes: new_entry.memory_bytes,
});
entries.insert(key.to_owned(), new_entry);
result
}
/// Return `true` if a non-stale entry exists for `key`.
pub fn contains(&self, key: &str) -> bool {
let entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
entries
.get(key)
.map(|e| !e.is_stale(self.config.ttl))
.unwrap_or(false)
}
/// Remove the entry for `key`. Returns `true` if an entry was removed.
pub fn evict(&self, key: &str) -> bool {
let mut entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
entries.remove(key).is_some()
}
/// Remove all entries that have been idle longer than the configured TTL.
/// Returns the number of entries removed.
pub fn evict_stale(&self) -> usize {
let mut entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
let ttl = self.config.ttl;
let before = entries.len();
entries.retain(|_, e| !e.is_stale(ttl));
before - entries.len()
}
/// Number of entries currently in the cache.
pub fn len(&self) -> usize {
self.entries
.lock()
.expect("model cache mutex should not be poisoned")
.len()
}
/// `true` when the cache holds no entries.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Cache hit rate as a fraction in `[0.0, 1.0]`.
///
/// Returns `0.0` when no lookups have been performed yet.
pub fn hit_rate(&self) -> f32 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
return 0.0;
}
hits as f32 / total as f32
}
/// Sum of `memory_bytes` across all cached entries.
pub fn total_memory_bytes(&self) -> usize {
self.entries
.lock()
.expect("model cache mutex should not be poisoned")
.values()
.map(|e| e.memory_bytes)
.sum()
}
/// Take a statistics snapshot of the current cache state.
pub fn stats(&self) -> ModelCacheStats {
let entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
let hit_rate = if total == 0 {
0.0
} else {
hits as f32 / total as f32
};
let total_memory_bytes: usize = entries.values().map(|e| e.memory_bytes).sum();
let oldest_entry_age_secs = entries.values().map(|e| e.age().as_secs()).max();
ModelCacheStats {
cached_models: entries.len(),
total_hits: hits,
total_misses: misses,
hit_rate,
total_memory_bytes,
oldest_entry_age_secs,
}
}
// ── Private helpers ────────────────────────────────────────────────────
/// Evict entries while over capacity or over memory budget.
///
/// Must be called with the mutex already held.
fn evict_if_needed_locked(
&self,
entries: &mut HashMap<String, ModelEntry>,
incoming_bytes: usize,
) {
// Slot limit.
while entries.len() >= self.config.max_models {
Self::evict_lru(entries);
}
// Memory budget.
if self.config.evict_on_memory_pressure {
if let Some(budget) = self.config.memory_budget_bytes {
let current: usize = entries.values().map(|e| e.memory_bytes).sum();
let projected = current.saturating_add(incoming_bytes);
while projected > budget && !entries.is_empty() {
Self::evict_lru(entries);
}
}
}
}
/// Remove the entry with the longest idle time (LRU eviction policy).
fn evict_lru(entries: &mut HashMap<String, ModelEntry>) {
if entries.is_empty() {
return;
}
let lru_key = entries
.iter()
.max_by_key(|(_, e)| {
// Convert to a comparable integer (microseconds since last use).
e.idle_time().as_micros()
})
.map(|(k, _)| k.clone());
if let Some(key) = lru_key {
entries.remove(&key);
}
}
}
// ─────────────────────────────────────────────────────────────────────────────
// ModelWarmup
// ─────────────────────────────────────────────────────────────────────────────
/// Runs a small number of dummy inference passes on a freshly-initialised
/// [`InferenceEngine`] to prime internal allocation caches and JIT paths
/// before the first real request arrives.
pub struct ModelWarmup {
/// Number of tokens to generate during the warmup pass.
pub num_warmup_tokens: usize,
/// Prompt text fed to the engine during warmup.
pub warmup_prompt: String,
}
impl Default for ModelWarmup {
fn default() -> Self {
Self::new()
}
}
impl ModelWarmup {
/// Create a warmup helper with sensible defaults (32 tokens, generic prompt).
pub fn new() -> Self {
Self {
num_warmup_tokens: 32,
warmup_prompt: "Warm up the inference engine.".to_owned(),
}
}
/// Override the number of warmup tokens.
pub fn with_tokens(mut self, n: usize) -> Self {
self.num_warmup_tokens = n;
self
}
/// Override the warmup prompt text.
pub fn with_prompt(mut self, p: &str) -> Self {
self.warmup_prompt = p.to_owned();
self
}
/// Execute the warmup passes on `engine` using `params`.
///
/// Generates up to [`ModelWarmup::num_warmup_tokens`] tokens from a small
/// synthetic token sequence and discards the output. Returns the elapsed
/// wall-clock time in milliseconds.
///
/// Errors from the engine are logged as warnings but do **not** propagate —
/// warmup failure is non-fatal.
pub fn run(&self, engine: &mut InferenceEngine<'_>, params: &SamplingParams) -> u64 {
let start = Instant::now();
// Build a minimal synthetic prompt from the warmup text.
// Without a real tokenizer we use a fixed representative token sequence.
let dummy_tokens: Vec<u32> = self
.warmup_prompt
.bytes()
.take(16)
.map(|b| u32::from(b) % 32000)
.collect();
let prompt_tokens = if dummy_tokens.is_empty() {
vec![151644u32] // <|im_start|>
} else {
dummy_tokens
};
// Temporarily swap in the caller-supplied params via generate_with_seed.
match engine.generate_with_seed(&prompt_tokens, self.num_warmup_tokens, 0, params) {
Ok(toks) => {
tracing::debug!(generated = toks.len(), "warmup pass completed");
}
Err(e) => {
tracing::warn!(error = %e, "warmup pass encountered an error (non-fatal)");
}
}
// Reset state so the engine is clean for real requests.
engine.reset();
start.elapsed().as_millis() as u64
}
/// Heuristic: should this engine be warmed up?
///
/// The current implementation always returns `true` — callers are
/// responsible for deciding when to apply warmup (e.g. once after initial
/// load, or after a cache miss).
pub fn needs_warmup(_engine: &InferenceEngine<'_>) -> bool {
true
}
}
// ─────────────────────────────────────────────────────────────────────────────
// Tests
// ─────────────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use oxibonsai_core::config::Qwen3Config;
fn tiny_entry() -> ModelEntry {
ModelEntry::new(
Qwen3Config::tiny_test(),
Some(std::env::temp_dir().join("tiny.gguf").display().to_string()),
1024,
)
}
// ── ModelEntry ────────────────────────────────────────────────────────
#[test]
fn test_model_entry_age() {
let entry = tiny_entry();
let age = entry.age();
// Age should be very small (sub-second) immediately after creation.
assert!(age < Duration::from_secs(1));
}
#[test]
fn test_model_entry_is_stale() {
let entry = tiny_entry();
// With a 1-hour TTL the brand-new entry must not be stale.
assert!(!entry.is_stale(Duration::from_secs(3600)));
// With a zero TTL every entry is stale.
assert!(entry.is_stale(Duration::from_nanos(0)));
}
// ── ModelCache — miss path ─────────────────────────────────────────────
#[test]
fn test_model_cache_miss_calls_loader() {
let cache = ModelCache::new(ModelCacheConfig::default());
let mut loader_called = false;
let _entry = cache.get_or_insert("model-a", || {
loader_called = true;
tiny_entry()
});
assert!(loader_called, "loader should have been called on a miss");
assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
assert_eq!(cache.hits.load(Ordering::Relaxed), 0);
assert_eq!(cache.len(), 1);
}
// ── ModelCache — hit path ──────────────────────────────────────────────
#[test]
fn test_model_cache_hit_skips_loader() {
let cache = ModelCache::new(ModelCacheConfig::default());
// First call: miss.
cache.get_or_insert("model-b", tiny_entry);
// Second call: should be a hit.
let mut second_loader_called = false;
cache.get_or_insert("model-b", || {
second_loader_called = true;
tiny_entry()
});
assert!(!second_loader_called, "loader must not be called on a hit");
assert_eq!(cache.hits.load(Ordering::Relaxed), 1);
assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
}
// ── ModelCache — manual eviction ──────────────────────────────────────
#[test]
fn test_model_cache_evict() {
let cache = ModelCache::new(ModelCacheConfig::default());
cache.get_or_insert("model-c", tiny_entry);
assert!(cache.contains("model-c"));
let removed = cache.evict("model-c");
assert!(removed);
assert!(!cache.contains("model-c"));
assert_eq!(cache.len(), 0);
// Evicting a non-existent key returns false.
assert!(!cache.evict("no-such-model"));
}
// ── ModelCache — stale eviction ──────────────────────────────────────
#[test]
fn test_model_cache_evict_stale() {
// Use a zero TTL so every entry is immediately stale.
let cfg = ModelCacheConfig {
ttl: Duration::from_nanos(0),
..Default::default()
};
let cache = ModelCache::new(cfg);
// Insert via get_or_insert so the entry lands in the map.
{
let mut entries = cache.entries.lock().expect("mutex should not be poisoned");
entries.insert("model-d".to_owned(), tiny_entry());
}
assert_eq!(cache.len(), 1);
let evicted = cache.evict_stale();
assert_eq!(evicted, 1);
assert_eq!(cache.len(), 0);
}
// ── ModelCache — hit rate ─────────────────────────────────────────────
#[test]
fn test_model_cache_hit_rate() {
let cache = ModelCache::new(ModelCacheConfig::default());
// No lookups yet → 0.0.
assert!((cache.hit_rate() - 0.0).abs() < f32::EPSILON);
cache.get_or_insert("rate-model", tiny_entry); // miss
cache.get_or_insert("rate-model", tiny_entry); // hit
cache.get_or_insert("rate-model", tiny_entry); // hit
// 2 hits out of 3 total → ~0.667
let rate = cache.hit_rate();
assert!(rate > 0.6 && rate < 0.7, "expected ~0.667, got {rate}");
}
// ── ModelCache — stats snapshot ───────────────────────────────────────
#[test]
fn test_model_cache_stats() {
let cache = ModelCache::new(ModelCacheConfig::default());
cache.get_or_insert("stats-model", tiny_entry); // miss
let stats = cache.stats();
assert_eq!(stats.cached_models, 1);
assert_eq!(stats.total_misses, 1);
assert_eq!(stats.total_hits, 0);
assert_eq!(stats.total_memory_bytes, 1024);
assert!(stats.oldest_entry_age_secs.is_some());
}
// ── ModelWarmup ───────────────────────────────────────────────────────
#[test]
fn test_warmup_runs_without_panic() {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let mut engine = InferenceEngine::new(config, params.clone(), 42);
let warmup = ModelWarmup::new().with_tokens(4).with_prompt("Hello");
let elapsed_ms = warmup.run(&mut engine, ¶ms);
// Warmup must complete (even if it generates 0 tokens on a tiny model).
// We just check it didn't panic and returned a sensible elapsed time.
assert!(elapsed_ms < 60_000, "warmup should complete in under 60 s");
assert!(ModelWarmup::needs_warmup(&engine));
}
}