use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use oxibonsai_core::config::Qwen3Config;
use crate::engine::InferenceEngine;
use crate::sampling::SamplingParams;
pub struct ModelEntry {
pub config: Qwen3Config,
pub model_path: Option<String>,
pub loaded_at: Instant,
pub last_used: Instant,
pub use_count: u64,
pub memory_bytes: usize,
}
impl ModelEntry {
pub fn new(config: Qwen3Config, model_path: Option<String>, memory_bytes: usize) -> Self {
let now = Instant::now();
Self {
config,
model_path,
loaded_at: now,
last_used: now,
use_count: 0,
memory_bytes,
}
}
pub fn age(&self) -> Duration {
self.loaded_at.elapsed()
}
pub fn idle_time(&self) -> Duration {
self.last_used.elapsed()
}
pub fn is_stale(&self, ttl: Duration) -> bool {
self.idle_time() >= ttl
}
}
#[derive(Debug, Clone)]
pub struct ModelCacheConfig {
pub max_models: usize,
pub ttl: Duration,
pub evict_on_memory_pressure: bool,
pub memory_budget_bytes: Option<usize>,
}
impl Default for ModelCacheConfig {
fn default() -> Self {
Self {
max_models: 4,
ttl: Duration::from_secs(3600),
evict_on_memory_pressure: true,
memory_budget_bytes: None,
}
}
}
#[derive(Debug, serde::Serialize)]
pub struct ModelCacheStats {
pub cached_models: usize,
pub total_hits: u64,
pub total_misses: u64,
pub hit_rate: f32,
pub total_memory_bytes: usize,
pub oldest_entry_age_secs: Option<u64>,
}
pub struct ModelCache {
entries: Mutex<HashMap<String, ModelEntry>>,
config: ModelCacheConfig,
pub hits: AtomicU64,
pub misses: AtomicU64,
}
impl ModelCache {
pub fn new(config: ModelCacheConfig) -> Self {
Self {
entries: Mutex::new(HashMap::new()),
config,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
pub fn get_or_insert<F>(&self, key: &str, loader: F) -> Arc<ModelEntry>
where
F: FnOnce() -> ModelEntry,
{
let mut entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
if let Some(entry) = entries.get_mut(key) {
if !entry.is_stale(self.config.ttl) {
entry.last_used = Instant::now();
entry.use_count += 1;
self.hits.fetch_add(1, Ordering::Relaxed);
return Arc::new(ModelEntry {
config: entry.config.clone(),
model_path: entry.model_path.clone(),
loaded_at: entry.loaded_at,
last_used: entry.last_used,
use_count: entry.use_count,
memory_bytes: entry.memory_bytes,
});
}
entries.remove(key);
}
self.misses.fetch_add(1, Ordering::Relaxed);
let new_entry = loader();
self.evict_if_needed_locked(&mut entries, new_entry.memory_bytes);
let result = Arc::new(ModelEntry {
config: new_entry.config.clone(),
model_path: new_entry.model_path.clone(),
loaded_at: new_entry.loaded_at,
last_used: new_entry.last_used,
use_count: new_entry.use_count,
memory_bytes: new_entry.memory_bytes,
});
entries.insert(key.to_owned(), new_entry);
result
}
pub fn contains(&self, key: &str) -> bool {
let entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
entries
.get(key)
.map(|e| !e.is_stale(self.config.ttl))
.unwrap_or(false)
}
pub fn evict(&self, key: &str) -> bool {
let mut entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
entries.remove(key).is_some()
}
pub fn evict_stale(&self) -> usize {
let mut entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
let ttl = self.config.ttl;
let before = entries.len();
entries.retain(|_, e| !e.is_stale(ttl));
before - entries.len()
}
pub fn len(&self) -> usize {
self.entries
.lock()
.expect("model cache mutex should not be poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn hit_rate(&self) -> f32 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
return 0.0;
}
hits as f32 / total as f32
}
pub fn total_memory_bytes(&self) -> usize {
self.entries
.lock()
.expect("model cache mutex should not be poisoned")
.values()
.map(|e| e.memory_bytes)
.sum()
}
pub fn stats(&self) -> ModelCacheStats {
let entries = self
.entries
.lock()
.expect("model cache mutex should not be poisoned");
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
let hit_rate = if total == 0 {
0.0
} else {
hits as f32 / total as f32
};
let total_memory_bytes: usize = entries.values().map(|e| e.memory_bytes).sum();
let oldest_entry_age_secs = entries.values().map(|e| e.age().as_secs()).max();
ModelCacheStats {
cached_models: entries.len(),
total_hits: hits,
total_misses: misses,
hit_rate,
total_memory_bytes,
oldest_entry_age_secs,
}
}
fn evict_if_needed_locked(
&self,
entries: &mut HashMap<String, ModelEntry>,
incoming_bytes: usize,
) {
while entries.len() >= self.config.max_models {
Self::evict_lru(entries);
}
if self.config.evict_on_memory_pressure {
if let Some(budget) = self.config.memory_budget_bytes {
let current: usize = entries.values().map(|e| e.memory_bytes).sum();
let projected = current.saturating_add(incoming_bytes);
while projected > budget && !entries.is_empty() {
Self::evict_lru(entries);
}
}
}
}
fn evict_lru(entries: &mut HashMap<String, ModelEntry>) {
if entries.is_empty() {
return;
}
let lru_key = entries
.iter()
.max_by_key(|(_, e)| {
e.idle_time().as_micros()
})
.map(|(k, _)| k.clone());
if let Some(key) = lru_key {
entries.remove(&key);
}
}
}
pub struct ModelWarmup {
pub num_warmup_tokens: usize,
pub warmup_prompt: String,
}
impl Default for ModelWarmup {
fn default() -> Self {
Self::new()
}
}
impl ModelWarmup {
pub fn new() -> Self {
Self {
num_warmup_tokens: 32,
warmup_prompt: "Warm up the inference engine.".to_owned(),
}
}
pub fn with_tokens(mut self, n: usize) -> Self {
self.num_warmup_tokens = n;
self
}
pub fn with_prompt(mut self, p: &str) -> Self {
self.warmup_prompt = p.to_owned();
self
}
pub fn run(&self, engine: &mut InferenceEngine<'_>, params: &SamplingParams) -> u64 {
let start = Instant::now();
let dummy_tokens: Vec<u32> = self
.warmup_prompt
.bytes()
.take(16)
.map(|b| u32::from(b) % 32000)
.collect();
let prompt_tokens = if dummy_tokens.is_empty() {
vec![151644u32] } else {
dummy_tokens
};
match engine.generate_with_seed(&prompt_tokens, self.num_warmup_tokens, 0, params) {
Ok(toks) => {
tracing::debug!(generated = toks.len(), "warmup pass completed");
}
Err(e) => {
tracing::warn!(error = %e, "warmup pass encountered an error (non-fatal)");
}
}
engine.reset();
start.elapsed().as_millis() as u64
}
pub fn needs_warmup(_engine: &InferenceEngine<'_>) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxibonsai_core::config::Qwen3Config;
fn tiny_entry() -> ModelEntry {
ModelEntry::new(
Qwen3Config::tiny_test(),
Some(std::env::temp_dir().join("tiny.gguf").display().to_string()),
1024,
)
}
#[test]
fn test_model_entry_age() {
let entry = tiny_entry();
let age = entry.age();
assert!(age < Duration::from_secs(1));
}
#[test]
fn test_model_entry_is_stale() {
let entry = tiny_entry();
assert!(!entry.is_stale(Duration::from_secs(3600)));
assert!(entry.is_stale(Duration::from_nanos(0)));
}
#[test]
fn test_model_cache_miss_calls_loader() {
let cache = ModelCache::new(ModelCacheConfig::default());
let mut loader_called = false;
let _entry = cache.get_or_insert("model-a", || {
loader_called = true;
tiny_entry()
});
assert!(loader_called, "loader should have been called on a miss");
assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
assert_eq!(cache.hits.load(Ordering::Relaxed), 0);
assert_eq!(cache.len(), 1);
}
#[test]
fn test_model_cache_hit_skips_loader() {
let cache = ModelCache::new(ModelCacheConfig::default());
cache.get_or_insert("model-b", tiny_entry);
let mut second_loader_called = false;
cache.get_or_insert("model-b", || {
second_loader_called = true;
tiny_entry()
});
assert!(!second_loader_called, "loader must not be called on a hit");
assert_eq!(cache.hits.load(Ordering::Relaxed), 1);
assert_eq!(cache.misses.load(Ordering::Relaxed), 1);
}
#[test]
fn test_model_cache_evict() {
let cache = ModelCache::new(ModelCacheConfig::default());
cache.get_or_insert("model-c", tiny_entry);
assert!(cache.contains("model-c"));
let removed = cache.evict("model-c");
assert!(removed);
assert!(!cache.contains("model-c"));
assert_eq!(cache.len(), 0);
assert!(!cache.evict("no-such-model"));
}
#[test]
fn test_model_cache_evict_stale() {
let cfg = ModelCacheConfig {
ttl: Duration::from_nanos(0),
..Default::default()
};
let cache = ModelCache::new(cfg);
{
let mut entries = cache.entries.lock().expect("mutex should not be poisoned");
entries.insert("model-d".to_owned(), tiny_entry());
}
assert_eq!(cache.len(), 1);
let evicted = cache.evict_stale();
assert_eq!(evicted, 1);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_model_cache_hit_rate() {
let cache = ModelCache::new(ModelCacheConfig::default());
assert!((cache.hit_rate() - 0.0).abs() < f32::EPSILON);
cache.get_or_insert("rate-model", tiny_entry); cache.get_or_insert("rate-model", tiny_entry); cache.get_or_insert("rate-model", tiny_entry);
let rate = cache.hit_rate();
assert!(rate > 0.6 && rate < 0.7, "expected ~0.667, got {rate}");
}
#[test]
fn test_model_cache_stats() {
let cache = ModelCache::new(ModelCacheConfig::default());
cache.get_or_insert("stats-model", tiny_entry);
let stats = cache.stats();
assert_eq!(stats.cached_models, 1);
assert_eq!(stats.total_misses, 1);
assert_eq!(stats.total_hits, 0);
assert_eq!(stats.total_memory_bytes, 1024);
assert!(stats.oldest_entry_age_secs.is_some());
}
#[test]
fn test_warmup_runs_without_panic() {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let mut engine = InferenceEngine::new(config, params.clone(), 42);
let warmup = ModelWarmup::new().with_tokens(4).with_prompt("Hello");
let elapsed_ms = warmup.run(&mut engine, ¶ms);
assert!(elapsed_ms < 60_000, "warmup should complete in under 60 s");
assert!(ModelWarmup::needs_warmup(&engine));
}
}