use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, RwLock};
use std::time::Instant;
use oxillama_runtime::engine::{EngineConfig, InferenceEngine};
use crate::error::{ServerError, ServerResult};
use crate::router::eviction::LruQueue;
pub type ModelId = String;
pub struct LoadedModel {
pub engine: InferenceEngine,
pub last_used: Instant,
pub mem_bytes: usize,
pub inflight: u64,
}
impl std::fmt::Debug for LoadedModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoadedModel")
.field("last_used", &self.last_used)
.field("mem_bytes", &self.mem_bytes)
.field("inflight", &self.inflight)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ModelStatus {
pub id: String,
pub status: ModelLoadStatus,
pub mem_bytes: usize,
pub last_used_secs: u64,
pub inflight: u64,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelLoadStatus {
Loading,
Ready,
Failed,
}
#[derive(Debug, Clone)]
pub struct ModelSpec {
pub path: PathBuf,
pub quant: Option<String>,
}
pub struct ModelLoader {
registry: HashMap<ModelId, ModelSpec>,
pub default_context_size: Option<usize>,
pub default_num_threads: usize,
}
impl ModelLoader {
pub fn new() -> Self {
Self {
registry: HashMap::new(),
default_context_size: None,
default_num_threads: 4,
}
}
pub fn register(&mut self, id: impl Into<String>, spec: ModelSpec) {
self.registry.insert(id.into(), spec);
}
pub fn lookup(&self, id: &str) -> Option<&ModelSpec> {
self.registry.get(id)
}
pub fn build_engine_config(&self, id: &str, spec: &ModelSpec) -> EngineConfig {
tracing::debug!(model_id = id, path = %spec.path.display(), "building engine config");
EngineConfig {
model_path: spec.path.to_string_lossy().into_owned(),
context_size: self.default_context_size,
num_threads: self.default_num_threads,
..EngineConfig::default()
}
}
}
impl Default for ModelLoader {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PendingStatus {
Loading,
Failed(String),
}
pub struct PendingEntry {
pub status: PendingStatus,
pub mem_bytes: usize,
}
pub struct ModelPool {
loaded: HashMap<ModelId, Arc<RwLock<LoadedModel>>>,
pending: HashMap<ModelId, PendingEntry>,
lru: Mutex<LruQueue>,
capacity: usize,
mem_budget_bytes: usize,
loader: ModelLoader,
}
impl ModelPool {
pub fn new(capacity: usize, mem_budget_mb: usize) -> Self {
Self {
loaded: HashMap::with_capacity(capacity),
pending: HashMap::new(),
lru: Mutex::new(LruQueue::with_capacity(capacity)),
capacity,
mem_budget_bytes: mem_budget_mb.saturating_mul(1024 * 1024),
loader: ModelLoader::new(),
}
}
pub fn loader_register(&mut self, id: impl Into<String>, spec: ModelSpec) {
self.loader.register(id, spec);
}
pub fn loader(&self) -> &ModelLoader {
&self.loader
}
pub fn acquire(
&mut self,
model_id: &str,
ext_loader: Option<&ModelLoader>,
) -> ServerResult<Arc<RwLock<LoadedModel>>> {
if let Some(handle) = self.loaded.get(model_id) {
self.touch_lru(model_id);
if let Ok(mut guard) = handle.write() {
guard.inflight = guard.inflight.saturating_add(1);
guard.last_used = Instant::now();
}
return Ok(Arc::clone(handle));
}
let spec = {
let ldr = ext_loader.unwrap_or(&self.loader);
ldr.lookup(model_id)
.cloned()
.ok_or_else(|| ServerError::InvalidRequest {
message: format!("model '{model_id}' is not registered"),
})?
};
let estimated_mem = estimate_mem_bytes(&spec.path);
self.evict_until_budget(estimated_mem)?;
if self.loaded.len() >= self.capacity {
self.evict_one()?;
}
tracing::info!(model_id, "loading model into pool");
let engine_config = self.loader.build_engine_config(model_id, &spec);
let mut engine = InferenceEngine::new(engine_config);
engine.load_model().map_err(ServerError::Runtime)?;
tracing::info!(model_id, mem_bytes = estimated_mem, "model loaded");
let handle = Arc::new(RwLock::new(LoadedModel {
engine,
last_used: Instant::now(),
mem_bytes: estimated_mem,
inflight: 1,
}));
self.loaded
.insert(model_id.to_string(), Arc::clone(&handle));
self.touch_lru(model_id);
Ok(handle)
}
pub fn release(&self, model_id: &str) {
if let Some(handle) = self.loaded.get(model_id) {
if let Ok(mut guard) = handle.write() {
guard.inflight = guard.inflight.saturating_sub(1);
}
}
}
pub fn unload(&mut self, model_id: &str) -> ServerResult<()> {
if self.loaded.remove(model_id).is_none() {
return Err(ServerError::InvalidRequest {
message: format!("model '{model_id}' is not loaded"),
});
}
self.pending.remove(model_id);
if let Ok(mut lru) = self.lru.lock() {
lru.remove(model_id);
}
tracing::info!(model_id, "model unloaded from pool");
Ok(())
}
pub fn list(&self) -> Vec<ModelStatus> {
let mut out = Vec::with_capacity(self.loaded.len() + self.pending.len());
for (id, handle) in &self.loaded {
let (mem_bytes, last_used_secs, inflight) = if let Ok(guard) = handle.read() {
let secs = guard.last_used.elapsed().as_secs();
(guard.mem_bytes, secs, guard.inflight)
} else {
(0, 0, 0)
};
out.push(ModelStatus {
id: id.clone(),
status: ModelLoadStatus::Ready,
mem_bytes,
last_used_secs,
inflight,
});
}
for (id, entry) in &self.pending {
let status = match &entry.status {
PendingStatus::Loading => ModelLoadStatus::Loading,
PendingStatus::Failed(_) => ModelLoadStatus::Failed,
};
out.push(ModelStatus {
id: id.clone(),
status,
mem_bytes: entry.mem_bytes,
last_used_secs: 0,
inflight: 0,
});
}
out
}
pub fn mark_loading(&mut self, model_id: impl Into<String>) {
let id = model_id.into();
self.pending.insert(
id,
PendingEntry {
status: PendingStatus::Loading,
mem_bytes: 0,
},
);
}
pub fn mark_ready(
&mut self,
model_id: &str,
engine: InferenceEngine,
mem_bytes: usize,
) -> ServerResult<()> {
self.evict_until_budget(mem_bytes)?;
if self.loaded.len() >= self.capacity {
self.evict_one()?;
}
let handle = Arc::new(RwLock::new(LoadedModel {
engine,
last_used: Instant::now(),
mem_bytes,
inflight: 0,
}));
self.loaded
.insert(model_id.to_string(), Arc::clone(&handle));
self.pending.remove(model_id);
self.touch_lru(model_id);
Ok(())
}
pub fn mark_failed(&mut self, model_id: &str, reason: String) {
if let Some(entry) = self.pending.get_mut(model_id) {
entry.status = PendingStatus::Failed(reason);
}
}
pub fn current_mem_bytes(&self) -> usize {
self.loaded
.values()
.filter_map(|h| h.read().ok().map(|g| g.mem_bytes))
.sum()
}
fn touch_lru(&self, model_id: &str) {
if let Ok(mut lru) = self.lru.lock() {
lru.touch(model_id);
}
}
fn evict_until_budget(&mut self, needed_bytes: usize) -> ServerResult<()> {
if self.mem_budget_bytes == 0 {
return Ok(());
}
while self.current_mem_bytes() + needed_bytes > self.mem_budget_bytes {
self.evict_one().map_err(|_| ServerError::InvalidRequest {
message: "memory budget exceeded and no evictable model found".to_string(),
})?;
}
Ok(())
}
fn evict_one(&mut self) -> ServerResult<()> {
let victim = {
let mut lru = self.lru.lock().map_err(|_| ServerError::InvalidRequest {
message: "LRU queue lock poisoned".to_string(),
})?;
lru.evict_lru()
};
let victim = victim.ok_or_else(|| ServerError::InvalidRequest {
message: "no model to evict — pool is empty".to_string(),
})?;
let inflight = self
.loaded
.get(&victim)
.and_then(|h| h.read().ok().map(|g| g.inflight))
.unwrap_or(0);
if inflight > 0 {
self.touch_lru(&victim);
return Err(ServerError::InvalidRequest {
message: format!("cannot evict '{victim}': {inflight} request(s) in flight"),
});
}
tracing::info!(model_id = %victim, "evicting model from pool (LRU)");
self.loaded.remove(&victim);
Ok(())
}
}
fn estimate_mem_bytes(path: &std::path::Path) -> usize {
const KV_OVERHEAD: usize = 64 * 1024 * 1024;
let file_size = std::fs::metadata(path)
.map(|m| m.len() as usize)
.unwrap_or(0);
file_size.saturating_add(KV_OVERHEAD)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_single_model_routes() {
let mut pool = ModelPool::new(2, 0);
let engine = InferenceEngine::new(EngineConfig::default());
let handle = Arc::new(RwLock::new(LoadedModel {
engine,
last_used: Instant::now(),
mem_bytes: 0,
inflight: 0,
}));
pool.loaded
.insert("model-a".to_string(), Arc::clone(&handle));
pool.touch_lru("model-a");
let h1 = pool.acquire("model-a", None).expect("first acquire");
let h2 = pool.acquire("model-a", None).expect("second acquire");
assert!(
Arc::ptr_eq(&h1, &h2),
"both acquires should return the same Arc"
);
}
#[test]
fn pool_evicts_when_over_capacity() {
let mut pool = ModelPool::new(1, 0);
let engine_a = InferenceEngine::new(EngineConfig::default());
let handle_a = Arc::new(RwLock::new(LoadedModel {
engine: engine_a,
last_used: Instant::now(),
mem_bytes: 0,
inflight: 0,
}));
pool.loaded.insert("model-a".to_string(), handle_a);
pool.touch_lru("model-a");
let engine_b = InferenceEngine::new(EngineConfig::default());
pool.mark_ready("model-b", engine_b, 0)
.expect("mark_ready should succeed after evicting model-a");
assert!(
!pool.loaded.contains_key("model-a"),
"model-a should have been evicted"
);
assert!(
pool.loaded.contains_key("model-b"),
"model-b should now be loaded"
);
}
#[test]
fn pool_unknown_model_returns_error() {
let mut pool = ModelPool::new(4, 0);
let err = pool.acquire("unknown-model", None).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("not registered"),
"error should mention 'not registered': {msg}"
);
}
#[test]
fn pool_list_shows_loaded() {
let mut pool = ModelPool::new(4, 0);
for name in ["model-x", "model-y"] {
let engine = InferenceEngine::new(EngineConfig::default());
let handle = Arc::new(RwLock::new(LoadedModel {
engine,
last_used: Instant::now(),
mem_bytes: 1024,
inflight: 0,
}));
pool.loaded.insert(name.to_string(), handle);
pool.touch_lru(name);
}
let statuses = pool.list();
assert_eq!(statuses.len(), 2, "list should report both models");
let ids: Vec<&str> = statuses.iter().map(|s| s.id.as_str()).collect();
assert!(ids.contains(&"model-x"), "model-x should appear in list");
assert!(ids.contains(&"model-y"), "model-y should appear in list");
for s in &statuses {
assert_eq!(s.status, ModelLoadStatus::Ready);
assert_eq!(s.mem_bytes, 1024);
}
}
#[test]
fn pool_lru_ordering() {
let mut pool = ModelPool::new(3, 0);
for name in ["alpha", "beta", "gamma"] {
let engine = InferenceEngine::new(EngineConfig::default());
let handle = Arc::new(RwLock::new(LoadedModel {
engine,
last_used: Instant::now(),
mem_bytes: 0,
inflight: 0,
}));
pool.loaded.insert(name.to_string(), handle);
pool.touch_lru(name);
}
pool.touch_lru("alpha");
pool.touch_lru("beta");
pool.evict_one().expect("should evict gamma");
assert!(
!pool.loaded.contains_key("gamma"),
"gamma should have been evicted"
);
assert!(pool.loaded.contains_key("alpha"), "alpha should remain");
assert!(pool.loaded.contains_key("beta"), "beta should remain");
}
#[test]
fn pool_evicts_when_over_budget() {
let mut pool = ModelPool::new(1, 1);
let engine_a = InferenceEngine::new(EngineConfig::default());
pool.mark_ready("big-model", engine_a, 0)
.expect("first mark_ready should succeed");
assert!(
pool.loaded.contains_key("big-model"),
"big-model should be in pool after mark_ready"
);
let engine_b = InferenceEngine::new(EngineConfig::default());
pool.mark_ready("small-model", engine_b, 0)
.expect("second mark_ready should evict big-model and succeed");
assert!(
!pool.loaded.contains_key("big-model"),
"big-model should have been evicted when small-model was loaded"
);
assert!(
pool.loaded.contains_key("small-model"),
"small-model should now be in the pool"
);
}
}