Skip to main content

oxillama_server/router/
pool.rs

1//! Multi-model LRU warm-pool.
2//!
3//! The `ModelPool` holds up to `capacity` `InferenceEngine` instances in a
4//! `HashMap` keyed by model identifier. When a request arrives the pool's
5//! `acquire` method is called: the model is returned immediately if already
6//! loaded; otherwise it is loaded from disk (evicting the LRU entry if the
7//! pool is full or over its memory budget) and then returned.
8//!
9//! Architecture note: The pool itself is owned by the single inference worker
10//! thread.  All mutations happen on that thread — no `Arc<Mutex<...>>` is
11//! needed around the pool itself.  `Arc<RwLock<LoadedModel>>` is used for the
12//! per-model handle so that future multi-worker scenarios can share the engine
13//! while one caller holds a read lock.
14
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, Mutex, RwLock};
18use std::time::Instant;
19
20use oxillama_runtime::engine::{EngineConfig, InferenceEngine};
21
22use crate::error::{ServerError, ServerResult};
23use crate::router::eviction::LruQueue;
24
25/// Identifier type alias for model IDs.
26pub type ModelId = String;
27
28/// A single loaded model with its engine and bookkeeping data.
29pub struct LoadedModel {
30    /// The owned inference engine.
31    pub engine: InferenceEngine,
32    /// Monotonic timestamp of the last request that used this model.
33    pub last_used: Instant,
34    /// Estimated resident memory in bytes:
35    /// `weights_size + max_batch * (kv_size_per_seq + state_size_per_seq)`.
36    ///
37    /// Used by the pool to enforce the memory budget.
38    pub mem_bytes: usize,
39    /// Number of requests currently using this model.
40    pub inflight: u64,
41}
42
43// Manual Debug impl because InferenceEngine does not derive Debug.
44impl std::fmt::Debug for LoadedModel {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("LoadedModel")
47            .field("last_used", &self.last_used)
48            .field("mem_bytes", &self.mem_bytes)
49            .field("inflight", &self.inflight)
50            .finish_non_exhaustive()
51    }
52}
53
54/// Status of a model in the pool (used by the Admin API).
55#[derive(Debug, Clone, serde::Serialize)]
56pub struct ModelStatus {
57    /// Model identifier.
58    pub id: String,
59    /// Load state.
60    pub status: ModelLoadStatus,
61    /// Estimated memory footprint in bytes (0 if not yet loaded).
62    pub mem_bytes: usize,
63    /// Last-used timestamp (seconds since UNIX epoch, 0 if never used).
64    pub last_used_secs: u64,
65    /// Number of requests currently using this model.
66    pub inflight: u64,
67}
68
69/// Load state of a model entry.
70#[derive(Debug, Clone, PartialEq, serde::Serialize)]
71#[serde(rename_all = "snake_case")]
72pub enum ModelLoadStatus {
73    /// The model is being loaded in a background task.
74    Loading,
75    /// The model is loaded and ready for inference.
76    Ready,
77    /// Loading failed; the model cannot be used.
78    Failed,
79}
80
81/// Resolved model path + optional quantisation hint for a named model.
82#[derive(Debug, Clone)]
83pub struct ModelSpec {
84    /// Filesystem path to the `.gguf` file.
85    pub path: PathBuf,
86    /// Quantisation hint (e.g. `"q4_0"`).  Currently informational only.
87    pub quant: Option<String>,
88}
89
90/// A registry that maps model IDs to filesystem specs.
91///
92/// Used by `ModelPool::acquire` to locate models it has not yet loaded.
93pub struct ModelLoader {
94    registry: HashMap<ModelId, ModelSpec>,
95    /// Default context size to pass to the engine.
96    pub default_context_size: Option<usize>,
97    /// Default thread count.
98    pub default_num_threads: usize,
99}
100
101impl ModelLoader {
102    /// Create a new loader with no registered models.
103    pub fn new() -> Self {
104        Self {
105            registry: HashMap::new(),
106            default_context_size: None,
107            default_num_threads: 4,
108        }
109    }
110
111    /// Register a model ID → spec mapping so the pool can load it on demand.
112    pub fn register(&mut self, id: impl Into<String>, spec: ModelSpec) {
113        self.registry.insert(id.into(), spec);
114    }
115
116    /// Look up the spec for a model ID.
117    pub fn lookup(&self, id: &str) -> Option<&ModelSpec> {
118        self.registry.get(id)
119    }
120
121    /// Build an `EngineConfig` for the given model spec.
122    pub fn build_engine_config(&self, id: &str, spec: &ModelSpec) -> EngineConfig {
123        tracing::debug!(model_id = id, path = %spec.path.display(), "building engine config");
124        EngineConfig {
125            model_path: spec.path.to_string_lossy().into_owned(),
126            context_size: self.default_context_size,
127            num_threads: self.default_num_threads,
128            ..EngineConfig::default()
129        }
130    }
131}
132
133impl Default for ModelLoader {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139/// Per-model pending-load state (used when async background loading is active).
140#[derive(Debug, Clone, PartialEq)]
141pub enum PendingStatus {
142    Loading,
143    Failed(String),
144}
145
146/// Entry in the pool for a model that is still loading (or failed).
147pub struct PendingEntry {
148    pub status: PendingStatus,
149    pub mem_bytes: usize,
150}
151
152/// The multi-model LRU warm-pool.
153///
154/// Owned by the inference worker thread; no `Send + Sync` requirement for the
155/// `HashMap` internals because all accesses happen on one thread.
156pub struct ModelPool {
157    /// Live engines, keyed by model ID.
158    loaded: HashMap<ModelId, Arc<RwLock<LoadedModel>>>,
159    /// Models being loaded in background or that failed to load.
160    pending: HashMap<ModelId, PendingEntry>,
161    /// LRU ordering of loaded model IDs.
162    lru: Mutex<LruQueue>,
163    /// Maximum number of concurrently loaded models.
164    capacity: usize,
165    /// Maximum total memory budget in bytes (0 = unlimited).
166    mem_budget_bytes: usize,
167    /// Internal model loader (registry of id → spec mappings).
168    loader: ModelLoader,
169}
170
171impl ModelPool {
172    /// Create a new empty pool.
173    ///
174    /// - `capacity`: maximum number of models that may be resident at once.
175    /// - `mem_budget_mb`: memory budget in MiB (0 = unlimited).
176    pub fn new(capacity: usize, mem_budget_mb: usize) -> Self {
177        Self {
178            loaded: HashMap::with_capacity(capacity),
179            pending: HashMap::new(),
180            lru: Mutex::new(LruQueue::with_capacity(capacity)),
181            capacity,
182            mem_budget_bytes: mem_budget_mb.saturating_mul(1024 * 1024),
183            loader: ModelLoader::new(),
184        }
185    }
186
187    /// Register a model spec so it can be loaded on demand via `acquire`.
188    ///
189    /// Called by the admin `POST /admin/models/load` route before initiating a
190    /// background load, and at startup for models listed in `[router] preload`.
191    pub fn loader_register(&mut self, id: impl Into<String>, spec: ModelSpec) {
192        self.loader.register(id, spec);
193    }
194
195    /// Access the embedded loader (read-only).
196    pub fn loader(&self) -> &ModelLoader {
197        &self.loader
198    }
199
200    /// Acquire an engine for `model_id`.
201    ///
202    /// If the model is already loaded it is promoted to the MRU position and
203    /// its `Arc<RwLock<LoadedModel>>` is returned immediately.
204    ///
205    /// Otherwise the model is loaded synchronously (blocking the calling thread)
206    /// after evicting LRU entries as needed.
207    ///
208    /// The optional `loader` parameter allows callers (like tests) to supply
209    /// an external loader; `None` uses the pool's embedded loader.
210    pub fn acquire(
211        &mut self,
212        model_id: &str,
213        ext_loader: Option<&ModelLoader>,
214    ) -> ServerResult<Arc<RwLock<LoadedModel>>> {
215        // Fast path — already loaded.
216        if let Some(handle) = self.loaded.get(model_id) {
217            self.touch_lru(model_id);
218            // Update inflight counter and last_used.
219            if let Ok(mut guard) = handle.write() {
220                guard.inflight = guard.inflight.saturating_add(1);
221                guard.last_used = Instant::now();
222            }
223            return Ok(Arc::clone(handle));
224        }
225
226        // Choose loader: external takes precedence over embedded.
227        // SAFETY: we know `self.loader` lives as long as `self`; to avoid
228        // the borrow issue we clone the spec out.
229        let spec = {
230            let ldr = ext_loader.unwrap_or(&self.loader);
231            ldr.lookup(model_id)
232                .cloned()
233                .ok_or_else(|| ServerError::InvalidRequest {
234                    message: format!("model '{model_id}' is not registered"),
235                })?
236        };
237
238        // Estimate memory before eviction so the budget check is accurate.
239        let estimated_mem = estimate_mem_bytes(&spec.path);
240
241        // Evict LRU models until we have room.
242        self.evict_until_budget(estimated_mem)?;
243        if self.loaded.len() >= self.capacity {
244            self.evict_one()?;
245        }
246
247        // Load the engine synchronously.
248        tracing::info!(model_id, "loading model into pool");
249        let engine_config = self.loader.build_engine_config(model_id, &spec);
250        let mut engine = InferenceEngine::new(engine_config);
251        engine.load_model().map_err(ServerError::Runtime)?;
252        tracing::info!(model_id, mem_bytes = estimated_mem, "model loaded");
253
254        let handle = Arc::new(RwLock::new(LoadedModel {
255            engine,
256            last_used: Instant::now(),
257            mem_bytes: estimated_mem,
258            inflight: 1,
259        }));
260
261        self.loaded
262            .insert(model_id.to_string(), Arc::clone(&handle));
263        self.touch_lru(model_id);
264
265        Ok(handle)
266    }
267
268    /// Decrement inflight count for a model when the caller is done with it.
269    pub fn release(&self, model_id: &str) {
270        if let Some(handle) = self.loaded.get(model_id) {
271            if let Ok(mut guard) = handle.write() {
272                guard.inflight = guard.inflight.saturating_sub(1);
273            }
274        }
275    }
276
277    /// Explicitly unload a model, freeing its memory.
278    ///
279    /// Returns an error if the model ID is not currently loaded.
280    pub fn unload(&mut self, model_id: &str) -> ServerResult<()> {
281        if self.loaded.remove(model_id).is_none() {
282            return Err(ServerError::InvalidRequest {
283                message: format!("model '{model_id}' is not loaded"),
284            });
285        }
286        self.pending.remove(model_id);
287        if let Ok(mut lru) = self.lru.lock() {
288            lru.remove(model_id);
289        }
290        tracing::info!(model_id, "model unloaded from pool");
291        Ok(())
292    }
293
294    /// List the status of all known models (loaded + pending).
295    pub fn list(&self) -> Vec<ModelStatus> {
296        let mut out = Vec::with_capacity(self.loaded.len() + self.pending.len());
297
298        for (id, handle) in &self.loaded {
299            let (mem_bytes, last_used_secs, inflight) = if let Ok(guard) = handle.read() {
300                let secs = guard.last_used.elapsed().as_secs();
301                (guard.mem_bytes, secs, guard.inflight)
302            } else {
303                (0, 0, 0)
304            };
305            out.push(ModelStatus {
306                id: id.clone(),
307                status: ModelLoadStatus::Ready,
308                mem_bytes,
309                last_used_secs,
310                inflight,
311            });
312        }
313
314        for (id, entry) in &self.pending {
315            let status = match &entry.status {
316                PendingStatus::Loading => ModelLoadStatus::Loading,
317                PendingStatus::Failed(_) => ModelLoadStatus::Failed,
318            };
319            out.push(ModelStatus {
320                id: id.clone(),
321                status,
322                mem_bytes: entry.mem_bytes,
323                last_used_secs: 0,
324                inflight: 0,
325            });
326        }
327
328        out
329    }
330
331    /// Mark a model as being loaded in a background task.
332    pub fn mark_loading(&mut self, model_id: impl Into<String>) {
333        let id = model_id.into();
334        self.pending.insert(
335            id,
336            PendingEntry {
337                status: PendingStatus::Loading,
338                mem_bytes: 0,
339            },
340        );
341    }
342
343    /// Mark a pending model as ready (called after a background load succeeds).
344    ///
345    /// Moves the engine from the temporary pending state into the loaded map.
346    pub fn mark_ready(
347        &mut self,
348        model_id: &str,
349        engine: InferenceEngine,
350        mem_bytes: usize,
351    ) -> ServerResult<()> {
352        // Evict if needed.
353        self.evict_until_budget(mem_bytes)?;
354        if self.loaded.len() >= self.capacity {
355            self.evict_one()?;
356        }
357
358        let handle = Arc::new(RwLock::new(LoadedModel {
359            engine,
360            last_used: Instant::now(),
361            mem_bytes,
362            inflight: 0,
363        }));
364        self.loaded
365            .insert(model_id.to_string(), Arc::clone(&handle));
366        self.pending.remove(model_id);
367        self.touch_lru(model_id);
368        Ok(())
369    }
370
371    /// Mark a pending model as failed to load.
372    pub fn mark_failed(&mut self, model_id: &str, reason: String) {
373        if let Some(entry) = self.pending.get_mut(model_id) {
374            entry.status = PendingStatus::Failed(reason);
375        }
376    }
377
378    /// Total estimated bytes currently consumed by loaded models.
379    pub fn current_mem_bytes(&self) -> usize {
380        self.loaded
381            .values()
382            .filter_map(|h| h.read().ok().map(|g| g.mem_bytes))
383            .sum()
384    }
385
386    // ── private helpers ──────────────────────────────────────────────────────
387
388    fn touch_lru(&self, model_id: &str) {
389        if let Ok(mut lru) = self.lru.lock() {
390            lru.touch(model_id);
391        }
392    }
393
394    /// Evict LRU models until `current + needed <= budget` (or budget is 0).
395    fn evict_until_budget(&mut self, needed_bytes: usize) -> ServerResult<()> {
396        if self.mem_budget_bytes == 0 {
397            return Ok(());
398        }
399        while self.current_mem_bytes() + needed_bytes > self.mem_budget_bytes {
400            self.evict_one().map_err(|_| ServerError::InvalidRequest {
401                message: "memory budget exceeded and no evictable model found".to_string(),
402            })?;
403        }
404        Ok(())
405    }
406
407    /// Evict the single LRU model.
408    fn evict_one(&mut self) -> ServerResult<()> {
409        let victim = {
410            let mut lru = self.lru.lock().map_err(|_| ServerError::InvalidRequest {
411                message: "LRU queue lock poisoned".to_string(),
412            })?;
413            lru.evict_lru()
414        };
415
416        let victim = victim.ok_or_else(|| ServerError::InvalidRequest {
417            message: "no model to evict — pool is empty".to_string(),
418        })?;
419
420        // Don't evict a model that has in-flight requests.
421        let inflight = self
422            .loaded
423            .get(&victim)
424            .and_then(|h| h.read().ok().map(|g| g.inflight))
425            .unwrap_or(0);
426
427        if inflight > 0 {
428            // Return the model to the LRU queue so it isn't lost.
429            self.touch_lru(&victim);
430            return Err(ServerError::InvalidRequest {
431                message: format!("cannot evict '{victim}': {inflight} request(s) in flight"),
432            });
433        }
434
435        tracing::info!(model_id = %victim, "evicting model from pool (LRU)");
436        self.loaded.remove(&victim);
437        Ok(())
438    }
439}
440
441/// Rough memory estimate based on the file size of the GGUF on disk.
442///
443/// The weights are memory-mapped, so the on-disk size approximates
444/// resident memory.  We add a 64 MiB overhead for KV cache + buffers.
445fn estimate_mem_bytes(path: &std::path::Path) -> usize {
446    const KV_OVERHEAD: usize = 64 * 1024 * 1024;
447    let file_size = std::fs::metadata(path)
448        .map(|m| m.len() as usize)
449        .unwrap_or(0);
450    file_size.saturating_add(KV_OVERHEAD)
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    /// (a) pool_single_model_routes: manual insert; acquire same model twice;
458    ///     second call returns same Arc (pointer equality).
459    #[test]
460    fn pool_single_model_routes() {
461        let mut pool = ModelPool::new(2, 0); // 0 = unlimited budget
462
463        // Manually insert a fake loaded model so we don't need a real file.
464        let engine = InferenceEngine::new(EngineConfig::default());
465        let handle = Arc::new(RwLock::new(LoadedModel {
466            engine,
467            last_used: Instant::now(),
468            mem_bytes: 0,
469            inflight: 0,
470        }));
471        pool.loaded
472            .insert("model-a".to_string(), Arc::clone(&handle));
473        pool.touch_lru("model-a");
474
475        // First acquire (no external loader needed — already loaded)
476        let h1 = pool.acquire("model-a", None).expect("first acquire");
477        // Second acquire — should be the same Arc
478        let h2 = pool.acquire("model-a", None).expect("second acquire");
479
480        assert!(
481            Arc::ptr_eq(&h1, &h2),
482            "both acquires should return the same Arc"
483        );
484    }
485
486    /// (b) pool_evicts_when_over_capacity: capacity=1; insert model-a;
487    ///     insert model-b manually; model-a should be evicted.
488    #[test]
489    fn pool_evicts_when_over_capacity() {
490        let mut pool = ModelPool::new(1, 0); // capacity = 1
491
492        // Insert model-a
493        let engine_a = InferenceEngine::new(EngineConfig::default());
494        let handle_a = Arc::new(RwLock::new(LoadedModel {
495            engine: engine_a,
496            last_used: Instant::now(),
497            mem_bytes: 0,
498            inflight: 0,
499        }));
500        pool.loaded.insert("model-a".to_string(), handle_a);
501        pool.touch_lru("model-a");
502
503        // Now insert model-b via mark_ready — should evict model-a.
504        let engine_b = InferenceEngine::new(EngineConfig::default());
505        pool.mark_ready("model-b", engine_b, 0)
506            .expect("mark_ready should succeed after evicting model-a");
507
508        assert!(
509            !pool.loaded.contains_key("model-a"),
510            "model-a should have been evicted"
511        );
512        assert!(
513            pool.loaded.contains_key("model-b"),
514            "model-b should now be loaded"
515        );
516    }
517
518    /// (c) pool_unknown_model_returns_error: acquire a model that was never
519    ///     registered; expect Err with a descriptive message.
520    #[test]
521    fn pool_unknown_model_returns_error() {
522        let mut pool = ModelPool::new(4, 0);
523        // No spec registered in pool's embedded loader, no external loader.
524
525        let err = pool.acquire("unknown-model", None).unwrap_err();
526        let msg = err.to_string();
527        assert!(
528            msg.contains("not registered"),
529            "error should mention 'not registered': {msg}"
530        );
531    }
532
533    /// (d) pool_list_shows_loaded: insert two models; pool.list() returns both.
534    #[test]
535    fn pool_list_shows_loaded() {
536        let mut pool = ModelPool::new(4, 0);
537
538        for name in ["model-x", "model-y"] {
539            let engine = InferenceEngine::new(EngineConfig::default());
540            let handle = Arc::new(RwLock::new(LoadedModel {
541                engine,
542                last_used: Instant::now(),
543                mem_bytes: 1024,
544                inflight: 0,
545            }));
546            pool.loaded.insert(name.to_string(), handle);
547            pool.touch_lru(name);
548        }
549
550        let statuses = pool.list();
551        assert_eq!(statuses.len(), 2, "list should report both models");
552        let ids: Vec<&str> = statuses.iter().map(|s| s.id.as_str()).collect();
553        assert!(ids.contains(&"model-x"), "model-x should appear in list");
554        assert!(ids.contains(&"model-y"), "model-y should appear in list");
555        for s in &statuses {
556            assert_eq!(s.status, ModelLoadStatus::Ready);
557            assert_eq!(s.mem_bytes, 1024);
558        }
559    }
560
561    /// LRU eviction order test: insert 3 models with capacity=3; touch the
562    /// first two; the third (oldest) should be evicted first.
563    #[test]
564    fn pool_lru_ordering() {
565        let mut pool = ModelPool::new(3, 0);
566
567        for name in ["alpha", "beta", "gamma"] {
568            let engine = InferenceEngine::new(EngineConfig::default());
569            let handle = Arc::new(RwLock::new(LoadedModel {
570                engine,
571                last_used: Instant::now(),
572                mem_bytes: 0,
573                inflight: 0,
574            }));
575            pool.loaded.insert(name.to_string(), handle);
576            pool.touch_lru(name);
577        }
578
579        // Touch alpha and beta → gamma is now LRU.
580        pool.touch_lru("alpha");
581        pool.touch_lru("beta");
582
583        pool.evict_one().expect("should evict gamma");
584        assert!(
585            !pool.loaded.contains_key("gamma"),
586            "gamma should have been evicted"
587        );
588        assert!(pool.loaded.contains_key("alpha"), "alpha should remain");
589        assert!(pool.loaded.contains_key("beta"), "beta should remain");
590    }
591
592    /// Memory-budget eviction: use mark_ready then mark_ready a second model
593    /// to exceed capacity=1 which forces LRU eviction.
594    ///
595    /// We exercise the budget path by setting a tiny budget and using
596    /// mark_ready (which calls evict_until_budget + evict_one internally).
597    #[test]
598    fn pool_evicts_when_over_budget() {
599        // Budget = 1 MiB; capacity = 1 so first mark_ready fills the slot.
600        let mut pool = ModelPool::new(1, 1); // 1-MiB budget, capacity=1
601
602        // Insert "big-model" using mark_ready with 0 bytes (fits in 1 MiB).
603        let engine_a = InferenceEngine::new(EngineConfig::default());
604        pool.mark_ready("big-model", engine_a, 0)
605            .expect("first mark_ready should succeed");
606
607        assert!(
608            pool.loaded.contains_key("big-model"),
609            "big-model should be in pool after mark_ready"
610        );
611
612        // Now insert a second model — capacity=1 forces eviction of big-model.
613        let engine_b = InferenceEngine::new(EngineConfig::default());
614        pool.mark_ready("small-model", engine_b, 0)
615            .expect("second mark_ready should evict big-model and succeed");
616
617        assert!(
618            !pool.loaded.contains_key("big-model"),
619            "big-model should have been evicted when small-model was loaded"
620        );
621        assert!(
622            pool.loaded.contains_key("small-model"),
623            "small-model should now be in the pool"
624        );
625    }
626}