1use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15
16use oxibonsai_core::config::Qwen3Config;
17
18use crate::engine::InferenceEngine;
19use crate::sampling::SamplingParams;
20
21pub struct ModelEntry {
31 pub config: Qwen3Config,
33 pub model_path: Option<String>,
35 pub loaded_at: Instant,
37 pub last_used: Instant,
39 pub use_count: u64,
41 pub memory_bytes: usize,
43}
44
45impl ModelEntry {
46 pub fn new(config: Qwen3Config, model_path: Option<String>, memory_bytes: usize) -> Self {
48 let now = Instant::now();
49 Self {
50 config,
51 model_path,
52 loaded_at: now,
53 last_used: now,
54 use_count: 0,
55 memory_bytes,
56 }
57 }
58
59 pub fn age(&self) -> Duration {
61 self.loaded_at.elapsed()
62 }
63
64 pub fn idle_time(&self) -> Duration {
66 self.last_used.elapsed()
67 }
68
69 pub fn is_stale(&self, ttl: Duration) -> bool {
71 self.idle_time() >= ttl
72 }
73}
74
75#[derive(Debug, Clone)]
81pub struct ModelCacheConfig {
82 pub max_models: usize,
84 pub ttl: Duration,
86 pub evict_on_memory_pressure: bool,
89 pub memory_budget_bytes: Option<usize>,
93}
94
95impl Default for ModelCacheConfig {
96 fn default() -> Self {
97 Self {
98 max_models: 4,
99 ttl: Duration::from_secs(3600),
100 evict_on_memory_pressure: true,
101 memory_budget_bytes: None,
102 }
103 }
104}
105
106#[derive(Debug, serde::Serialize)]
112pub struct ModelCacheStats {
113 pub cached_models: usize,
115 pub total_hits: u64,
117 pub total_misses: u64,
119 pub hit_rate: f32,
121 pub total_memory_bytes: usize,
123 pub oldest_entry_age_secs: Option<u64>,
125}
126
127pub struct ModelCache {
137 entries: Mutex<HashMap<String, ModelEntry>>,
138 config: ModelCacheConfig,
139 pub hits: AtomicU64,
141 pub misses: AtomicU64,
143}
144
145impl ModelCache {
146 pub fn new(config: ModelCacheConfig) -> Self {
148 Self {
149 entries: Mutex::new(HashMap::new()),
150 config,
151 hits: AtomicU64::new(0),
152 misses: AtomicU64::new(0),
153 }
154 }
155
156 pub fn get_or_insert<F>(&self, key: &str, loader: F) -> Arc<ModelEntry>
163 where
164 F: FnOnce() -> ModelEntry,
165 {
166 let mut entries = self
167 .entries
168 .lock()
169 .expect("model cache mutex should not be poisoned");
170
171 if let Some(entry) = entries.get_mut(key) {
173 if !entry.is_stale(self.config.ttl) {
174 entry.last_used = Instant::now();
175 entry.use_count += 1;
176 self.hits.fetch_add(1, Ordering::Relaxed);
177 return Arc::new(ModelEntry {
181 config: entry.config.clone(),
182 model_path: entry.model_path.clone(),
183 loaded_at: entry.loaded_at,
184 last_used: entry.last_used,
185 use_count: entry.use_count,
186 memory_bytes: entry.memory_bytes,
187 });
188 }
189 entries.remove(key);
191 }
192
193 self.misses.fetch_add(1, Ordering::Relaxed);
195 let new_entry = loader();
196
197 self.evict_if_needed_locked(&mut entries, new_entry.memory_bytes);
199
200 let result = Arc::new(ModelEntry {
201 config: new_entry.config.clone(),
202 model_path: new_entry.model_path.clone(),
203 loaded_at: new_entry.loaded_at,
204 last_used: new_entry.last_used,
205 use_count: new_entry.use_count,
206 memory_bytes: new_entry.memory_bytes,
207 });
208
209 entries.insert(key.to_owned(), new_entry);
210 result
211 }
212
213 pub fn contains(&self, key: &str) -> bool {
215 let entries = self
216 .entries
217 .lock()
218 .expect("model cache mutex should not be poisoned");
219 entries
220 .get(key)
221 .map(|e| !e.is_stale(self.config.ttl))
222 .unwrap_or(false)
223 }
224
225 pub fn evict(&self, key: &str) -> bool {
227 let mut entries = self
228 .entries
229 .lock()
230 .expect("model cache mutex should not be poisoned");
231 entries.remove(key).is_some()
232 }
233
234 pub fn evict_stale(&self) -> usize {
237 let mut entries = self
238 .entries
239 .lock()
240 .expect("model cache mutex should not be poisoned");
241 let ttl = self.config.ttl;
242 let before = entries.len();
243 entries.retain(|_, e| !e.is_stale(ttl));
244 before - entries.len()
245 }
246
247 pub fn len(&self) -> usize {
249 self.entries
250 .lock()
251 .expect("model cache mutex should not be poisoned")
252 .len()
253 }
254
255 pub fn is_empty(&self) -> bool {
257 self.len() == 0
258 }
259
260 pub fn hit_rate(&self) -> f32 {
264 let hits = self.hits.load(Ordering::Relaxed);
265 let misses = self.misses.load(Ordering::Relaxed);
266 let total = hits + misses;
267 if total == 0 {
268 return 0.0;
269 }
270 hits as f32 / total as f32
271 }
272
273 pub fn total_memory_bytes(&self) -> usize {
275 self.entries
276 .lock()
277 .expect("model cache mutex should not be poisoned")
278 .values()
279 .map(|e| e.memory_bytes)
280 .sum()
281 }
282
283 pub fn stats(&self) -> ModelCacheStats {
285 let entries = self
286 .entries
287 .lock()
288 .expect("model cache mutex should not be poisoned");
289 let hits = self.hits.load(Ordering::Relaxed);
290 let misses = self.misses.load(Ordering::Relaxed);
291 let total = hits + misses;
292 let hit_rate = if total == 0 {
293 0.0
294 } else {
295 hits as f32 / total as f32
296 };
297 let total_memory_bytes: usize = entries.values().map(|e| e.memory_bytes).sum();
298 let oldest_entry_age_secs = entries.values().map(|e| e.age().as_secs()).max();
299
300 ModelCacheStats {
301 cached_models: entries.len(),
302 total_hits: hits,
303 total_misses: misses,
304 hit_rate,
305 total_memory_bytes,
306 oldest_entry_age_secs,
307 }
308 }
309
310 fn evict_if_needed_locked(
316 &self,
317 entries: &mut HashMap<String, ModelEntry>,
318 incoming_bytes: usize,
319 ) {
320 while entries.len() >= self.config.max_models {
322 Self::evict_lru(entries);
323 }
324
325 if self.config.evict_on_memory_pressure {
327 if let Some(budget) = self.config.memory_budget_bytes {
328 let current: usize = entries.values().map(|e| e.memory_bytes).sum();
329 let projected = current.saturating_add(incoming_bytes);
330 while projected > budget && !entries.is_empty() {
331 Self::evict_lru(entries);
332 }
333 }
334 }
335 }
336
337 fn evict_lru(entries: &mut HashMap<String, ModelEntry>) {
339 if entries.is_empty() {
340 return;
341 }
342 let lru_key = entries
343 .iter()
344 .max_by_key(|(_, e)| {
345 e.idle_time().as_micros()
347 })
348 .map(|(k, _)| k.clone());
349
350 if let Some(key) = lru_key {
351 entries.remove(&key);
352 }
353 }
354}
355
356pub struct ModelWarmup {
364 pub num_warmup_tokens: usize,
366 pub warmup_prompt: String,
368}
369
370impl Default for ModelWarmup {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375
376impl ModelWarmup {
377 pub fn new() -> Self {
379 Self {
380 num_warmup_tokens: 32,
381 warmup_prompt: "Warm up the inference engine.".to_owned(),
382 }
383 }
384
385 pub fn with_tokens(mut self, n: usize) -> Self {
387 self.num_warmup_tokens = n;
388 self
389 }
390
391 pub fn with_prompt(mut self, p: &str) -> Self {
393 self.warmup_prompt = p.to_owned();
394 self
395 }
396
397 pub fn run(&self, engine: &mut InferenceEngine<'_>, params: &SamplingParams) -> u64 {
406 let start = Instant::now();
407
408 let dummy_tokens: Vec<u32> = self
411 .warmup_prompt
412 .bytes()
413 .take(16)
414 .map(|b| u32::from(b) % 32000)
415 .collect();
416
417 let prompt_tokens = if dummy_tokens.is_empty() {
418 vec![151644u32] } else {
420 dummy_tokens
421 };
422
423 match engine.generate_with_seed(&prompt_tokens, self.num_warmup_tokens, 0, params) {
425 Ok(toks) => {
426 tracing::debug!(generated = toks.len(), "warmup pass completed");
427 }
428 Err(e) => {
429 tracing::warn!(error = %e, "warmup pass encountered an error (non-fatal)");
430 }
431 }
432
433 engine.reset();
435
436 start.elapsed().as_millis() as u64
437 }
438
439 pub fn needs_warmup(_engine: &InferenceEngine<'_>) -> bool {
445 true
446 }
447}
448
449#[cfg(test)]
454mod tests {
455 use super::*;
456 use oxibonsai_core::config::Qwen3Config;
457
458 fn tiny_entry() -> ModelEntry {
459 ModelEntry::new(
460 Qwen3Config::tiny_test(),
461 Some(std::env::temp_dir().join("tiny.gguf").display().to_string()),
462 1024,
463 )
464 }
465
466 #[test]
469 fn test_model_entry_age() {
470 let entry = tiny_entry();
471 let age = entry.age();
472 assert!(age < Duration::from_secs(1));
474 }
475
476 #[test]
477 fn test_model_entry_is_stale() {
478 let entry = tiny_entry();
479 assert!(!entry.is_stale(Duration::from_secs(3600)));
481 assert!(entry.is_stale(Duration::from_nanos(0)));
483 }
484
485 #[test]
488 fn test_model_cache_miss_calls_loader() {
489 let cache = ModelCache::new(ModelCacheConfig::default());
490 let mut loader_called = false;
491
492 let _entry = cache.get_or_insert("model-a", || {
493 loader_called = true;
494 tiny_entry()
495 });
496
497 assert!(loader_called, "loader should have been called on a miss");
498 assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
499 assert_eq!(cache.hits.load(Ordering::Relaxed), 0);
500 assert_eq!(cache.len(), 1);
501 }
502
503 #[test]
506 fn test_model_cache_hit_skips_loader() {
507 let cache = ModelCache::new(ModelCacheConfig::default());
508
509 cache.get_or_insert("model-b", tiny_entry);
511
512 let mut second_loader_called = false;
514 cache.get_or_insert("model-b", || {
515 second_loader_called = true;
516 tiny_entry()
517 });
518
519 assert!(!second_loader_called, "loader must not be called on a hit");
520 assert_eq!(cache.hits.load(Ordering::Relaxed), 1);
521 assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
522 }
523
524 #[test]
527 fn test_model_cache_evict() {
528 let cache = ModelCache::new(ModelCacheConfig::default());
529 cache.get_or_insert("model-c", tiny_entry);
530 assert!(cache.contains("model-c"));
531
532 let removed = cache.evict("model-c");
533 assert!(removed);
534 assert!(!cache.contains("model-c"));
535 assert_eq!(cache.len(), 0);
536
537 assert!(!cache.evict("no-such-model"));
539 }
540
541 #[test]
544 fn test_model_cache_evict_stale() {
545 let cfg = ModelCacheConfig {
547 ttl: Duration::from_nanos(0),
548 ..Default::default()
549 };
550 let cache = ModelCache::new(cfg);
551
552 {
554 let mut entries = cache.entries.lock().expect("mutex should not be poisoned");
555 entries.insert("model-d".to_owned(), tiny_entry());
556 }
557
558 assert_eq!(cache.len(), 1);
559 let evicted = cache.evict_stale();
560 assert_eq!(evicted, 1);
561 assert_eq!(cache.len(), 0);
562 }
563
564 #[test]
567 fn test_model_cache_hit_rate() {
568 let cache = ModelCache::new(ModelCacheConfig::default());
569
570 assert!((cache.hit_rate() - 0.0).abs() < f32::EPSILON);
572
573 cache.get_or_insert("rate-model", tiny_entry); cache.get_or_insert("rate-model", tiny_entry); cache.get_or_insert("rate-model", tiny_entry); let rate = cache.hit_rate();
579 assert!(rate > 0.6 && rate < 0.7, "expected ~0.667, got {rate}");
580 }
581
582 #[test]
585 fn test_model_cache_stats() {
586 let cache = ModelCache::new(ModelCacheConfig::default());
587 cache.get_or_insert("stats-model", tiny_entry); let stats = cache.stats();
590 assert_eq!(stats.cached_models, 1);
591 assert_eq!(stats.total_misses, 1);
592 assert_eq!(stats.total_hits, 0);
593 assert_eq!(stats.total_memory_bytes, 1024);
594 assert!(stats.oldest_entry_age_secs.is_some());
595 }
596
597 #[test]
600 fn test_warmup_runs_without_panic() {
601 let config = Qwen3Config::tiny_test();
602 let params = SamplingParams::default();
603 let mut engine = InferenceEngine::new(config, params.clone(), 42);
604
605 let warmup = ModelWarmup::new().with_tokens(4).with_prompt("Hello");
606 let elapsed_ms = warmup.run(&mut engine, ¶ms);
607
608 assert!(elapsed_ms < 60_000, "warmup should complete in under 60 s");
611 assert!(ModelWarmup::needs_warmup(&engine));
612 }
613}