1use super::format::BundleReader;
6use super::manifest::{BundleManifest, ModelEntry};
7use super::mmap::PageTable;
8use super::{DEFAULT_MAX_MEMORY, DEFAULT_PAGE_SIZE};
9use crate::error::{AprenderError, Result};
10use std::collections::{HashMap, VecDeque};
11use std::path::Path;
12
13#[derive(Debug, Clone)]
19pub struct PagingConfig {
20 pub max_memory: usize,
22 pub page_size: usize,
24 pub prefetch: bool,
26 pub prefetch_count: usize,
28 pub eviction: EvictionStrategy,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34#[allow(clippy::upper_case_acronyms)]
35pub enum EvictionStrategy {
36 #[default]
38 LRU,
39 LFU,
41}
42
43impl Default for PagingConfig {
44 fn default() -> Self {
45 Self {
46 max_memory: DEFAULT_MAX_MEMORY,
47 page_size: DEFAULT_PAGE_SIZE,
48 prefetch: true,
49 prefetch_count: 2,
50 eviction: EvictionStrategy::default(),
51 }
52 }
53}
54
55impl PagingConfig {
56 #[must_use]
58 pub fn new() -> Self {
59 Self::default()
60 }
61
62 #[must_use]
66 pub fn with_max_memory(mut self, max_memory: usize) -> Self {
67 self.max_memory = max_memory.max(1024);
68 self
69 }
70
71 #[must_use]
73 pub fn with_page_size(mut self, page_size: usize) -> Self {
74 self.page_size = page_size.max(512);
75 self
76 }
77
78 #[must_use]
80 pub fn with_prefetch(mut self, prefetch: bool) -> Self {
81 self.prefetch = prefetch;
82 self
83 }
84
85 #[must_use]
87 pub fn with_prefetch_count(mut self, count: usize) -> Self {
88 self.prefetch_count = count;
89 self
90 }
91
92 #[must_use]
94 pub fn with_eviction(mut self, strategy: EvictionStrategy) -> Self {
95 self.eviction = strategy;
96 self
97 }
98}
99
100#[derive(Debug, Clone, Default)]
106pub struct PagingStats {
107 pub hits: usize,
109 pub misses: usize,
111 pub evictions: usize,
113 pub bytes_loaded: usize,
115 pub memory_used: usize,
117 pub prefetches: usize,
119}
120
121impl PagingStats {
122 #[must_use]
124 pub fn hit_rate(&self) -> f32 {
125 let total = self.hits + self.misses;
126 if total == 0 {
127 0.0
128 } else {
129 self.hits as f32 / total as f32
130 }
131 }
132
133 pub fn reset(&mut self) {
135 *self = Self::default();
136 }
137}
138
139pub struct PagedBundle {
148 reader: BundleReader,
150 manifest: BundleManifest,
152 cache: HashMap<String, Vec<u8>>,
154 lru_order: VecDeque<String>,
156 page_table: PageTable,
158 config: PagingConfig,
160 stats: PagingStats,
162 access_history: VecDeque<String>,
164}
165
166impl PagedBundle {
167 pub fn open(path: impl AsRef<Path>, config: PagingConfig) -> Result<Self> {
169 let mut reader = BundleReader::open(path)?;
170 let manifest = reader.read_manifest()?;
171
172 Ok(Self {
173 reader,
174 manifest,
175 cache: HashMap::new(),
176 lru_order: VecDeque::new(),
177 page_table: PageTable::new(),
178 config,
179 stats: PagingStats::default(),
180 access_history: VecDeque::with_capacity(10),
181 })
182 }
183
184 pub fn get_model(&mut self, name: &str) -> Result<&[u8]> {
186 if self.cache.contains_key(name) {
188 self.stats.hits += 1;
189 self.update_lru(name);
190 self.record_access(name);
191
192 if self.config.prefetch {
194 self.try_prefetch();
195 }
196
197 return Ok(self.cache.get(name).expect("Key should exist"));
198 }
199
200 self.stats.misses += 1;
202 self.load_model(name)?;
203
204 self.record_access(name);
206
207 if self.config.prefetch {
209 self.try_prefetch();
210 }
211
212 Ok(self.cache.get(name).expect("Just loaded"))
213 }
214
215 #[must_use]
217 pub fn is_cached(&self, name: &str) -> bool {
218 self.cache.contains_key(name)
219 }
220
221 #[must_use]
223 pub fn model_names(&self) -> Vec<&str> {
224 self.manifest.model_names()
225 }
226
227 #[must_use]
229 pub fn get_metadata(&self, name: &str) -> Option<&ModelEntry> {
230 self.manifest.get_model(name)
231 }
232
233 #[must_use]
235 pub fn stats(&self) -> &PagingStats {
236 &self.stats
237 }
238
239 #[must_use]
241 pub fn config(&self) -> &PagingConfig {
242 &self.config
243 }
244
245 #[must_use]
247 pub fn memory_used(&self) -> usize {
248 self.stats.memory_used
249 }
250
251 #[must_use]
253 pub fn cached_count(&self) -> usize {
254 self.cache.len()
255 }
256
257 pub fn evict(&mut self, name: &str) -> bool {
259 if let Some(data) = self.cache.remove(name) {
260 self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
261 self.stats.evictions += 1;
262 self.lru_order.retain(|n| n != name);
263 true
264 } else {
265 false
266 }
267 }
268
269 pub fn clear_cache(&mut self) {
271 self.cache.clear();
272 self.lru_order.clear();
273 self.stats.memory_used = 0;
274 }
275
276 pub fn prefetch_hint(&mut self, name: &str) -> Result<()> {
278 if !self.cache.contains_key(name) && self.manifest.get_model(name).is_some() {
279 self.load_model(name)?;
280 self.stats.prefetches += 1;
281 }
282 Ok(())
283 }
284
285 fn load_model(&mut self, name: &str) -> Result<()> {
287 let entry = self
288 .manifest
289 .get_model(name)
290 .ok_or_else(|| AprenderError::Other(format!("Model '{name}' not found")))?
291 .clone();
292
293 while self.stats.memory_used + entry.size > self.config.max_memory {
295 if !self.evict_lru() {
296 break;
298 }
299 }
300
301 let data = self.reader.read_model(&entry)?;
303 let size = data.len();
304
305 self.stats.bytes_loaded += size;
307 self.stats.memory_used += size;
308
309 self.cache.insert(name.to_string(), data);
311 self.lru_order.push_back(name.to_string());
312
313 self.page_table.add_page(entry.offset, size);
315
316 Ok(())
317 }
318
319 fn update_lru(&mut self, name: &str) {
321 self.lru_order.retain(|n| n != name);
322 self.lru_order.push_back(name.to_string());
323
324 if let Some(entry) = self.manifest.get_model(name) {
326 self.page_table.touch(entry.offset);
327 }
328 }
329
330 fn evict_lru(&mut self) -> bool {
332 let to_evict = match self.config.eviction {
333 EvictionStrategy::LRU => self.lru_order.pop_front(),
334 EvictionStrategy::LFU => {
335 if let Some(offset) = self.page_table.lfu_page() {
336 self.manifest
338 .iter()
339 .find(|e| e.offset == offset)
340 .map(|e| e.name.clone())
341 } else {
342 self.lru_order.pop_front()
343 }
344 }
345 };
346
347 if let Some(name) = to_evict {
348 if let Some(data) = self.cache.remove(&name) {
349 self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
350 self.stats.evictions += 1;
351
352 if let Some(entry) = self.manifest.get_model(&name) {
354 self.page_table.remove(entry.offset);
355 }
356
357 return true;
358 }
359 }
360
361 false
362 }
363
364 fn record_access(&mut self, name: &str) {
366 if self.access_history.len() >= 10 {
367 self.access_history.pop_front();
368 }
369 self.access_history.push_back(name.to_string());
370 }
371
372 fn try_prefetch(&mut self) {
374 if self.access_history.len() < 2 {
375 return;
376 }
377
378 let last = self.access_history.back().cloned();
380 if let Some(last_name) = last {
381 let patterns: Vec<_> = self
383 .access_history
384 .iter()
385 .zip(self.access_history.iter().skip(1))
386 .filter(|(prev, _)| *prev == &last_name)
387 .map(|(_, next)| next.clone())
388 .take(self.config.prefetch_count)
389 .collect();
390
391 for name in patterns {
392 if !self.cache.contains_key(&name)
393 && self.stats.memory_used + self.estimate_size(&name) <= self.config.max_memory
394 {
395 let _ = self.load_model(&name);
396 self.stats.prefetches += 1;
397 }
398 }
399 }
400 }
401
402 fn estimate_size(&self, name: &str) -> usize {
404 self.manifest.get_model(name).map_or(0, |e| e.size)
405 }
406}
407
408impl std::fmt::Debug for PagedBundle {
409 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 f.debug_struct("PagedBundle")
411 .field("models", &self.manifest.len())
412 .field("cached", &self.cache.len())
413 .field("memory_used", &self.stats.memory_used)
414 .field("max_memory", &self.config.max_memory)
415 .field("hit_rate", &self.stats.hit_rate())
416 .finish_non_exhaustive()
417 }
418}
419
420#[cfg(test)]
425mod tests {
426 use super::*;
427 use crate::bundle::{BundleManifest, BundleWriter, ModelEntry};
428 use std::collections::HashMap;
429 use tempfile::NamedTempFile;
430
431 fn create_test_bundle(models: &[(&str, Vec<u8>)]) -> NamedTempFile {
432 let temp = NamedTempFile::new().expect("Failed to create temp file");
433
434 let mut manifest = BundleManifest::new();
435 let mut model_map = HashMap::new();
436
437 for (name, data) in models {
438 manifest.add_model(ModelEntry::new(*name, data.len()));
439 model_map.insert((*name).to_string(), data.clone());
440 }
441
442 let writer = BundleWriter::create(temp.path()).expect("Failed to create writer");
443 writer
444 .write_bundle(&manifest, &model_map)
445 .expect("Failed to write bundle");
446
447 temp
448 }
449
450 #[test]
451 fn test_paging_config_default() {
452 let config = PagingConfig::default();
453 assert_eq!(config.max_memory, DEFAULT_MAX_MEMORY);
454 assert_eq!(config.page_size, DEFAULT_PAGE_SIZE);
455 assert!(config.prefetch);
456 assert_eq!(config.eviction, EvictionStrategy::LRU);
457 }
458
459 #[test]
460 fn test_paging_config_builder() {
461 let config = PagingConfig::new()
462 .with_max_memory(50_000)
463 .with_page_size(8192)
464 .with_prefetch(false)
465 .with_eviction(EvictionStrategy::LFU);
466
467 assert_eq!(config.max_memory, 50_000);
468 assert_eq!(config.page_size, 8192);
469 assert!(!config.prefetch);
470 assert_eq!(config.eviction, EvictionStrategy::LFU);
471 }
472
473 #[test]
474 fn test_paging_stats() {
475 let mut stats = PagingStats::default();
476 assert_eq!(stats.hit_rate(), 0.0);
477
478 stats.hits = 3;
479 stats.misses = 1;
480 assert!((stats.hit_rate() - 0.75).abs() < f32::EPSILON);
481 }
482
483 #[test]
484 fn test_paged_bundle_open() {
485 let temp = create_test_bundle(&[("model1", vec![1, 2, 3]), ("model2", vec![4, 5, 6, 7])]);
486
487 let bundle =
488 PagedBundle::open(temp.path(), PagingConfig::default()).expect("Failed to open bundle");
489
490 assert_eq!(bundle.model_names().len(), 2);
491 assert_eq!(bundle.cached_count(), 0);
492 assert_eq!(bundle.memory_used(), 0);
493 }
494
495 #[test]
496 fn test_paged_bundle_get_model() {
497 let temp = create_test_bundle(&[("weights", vec![10, 20, 30, 40, 50])]);
498
499 let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
500 .expect("Failed to open");
501
502 let data = bundle.get_model("weights").expect("Failed to get model");
503 assert_eq!(data, &[10, 20, 30, 40, 50]);
504 assert_eq!(bundle.cached_count(), 1);
505 assert_eq!(bundle.stats().misses, 1);
506 assert_eq!(bundle.stats().hits, 0);
507
508 let _ = bundle.get_model("weights").expect("Failed to get model");
510 assert_eq!(bundle.stats().hits, 1);
511 }
512
513 #[test]
514 fn test_paged_bundle_eviction() {
515 let temp = create_test_bundle(&[
518 ("model1", vec![1; 1000]),
519 ("model2", vec![2; 1000]),
520 ("model3", vec![3; 1000]),
521 ]);
522
523 let mut bundle = PagedBundle::open(
525 temp.path(),
526 PagingConfig::new()
527 .with_max_memory(1500)
528 .with_prefetch(false),
529 )
530 .expect("Failed to open");
531
532 let _ = bundle.get_model("model1").expect("Failed");
534 assert_eq!(bundle.cached_count(), 1);
535 assert_eq!(bundle.memory_used(), 1000);
536
537 let _ = bundle.get_model("model2").expect("Failed");
540 assert!(
541 bundle.stats().evictions > 0,
542 "Expected evictions > 0, got {}",
543 bundle.stats().evictions
544 );
545 assert!(bundle.memory_used() <= 1500);
546
547 let _ = bundle.get_model("model3").expect("Failed");
549 assert!(bundle.stats().evictions >= 2);
550 assert!(bundle.memory_used() <= 1500);
551 }
552
553 #[test]
554 fn test_paged_bundle_explicit_evict() {
555 let temp = create_test_bundle(&[("model1", vec![1, 2, 3])]);
556
557 let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
558 .expect("Failed to open");
559
560 let _ = bundle.get_model("model1").expect("Failed");
562 assert!(bundle.is_cached("model1"));
563
564 let evicted = bundle.evict("model1");
566 assert!(evicted);
567 assert!(!bundle.is_cached("model1"));
568 }
569
570 #[test]
571 fn test_paged_bundle_clear_cache() {
572 let temp = create_test_bundle(&[("model1", vec![1, 2, 3]), ("model2", vec![4, 5, 6])]);
573
574 let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
575 .expect("Failed to open");
576
577 let _ = bundle.get_model("model1").expect("Failed");
578 let _ = bundle.get_model("model2").expect("Failed");
579 assert_eq!(bundle.cached_count(), 2);
580
581 bundle.clear_cache();
582 assert_eq!(bundle.cached_count(), 0);
583 assert_eq!(bundle.memory_used(), 0);
584 }
585
586 #[test]
587 fn test_paged_bundle_prefetch_hint() {
588 let temp = create_test_bundle(&[("model1", vec![1, 2, 3])]);
589
590 let mut bundle =
591 PagedBundle::open(temp.path(), PagingConfig::new()).expect("Failed to open");
592
593 bundle.prefetch_hint("model1").expect("Prefetch failed");
595 assert!(bundle.is_cached("model1"));
596
597 let _ = bundle.get_model("model1").expect("Failed");
599 assert_eq!(bundle.stats().hits, 1);
600 assert_eq!(bundle.stats().misses, 0);
601 }
602
603 #[test]
604 fn test_paged_bundle_nonexistent_model() {
605 let temp = create_test_bundle(&[("model1", vec![1, 2, 3])]);
606
607 let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
608 .expect("Failed to open");
609
610 let result = bundle.get_model("nonexistent");
611 assert!(result.is_err());
612 }
613}