1use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, Mutex, RwLock};
18use std::time::Instant;
19
20use oxillama_runtime::engine::{EngineConfig, InferenceEngine};
21
22use crate::error::{ServerError, ServerResult};
23use crate::router::eviction::LruQueue;
24
25pub type ModelId = String;
27
28pub struct LoadedModel {
30 pub engine: InferenceEngine,
32 pub last_used: Instant,
34 pub mem_bytes: usize,
39 pub inflight: u64,
41}
42
43impl std::fmt::Debug for LoadedModel {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("LoadedModel")
47 .field("last_used", &self.last_used)
48 .field("mem_bytes", &self.mem_bytes)
49 .field("inflight", &self.inflight)
50 .finish_non_exhaustive()
51 }
52}
53
54#[derive(Debug, Clone, serde::Serialize)]
56pub struct ModelStatus {
57 pub id: String,
59 pub status: ModelLoadStatus,
61 pub mem_bytes: usize,
63 pub last_used_secs: u64,
65 pub inflight: u64,
67}
68
69#[derive(Debug, Clone, PartialEq, serde::Serialize)]
71#[serde(rename_all = "snake_case")]
72pub enum ModelLoadStatus {
73 Loading,
75 Ready,
77 Failed,
79}
80
81#[derive(Debug, Clone)]
83pub struct ModelSpec {
84 pub path: PathBuf,
86 pub quant: Option<String>,
88}
89
90pub struct ModelLoader {
94 registry: HashMap<ModelId, ModelSpec>,
95 pub default_context_size: Option<usize>,
97 pub default_num_threads: usize,
99}
100
101impl ModelLoader {
102 pub fn new() -> Self {
104 Self {
105 registry: HashMap::new(),
106 default_context_size: None,
107 default_num_threads: 4,
108 }
109 }
110
111 pub fn register(&mut self, id: impl Into<String>, spec: ModelSpec) {
113 self.registry.insert(id.into(), spec);
114 }
115
116 pub fn lookup(&self, id: &str) -> Option<&ModelSpec> {
118 self.registry.get(id)
119 }
120
121 pub fn build_engine_config(&self, id: &str, spec: &ModelSpec) -> EngineConfig {
123 tracing::debug!(model_id = id, path = %spec.path.display(), "building engine config");
124 EngineConfig {
125 model_path: spec.path.to_string_lossy().into_owned(),
126 context_size: self.default_context_size,
127 num_threads: self.default_num_threads,
128 ..EngineConfig::default()
129 }
130 }
131}
132
133impl Default for ModelLoader {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[derive(Debug, Clone, PartialEq)]
141pub enum PendingStatus {
142 Loading,
143 Failed(String),
144}
145
146pub struct PendingEntry {
148 pub status: PendingStatus,
149 pub mem_bytes: usize,
150}
151
152pub struct ModelPool {
157 loaded: HashMap<ModelId, Arc<RwLock<LoadedModel>>>,
159 pending: HashMap<ModelId, PendingEntry>,
161 lru: Mutex<LruQueue>,
163 capacity: usize,
165 mem_budget_bytes: usize,
167 loader: ModelLoader,
169}
170
171impl ModelPool {
172 pub fn new(capacity: usize, mem_budget_mb: usize) -> Self {
177 Self {
178 loaded: HashMap::with_capacity(capacity),
179 pending: HashMap::new(),
180 lru: Mutex::new(LruQueue::with_capacity(capacity)),
181 capacity,
182 mem_budget_bytes: mem_budget_mb.saturating_mul(1024 * 1024),
183 loader: ModelLoader::new(),
184 }
185 }
186
187 pub fn loader_register(&mut self, id: impl Into<String>, spec: ModelSpec) {
192 self.loader.register(id, spec);
193 }
194
195 pub fn loader(&self) -> &ModelLoader {
197 &self.loader
198 }
199
200 pub fn acquire(
211 &mut self,
212 model_id: &str,
213 ext_loader: Option<&ModelLoader>,
214 ) -> ServerResult<Arc<RwLock<LoadedModel>>> {
215 if let Some(handle) = self.loaded.get(model_id) {
217 self.touch_lru(model_id);
218 if let Ok(mut guard) = handle.write() {
220 guard.inflight = guard.inflight.saturating_add(1);
221 guard.last_used = Instant::now();
222 }
223 return Ok(Arc::clone(handle));
224 }
225
226 let spec = {
230 let ldr = ext_loader.unwrap_or(&self.loader);
231 ldr.lookup(model_id)
232 .cloned()
233 .ok_or_else(|| ServerError::InvalidRequest {
234 message: format!("model '{model_id}' is not registered"),
235 })?
236 };
237
238 let estimated_mem = estimate_mem_bytes(&spec.path);
240
241 self.evict_until_budget(estimated_mem)?;
243 if self.loaded.len() >= self.capacity {
244 self.evict_one()?;
245 }
246
247 tracing::info!(model_id, "loading model into pool");
249 let engine_config = self.loader.build_engine_config(model_id, &spec);
250 let mut engine = InferenceEngine::new(engine_config);
251 engine.load_model().map_err(ServerError::Runtime)?;
252 tracing::info!(model_id, mem_bytes = estimated_mem, "model loaded");
253
254 let handle = Arc::new(RwLock::new(LoadedModel {
255 engine,
256 last_used: Instant::now(),
257 mem_bytes: estimated_mem,
258 inflight: 1,
259 }));
260
261 self.loaded
262 .insert(model_id.to_string(), Arc::clone(&handle));
263 self.touch_lru(model_id);
264
265 Ok(handle)
266 }
267
268 pub fn release(&self, model_id: &str) {
270 if let Some(handle) = self.loaded.get(model_id) {
271 if let Ok(mut guard) = handle.write() {
272 guard.inflight = guard.inflight.saturating_sub(1);
273 }
274 }
275 }
276
277 pub fn unload(&mut self, model_id: &str) -> ServerResult<()> {
281 if self.loaded.remove(model_id).is_none() {
282 return Err(ServerError::InvalidRequest {
283 message: format!("model '{model_id}' is not loaded"),
284 });
285 }
286 self.pending.remove(model_id);
287 if let Ok(mut lru) = self.lru.lock() {
288 lru.remove(model_id);
289 }
290 tracing::info!(model_id, "model unloaded from pool");
291 Ok(())
292 }
293
294 pub fn list(&self) -> Vec<ModelStatus> {
296 let mut out = Vec::with_capacity(self.loaded.len() + self.pending.len());
297
298 for (id, handle) in &self.loaded {
299 let (mem_bytes, last_used_secs, inflight) = if let Ok(guard) = handle.read() {
300 let secs = guard.last_used.elapsed().as_secs();
301 (guard.mem_bytes, secs, guard.inflight)
302 } else {
303 (0, 0, 0)
304 };
305 out.push(ModelStatus {
306 id: id.clone(),
307 status: ModelLoadStatus::Ready,
308 mem_bytes,
309 last_used_secs,
310 inflight,
311 });
312 }
313
314 for (id, entry) in &self.pending {
315 let status = match &entry.status {
316 PendingStatus::Loading => ModelLoadStatus::Loading,
317 PendingStatus::Failed(_) => ModelLoadStatus::Failed,
318 };
319 out.push(ModelStatus {
320 id: id.clone(),
321 status,
322 mem_bytes: entry.mem_bytes,
323 last_used_secs: 0,
324 inflight: 0,
325 });
326 }
327
328 out
329 }
330
331 pub fn mark_loading(&mut self, model_id: impl Into<String>) {
333 let id = model_id.into();
334 self.pending.insert(
335 id,
336 PendingEntry {
337 status: PendingStatus::Loading,
338 mem_bytes: 0,
339 },
340 );
341 }
342
343 pub fn mark_ready(
347 &mut self,
348 model_id: &str,
349 engine: InferenceEngine,
350 mem_bytes: usize,
351 ) -> ServerResult<()> {
352 self.evict_until_budget(mem_bytes)?;
354 if self.loaded.len() >= self.capacity {
355 self.evict_one()?;
356 }
357
358 let handle = Arc::new(RwLock::new(LoadedModel {
359 engine,
360 last_used: Instant::now(),
361 mem_bytes,
362 inflight: 0,
363 }));
364 self.loaded
365 .insert(model_id.to_string(), Arc::clone(&handle));
366 self.pending.remove(model_id);
367 self.touch_lru(model_id);
368 Ok(())
369 }
370
371 pub fn mark_failed(&mut self, model_id: &str, reason: String) {
373 if let Some(entry) = self.pending.get_mut(model_id) {
374 entry.status = PendingStatus::Failed(reason);
375 }
376 }
377
378 pub fn current_mem_bytes(&self) -> usize {
380 self.loaded
381 .values()
382 .filter_map(|h| h.read().ok().map(|g| g.mem_bytes))
383 .sum()
384 }
385
386 fn touch_lru(&self, model_id: &str) {
389 if let Ok(mut lru) = self.lru.lock() {
390 lru.touch(model_id);
391 }
392 }
393
394 fn evict_until_budget(&mut self, needed_bytes: usize) -> ServerResult<()> {
396 if self.mem_budget_bytes == 0 {
397 return Ok(());
398 }
399 while self.current_mem_bytes() + needed_bytes > self.mem_budget_bytes {
400 self.evict_one().map_err(|_| ServerError::InvalidRequest {
401 message: "memory budget exceeded and no evictable model found".to_string(),
402 })?;
403 }
404 Ok(())
405 }
406
407 fn evict_one(&mut self) -> ServerResult<()> {
409 let victim = {
410 let mut lru = self.lru.lock().map_err(|_| ServerError::InvalidRequest {
411 message: "LRU queue lock poisoned".to_string(),
412 })?;
413 lru.evict_lru()
414 };
415
416 let victim = victim.ok_or_else(|| ServerError::InvalidRequest {
417 message: "no model to evict — pool is empty".to_string(),
418 })?;
419
420 let inflight = self
422 .loaded
423 .get(&victim)
424 .and_then(|h| h.read().ok().map(|g| g.inflight))
425 .unwrap_or(0);
426
427 if inflight > 0 {
428 self.touch_lru(&victim);
430 return Err(ServerError::InvalidRequest {
431 message: format!("cannot evict '{victim}': {inflight} request(s) in flight"),
432 });
433 }
434
435 tracing::info!(model_id = %victim, "evicting model from pool (LRU)");
436 self.loaded.remove(&victim);
437 Ok(())
438 }
439}
440
441fn estimate_mem_bytes(path: &std::path::Path) -> usize {
446 const KV_OVERHEAD: usize = 64 * 1024 * 1024;
447 let file_size = std::fs::metadata(path)
448 .map(|m| m.len() as usize)
449 .unwrap_or(0);
450 file_size.saturating_add(KV_OVERHEAD)
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
460 fn pool_single_model_routes() {
461 let mut pool = ModelPool::new(2, 0); let engine = InferenceEngine::new(EngineConfig::default());
465 let handle = Arc::new(RwLock::new(LoadedModel {
466 engine,
467 last_used: Instant::now(),
468 mem_bytes: 0,
469 inflight: 0,
470 }));
471 pool.loaded
472 .insert("model-a".to_string(), Arc::clone(&handle));
473 pool.touch_lru("model-a");
474
475 let h1 = pool.acquire("model-a", None).expect("first acquire");
477 let h2 = pool.acquire("model-a", None).expect("second acquire");
479
480 assert!(
481 Arc::ptr_eq(&h1, &h2),
482 "both acquires should return the same Arc"
483 );
484 }
485
486 #[test]
489 fn pool_evicts_when_over_capacity() {
490 let mut pool = ModelPool::new(1, 0); let engine_a = InferenceEngine::new(EngineConfig::default());
494 let handle_a = Arc::new(RwLock::new(LoadedModel {
495 engine: engine_a,
496 last_used: Instant::now(),
497 mem_bytes: 0,
498 inflight: 0,
499 }));
500 pool.loaded.insert("model-a".to_string(), handle_a);
501 pool.touch_lru("model-a");
502
503 let engine_b = InferenceEngine::new(EngineConfig::default());
505 pool.mark_ready("model-b", engine_b, 0)
506 .expect("mark_ready should succeed after evicting model-a");
507
508 assert!(
509 !pool.loaded.contains_key("model-a"),
510 "model-a should have been evicted"
511 );
512 assert!(
513 pool.loaded.contains_key("model-b"),
514 "model-b should now be loaded"
515 );
516 }
517
518 #[test]
521 fn pool_unknown_model_returns_error() {
522 let mut pool = ModelPool::new(4, 0);
523 let err = pool.acquire("unknown-model", None).unwrap_err();
526 let msg = err.to_string();
527 assert!(
528 msg.contains("not registered"),
529 "error should mention 'not registered': {msg}"
530 );
531 }
532
533 #[test]
535 fn pool_list_shows_loaded() {
536 let mut pool = ModelPool::new(4, 0);
537
538 for name in ["model-x", "model-y"] {
539 let engine = InferenceEngine::new(EngineConfig::default());
540 let handle = Arc::new(RwLock::new(LoadedModel {
541 engine,
542 last_used: Instant::now(),
543 mem_bytes: 1024,
544 inflight: 0,
545 }));
546 pool.loaded.insert(name.to_string(), handle);
547 pool.touch_lru(name);
548 }
549
550 let statuses = pool.list();
551 assert_eq!(statuses.len(), 2, "list should report both models");
552 let ids: Vec<&str> = statuses.iter().map(|s| s.id.as_str()).collect();
553 assert!(ids.contains(&"model-x"), "model-x should appear in list");
554 assert!(ids.contains(&"model-y"), "model-y should appear in list");
555 for s in &statuses {
556 assert_eq!(s.status, ModelLoadStatus::Ready);
557 assert_eq!(s.mem_bytes, 1024);
558 }
559 }
560
561 #[test]
564 fn pool_lru_ordering() {
565 let mut pool = ModelPool::new(3, 0);
566
567 for name in ["alpha", "beta", "gamma"] {
568 let engine = InferenceEngine::new(EngineConfig::default());
569 let handle = Arc::new(RwLock::new(LoadedModel {
570 engine,
571 last_used: Instant::now(),
572 mem_bytes: 0,
573 inflight: 0,
574 }));
575 pool.loaded.insert(name.to_string(), handle);
576 pool.touch_lru(name);
577 }
578
579 pool.touch_lru("alpha");
581 pool.touch_lru("beta");
582
583 pool.evict_one().expect("should evict gamma");
584 assert!(
585 !pool.loaded.contains_key("gamma"),
586 "gamma should have been evicted"
587 );
588 assert!(pool.loaded.contains_key("alpha"), "alpha should remain");
589 assert!(pool.loaded.contains_key("beta"), "beta should remain");
590 }
591
592 #[test]
598 fn pool_evicts_when_over_budget() {
599 let mut pool = ModelPool::new(1, 1); let engine_a = InferenceEngine::new(EngineConfig::default());
604 pool.mark_ready("big-model", engine_a, 0)
605 .expect("first mark_ready should succeed");
606
607 assert!(
608 pool.loaded.contains_key("big-model"),
609 "big-model should be in pool after mark_ready"
610 );
611
612 let engine_b = InferenceEngine::new(EngineConfig::default());
614 pool.mark_ready("small-model", engine_b, 0)
615 .expect("second mark_ready should evict big-model and succeed");
616
617 assert!(
618 !pool.loaded.contains_key("big-model"),
619 "big-model should have been evicted when small-model was loaded"
620 );
621 assert!(
622 pool.loaded.contains_key("small-model"),
623 "small-model should now be in the pool"
624 );
625 }
626}