aprender/bundle/
paging.rs

1//! Memory Paging for Model Bundles
2//!
3//! Implements LRU-based paging for loading models larger than available RAM.
4
5use super::format::BundleReader;
6use super::manifest::{BundleManifest, ModelEntry};
7use super::mmap::PageTable;
8use super::{DEFAULT_MAX_MEMORY, DEFAULT_PAGE_SIZE};
9use crate::error::{AprenderError, Result};
10use std::collections::{HashMap, VecDeque};
11use std::path::Path;
12
13// ============================================================================
14// Paging Configuration
15// ============================================================================
16
17/// Configuration for paged model loading.
18#[derive(Debug, Clone)]
19pub struct PagingConfig {
20    /// Maximum memory to use for cached model data (bytes).
21    pub max_memory: usize,
22    /// Page size for loading (bytes).
23    pub page_size: usize,
24    /// Enable pre-fetching of likely-needed pages.
25    pub prefetch: bool,
26    /// Number of pages to pre-fetch.
27    pub prefetch_count: usize,
28    /// Eviction strategy.
29    pub eviction: EvictionStrategy,
30}
31
32/// Strategy for evicting pages when memory is full.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34#[allow(clippy::upper_case_acronyms)]
35pub enum EvictionStrategy {
36    /// Least Recently Used - evict oldest accessed page.
37    #[default]
38    LRU,
39    /// Least Frequently Used - evict least accessed page.
40    LFU,
41}
42
43impl Default for PagingConfig {
44    fn default() -> Self {
45        Self {
46            max_memory: DEFAULT_MAX_MEMORY,
47            page_size: DEFAULT_PAGE_SIZE,
48            prefetch: true,
49            prefetch_count: 2,
50            eviction: EvictionStrategy::default(),
51        }
52    }
53}
54
55impl PagingConfig {
56    /// Create a new paging configuration.
57    #[must_use]
58    pub fn new() -> Self {
59        Self::default()
60    }
61
62    /// Set maximum memory.
63    ///
64    /// Note: The minimum is 1024 bytes to ensure meaningful paging behavior.
65    #[must_use]
66    pub fn with_max_memory(mut self, max_memory: usize) -> Self {
67        self.max_memory = max_memory.max(1024);
68        self
69    }
70
71    /// Set page size.
72    #[must_use]
73    pub fn with_page_size(mut self, page_size: usize) -> Self {
74        self.page_size = page_size.max(512);
75        self
76    }
77
78    /// Enable or disable pre-fetching.
79    #[must_use]
80    pub fn with_prefetch(mut self, prefetch: bool) -> Self {
81        self.prefetch = prefetch;
82        self
83    }
84
85    /// Set pre-fetch count.
86    #[must_use]
87    pub fn with_prefetch_count(mut self, count: usize) -> Self {
88        self.prefetch_count = count;
89        self
90    }
91
92    /// Set eviction strategy.
93    #[must_use]
94    pub fn with_eviction(mut self, strategy: EvictionStrategy) -> Self {
95        self.eviction = strategy;
96        self
97    }
98}
99
100// ============================================================================
101// Paging Statistics
102// ============================================================================
103
104/// Statistics for paged bundle access.
105#[derive(Debug, Clone, Default)]
106pub struct PagingStats {
107    /// Number of page hits (data already in memory).
108    pub hits: usize,
109    /// Number of page misses (data loaded from disk).
110    pub misses: usize,
111    /// Number of page evictions.
112    pub evictions: usize,
113    /// Total bytes loaded.
114    pub bytes_loaded: usize,
115    /// Current memory usage.
116    pub memory_used: usize,
117    /// Number of pre-fetches.
118    pub prefetches: usize,
119}
120
121impl PagingStats {
122    /// Calculate hit rate.
123    #[must_use]
124    pub fn hit_rate(&self) -> f32 {
125        let total = self.hits + self.misses;
126        if total == 0 {
127            0.0
128        } else {
129            self.hits as f32 / total as f32
130        }
131    }
132
133    /// Reset statistics.
134    pub fn reset(&mut self) {
135        *self = Self::default();
136    }
137}
138
139// ============================================================================
140// Paged Bundle
141// ============================================================================
142
143/// A model bundle with memory paging support.
144///
145/// Enables loading models larger than available RAM by dynamically
146/// loading and evicting model data as needed.
147pub struct PagedBundle {
148    /// Bundle reader.
149    reader: BundleReader,
150    /// Bundle manifest.
151    manifest: BundleManifest,
152    /// Cached model data.
153    cache: HashMap<String, Vec<u8>>,
154    /// LRU order for eviction.
155    lru_order: VecDeque<String>,
156    /// Page table for tracking.
157    page_table: PageTable,
158    /// Paging configuration.
159    config: PagingConfig,
160    /// Paging statistics.
161    stats: PagingStats,
162    /// Access history for pre-fetching.
163    access_history: VecDeque<String>,
164}
165
166impl PagedBundle {
167    /// Open a bundle with paging enabled.
168    pub fn open(path: impl AsRef<Path>, config: PagingConfig) -> Result<Self> {
169        let mut reader = BundleReader::open(path)?;
170        let manifest = reader.read_manifest()?;
171
172        Ok(Self {
173            reader,
174            manifest,
175            cache: HashMap::new(),
176            lru_order: VecDeque::new(),
177            page_table: PageTable::new(),
178            config,
179            stats: PagingStats::default(),
180            access_history: VecDeque::with_capacity(10),
181        })
182    }
183
184    /// Get a model's data, loading from disk if needed.
185    pub fn get_model(&mut self, name: &str) -> Result<&[u8]> {
186        // Check cache first
187        if self.cache.contains_key(name) {
188            self.stats.hits += 1;
189            self.update_lru(name);
190            self.record_access(name);
191
192            // Pre-fetch if enabled
193            if self.config.prefetch {
194                self.try_prefetch();
195            }
196
197            return Ok(self.cache.get(name).expect("Key should exist"));
198        }
199
200        // Cache miss - load from disk
201        self.stats.misses += 1;
202        self.load_model(name)?;
203
204        // Record access
205        self.record_access(name);
206
207        // Pre-fetch if enabled
208        if self.config.prefetch {
209            self.try_prefetch();
210        }
211
212        Ok(self.cache.get(name).expect("Just loaded"))
213    }
214
215    /// Check if a model is currently in cache.
216    #[must_use]
217    pub fn is_cached(&self, name: &str) -> bool {
218        self.cache.contains_key(name)
219    }
220
221    /// Get all model names.
222    #[must_use]
223    pub fn model_names(&self) -> Vec<&str> {
224        self.manifest.model_names()
225    }
226
227    /// Get model metadata.
228    #[must_use]
229    pub fn get_metadata(&self, name: &str) -> Option<&ModelEntry> {
230        self.manifest.get_model(name)
231    }
232
233    /// Get paging statistics.
234    #[must_use]
235    pub fn stats(&self) -> &PagingStats {
236        &self.stats
237    }
238
239    /// Get paging configuration.
240    #[must_use]
241    pub fn config(&self) -> &PagingConfig {
242        &self.config
243    }
244
245    /// Get current memory usage.
246    #[must_use]
247    pub fn memory_used(&self) -> usize {
248        self.stats.memory_used
249    }
250
251    /// Get number of cached models.
252    #[must_use]
253    pub fn cached_count(&self) -> usize {
254        self.cache.len()
255    }
256
257    /// Explicitly evict a model from cache.
258    pub fn evict(&mut self, name: &str) -> bool {
259        if let Some(data) = self.cache.remove(name) {
260            self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
261            self.stats.evictions += 1;
262            self.lru_order.retain(|n| n != name);
263            true
264        } else {
265            false
266        }
267    }
268
269    /// Clear all cached data.
270    pub fn clear_cache(&mut self) {
271        self.cache.clear();
272        self.lru_order.clear();
273        self.stats.memory_used = 0;
274    }
275
276    /// Hint that a model will be needed soon.
277    pub fn prefetch_hint(&mut self, name: &str) -> Result<()> {
278        if !self.cache.contains_key(name) && self.manifest.get_model(name).is_some() {
279            self.load_model(name)?;
280            self.stats.prefetches += 1;
281        }
282        Ok(())
283    }
284
285    /// Load a model into cache.
286    fn load_model(&mut self, name: &str) -> Result<()> {
287        let entry = self
288            .manifest
289            .get_model(name)
290            .ok_or_else(|| AprenderError::Other(format!("Model '{name}' not found")))?
291            .clone();
292
293        // Evict if necessary
294        while self.stats.memory_used + entry.size > self.config.max_memory {
295            if !self.evict_lru() {
296                // Can't evict anything, but try to load anyway
297                break;
298            }
299        }
300
301        // Load the data
302        let data = self.reader.read_model(&entry)?;
303        let size = data.len();
304
305        // Update stats
306        self.stats.bytes_loaded += size;
307        self.stats.memory_used += size;
308
309        // Add to cache
310        self.cache.insert(name.to_string(), data);
311        self.lru_order.push_back(name.to_string());
312
313        // Update page table
314        self.page_table.add_page(entry.offset, size);
315
316        Ok(())
317    }
318
319    /// Update LRU order for a model.
320    fn update_lru(&mut self, name: &str) {
321        self.lru_order.retain(|n| n != name);
322        self.lru_order.push_back(name.to_string());
323
324        // Update page table timestamp
325        if let Some(entry) = self.manifest.get_model(name) {
326            self.page_table.touch(entry.offset);
327        }
328    }
329
330    /// Evict the least recently used model.
331    fn evict_lru(&mut self) -> bool {
332        let to_evict = match self.config.eviction {
333            EvictionStrategy::LRU => self.lru_order.pop_front(),
334            EvictionStrategy::LFU => {
335                if let Some(offset) = self.page_table.lfu_page() {
336                    // Find model with this offset
337                    self.manifest
338                        .iter()
339                        .find(|e| e.offset == offset)
340                        .map(|e| e.name.clone())
341                } else {
342                    self.lru_order.pop_front()
343                }
344            }
345        };
346
347        if let Some(name) = to_evict {
348            if let Some(data) = self.cache.remove(&name) {
349                self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
350                self.stats.evictions += 1;
351
352                // Remove from page table
353                if let Some(entry) = self.manifest.get_model(&name) {
354                    self.page_table.remove(entry.offset);
355                }
356
357                return true;
358            }
359        }
360
361        false
362    }
363
364    /// Record an access for prediction.
365    fn record_access(&mut self, name: &str) {
366        if self.access_history.len() >= 10 {
367            self.access_history.pop_front();
368        }
369        self.access_history.push_back(name.to_string());
370    }
371
372    /// Try to pre-fetch likely-needed models.
373    fn try_prefetch(&mut self) {
374        if self.access_history.len() < 2 {
375            return;
376        }
377
378        // Simple pattern: if A -> B happened before, pre-fetch B after A
379        let last = self.access_history.back().cloned();
380        if let Some(last_name) = last {
381            // Look for patterns in history
382            let patterns: Vec<_> = self
383                .access_history
384                .iter()
385                .zip(self.access_history.iter().skip(1))
386                .filter(|(prev, _)| *prev == &last_name)
387                .map(|(_, next)| next.clone())
388                .take(self.config.prefetch_count)
389                .collect();
390
391            for name in patterns {
392                if !self.cache.contains_key(&name)
393                    && self.stats.memory_used + self.estimate_size(&name) <= self.config.max_memory
394                {
395                    let _ = self.load_model(&name);
396                    self.stats.prefetches += 1;
397                }
398            }
399        }
400    }
401
402    /// Estimate size of a model.
403    fn estimate_size(&self, name: &str) -> usize {
404        self.manifest.get_model(name).map_or(0, |e| e.size)
405    }
406}
407
408impl std::fmt::Debug for PagedBundle {
409    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        f.debug_struct("PagedBundle")
411            .field("models", &self.manifest.len())
412            .field("cached", &self.cache.len())
413            .field("memory_used", &self.stats.memory_used)
414            .field("max_memory", &self.config.max_memory)
415            .field("hit_rate", &self.stats.hit_rate())
416            .finish_non_exhaustive()
417    }
418}
419
420// ============================================================================
421// Tests
422// ============================================================================
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::bundle::{BundleManifest, BundleWriter, ModelEntry};
428    use std::collections::HashMap;
429    use tempfile::NamedTempFile;
430
431    fn create_test_bundle(models: &[(&str, Vec<u8>)]) -> NamedTempFile {
432        let temp = NamedTempFile::new().expect("Failed to create temp file");
433
434        let mut manifest = BundleManifest::new();
435        let mut model_map = HashMap::new();
436
437        for (name, data) in models {
438            manifest.add_model(ModelEntry::new(*name, data.len()));
439            model_map.insert((*name).to_string(), data.clone());
440        }
441
442        let writer = BundleWriter::create(temp.path()).expect("Failed to create writer");
443        writer
444            .write_bundle(&manifest, &model_map)
445            .expect("Failed to write bundle");
446
447        temp
448    }
449
450    #[test]
451    fn test_paging_config_default() {
452        let config = PagingConfig::default();
453        assert_eq!(config.max_memory, DEFAULT_MAX_MEMORY);
454        assert_eq!(config.page_size, DEFAULT_PAGE_SIZE);
455        assert!(config.prefetch);
456        assert_eq!(config.eviction, EvictionStrategy::LRU);
457    }
458
459    #[test]
460    fn test_paging_config_builder() {
461        let config = PagingConfig::new()
462            .with_max_memory(50_000)
463            .with_page_size(8192)
464            .with_prefetch(false)
465            .with_eviction(EvictionStrategy::LFU);
466
467        assert_eq!(config.max_memory, 50_000);
468        assert_eq!(config.page_size, 8192);
469        assert!(!config.prefetch);
470        assert_eq!(config.eviction, EvictionStrategy::LFU);
471    }
472
473    #[test]
474    fn test_paging_stats() {
475        let mut stats = PagingStats::default();
476        assert_eq!(stats.hit_rate(), 0.0);
477
478        stats.hits = 3;
479        stats.misses = 1;
480        assert!((stats.hit_rate() - 0.75).abs() < f32::EPSILON);
481    }
482
483    #[test]
484    fn test_paged_bundle_open() {
485        let temp = create_test_bundle(&[("model1", vec![1, 2, 3]), ("model2", vec![4, 5, 6, 7])]);
486
487        let bundle =
488            PagedBundle::open(temp.path(), PagingConfig::default()).expect("Failed to open bundle");
489
490        assert_eq!(bundle.model_names().len(), 2);
491        assert_eq!(bundle.cached_count(), 0);
492        assert_eq!(bundle.memory_used(), 0);
493    }
494
495    #[test]
496    fn test_paged_bundle_get_model() {
497        let temp = create_test_bundle(&[("weights", vec![10, 20, 30, 40, 50])]);
498
499        let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
500            .expect("Failed to open");
501
502        let data = bundle.get_model("weights").expect("Failed to get model");
503        assert_eq!(data, &[10, 20, 30, 40, 50]);
504        assert_eq!(bundle.cached_count(), 1);
505        assert_eq!(bundle.stats().misses, 1);
506        assert_eq!(bundle.stats().hits, 0);
507
508        // Second access should be a hit
509        let _ = bundle.get_model("weights").expect("Failed to get model");
510        assert_eq!(bundle.stats().hits, 1);
511    }
512
513    #[test]
514    fn test_paged_bundle_eviction() {
515        // Use 1000-byte models with 1500-byte max memory
516        // This forces eviction after the first model since 2000 > 1500
517        let temp = create_test_bundle(&[
518            ("model1", vec![1; 1000]),
519            ("model2", vec![2; 1000]),
520            ("model3", vec![3; 1000]),
521        ]);
522
523        // Small max memory to force eviction (1500 = 1.5 models worth)
524        let mut bundle = PagedBundle::open(
525            temp.path(),
526            PagingConfig::new()
527                .with_max_memory(1500)
528                .with_prefetch(false),
529        )
530        .expect("Failed to open");
531
532        // Load first model - fits in memory
533        let _ = bundle.get_model("model1").expect("Failed");
534        assert_eq!(bundle.cached_count(), 1);
535        assert_eq!(bundle.memory_used(), 1000);
536
537        // Load second model - should trigger eviction of model1
538        // 1000 + 1000 = 2000 > 1500, must evict
539        let _ = bundle.get_model("model2").expect("Failed");
540        assert!(
541            bundle.stats().evictions > 0,
542            "Expected evictions > 0, got {}",
543            bundle.stats().evictions
544        );
545        assert!(bundle.memory_used() <= 1500);
546
547        // Load third model - should trigger another eviction
548        let _ = bundle.get_model("model3").expect("Failed");
549        assert!(bundle.stats().evictions >= 2);
550        assert!(bundle.memory_used() <= 1500);
551    }
552
553    #[test]
554    fn test_paged_bundle_explicit_evict() {
555        let temp = create_test_bundle(&[("model1", vec![1, 2, 3])]);
556
557        let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
558            .expect("Failed to open");
559
560        // Load model
561        let _ = bundle.get_model("model1").expect("Failed");
562        assert!(bundle.is_cached("model1"));
563
564        // Explicitly evict
565        let evicted = bundle.evict("model1");
566        assert!(evicted);
567        assert!(!bundle.is_cached("model1"));
568    }
569
570    #[test]
571    fn test_paged_bundle_clear_cache() {
572        let temp = create_test_bundle(&[("model1", vec![1, 2, 3]), ("model2", vec![4, 5, 6])]);
573
574        let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
575            .expect("Failed to open");
576
577        let _ = bundle.get_model("model1").expect("Failed");
578        let _ = bundle.get_model("model2").expect("Failed");
579        assert_eq!(bundle.cached_count(), 2);
580
581        bundle.clear_cache();
582        assert_eq!(bundle.cached_count(), 0);
583        assert_eq!(bundle.memory_used(), 0);
584    }
585
586    #[test]
587    fn test_paged_bundle_prefetch_hint() {
588        let temp = create_test_bundle(&[("model1", vec![1, 2, 3])]);
589
590        let mut bundle =
591            PagedBundle::open(temp.path(), PagingConfig::new()).expect("Failed to open");
592
593        // Pre-fetch
594        bundle.prefetch_hint("model1").expect("Prefetch failed");
595        assert!(bundle.is_cached("model1"));
596
597        // Access should now be a hit
598        let _ = bundle.get_model("model1").expect("Failed");
599        assert_eq!(bundle.stats().hits, 1);
600        assert_eq!(bundle.stats().misses, 0);
601    }
602
603    #[test]
604    fn test_paged_bundle_nonexistent_model() {
605        let temp = create_test_bundle(&[("model1", vec![1, 2, 3])]);
606
607        let mut bundle = PagedBundle::open(temp.path(), PagingConfig::new().with_prefetch(false))
608            .expect("Failed to open");
609
610        let result = bundle.get_model("nonexistent");
611        assert!(result.is_err());
612    }
613}