use super::format::BundleReader;
use super::manifest::{BundleManifest, ModelEntry};
use super::mmap::PageTable;
use super::{DEFAULT_MAX_MEMORY, DEFAULT_PAGE_SIZE};
use crate::error::{AprenderError, Result};
use std::collections::{HashMap, VecDeque};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct PagingConfig {
pub max_memory: usize,
pub page_size: usize,
pub prefetch: bool,
pub prefetch_count: usize,
pub eviction: EvictionStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[allow(clippy::upper_case_acronyms)]
pub enum EvictionStrategy {
#[default]
LRU,
LFU,
}
impl Default for PagingConfig {
fn default() -> Self {
Self {
max_memory: DEFAULT_MAX_MEMORY,
page_size: DEFAULT_PAGE_SIZE,
prefetch: true,
prefetch_count: 2,
eviction: EvictionStrategy::default(),
}
}
}
impl PagingConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_memory(mut self, max_memory: usize) -> Self {
self.max_memory = max_memory.max(1024);
self
}
#[must_use]
pub fn with_page_size(mut self, page_size: usize) -> Self {
self.page_size = page_size.max(512);
self
}
#[must_use]
pub fn with_prefetch(mut self, prefetch: bool) -> Self {
self.prefetch = prefetch;
self
}
#[must_use]
pub fn with_prefetch_count(mut self, count: usize) -> Self {
self.prefetch_count = count;
self
}
#[must_use]
pub fn with_eviction(mut self, strategy: EvictionStrategy) -> Self {
self.eviction = strategy;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct PagingStats {
pub hits: usize,
pub misses: usize,
pub evictions: usize,
pub bytes_loaded: usize,
pub memory_used: usize,
pub prefetches: usize,
}
impl PagingStats {
#[must_use]
pub fn hit_rate(&self) -> f32 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f32 / total as f32
}
}
pub fn reset(&mut self) {
*self = Self::default();
}
}
pub struct PagedBundle {
reader: BundleReader,
manifest: BundleManifest,
cache: HashMap<String, Vec<u8>>,
lru_order: VecDeque<String>,
page_table: PageTable,
config: PagingConfig,
stats: PagingStats,
access_history: VecDeque<String>,
}
impl PagedBundle {
pub fn open(path: impl AsRef<Path>, config: PagingConfig) -> Result<Self> {
let mut reader = BundleReader::open(path)?;
let manifest = reader.read_manifest()?;
Ok(Self {
reader,
manifest,
cache: HashMap::new(),
lru_order: VecDeque::new(),
page_table: PageTable::new(),
config,
stats: PagingStats::default(),
access_history: VecDeque::with_capacity(10),
})
}
pub fn get_model(&mut self, name: &str) -> Result<&[u8]> {
if self.cache.contains_key(name) {
self.stats.hits += 1;
self.update_lru(name);
self.record_access(name);
if self.config.prefetch {
self.try_prefetch();
}
return Ok(self.cache.get(name).expect("Key should exist"));
}
self.stats.misses += 1;
self.load_model(name)?;
self.record_access(name);
if self.config.prefetch {
self.try_prefetch();
}
Ok(self.cache.get(name).expect("Just loaded"))
}
#[must_use]
pub fn is_cached(&self, name: &str) -> bool {
self.cache.contains_key(name)
}
#[must_use]
pub fn model_names(&self) -> Vec<&str> {
self.manifest.model_names()
}
#[must_use]
pub fn get_metadata(&self, name: &str) -> Option<&ModelEntry> {
self.manifest.get_model(name)
}
#[must_use]
pub fn stats(&self) -> &PagingStats {
&self.stats
}
#[must_use]
pub fn config(&self) -> &PagingConfig {
&self.config
}
#[must_use]
pub fn memory_used(&self) -> usize {
self.stats.memory_used
}
#[must_use]
pub fn cached_count(&self) -> usize {
self.cache.len()
}
pub fn evict(&mut self, name: &str) -> bool {
if let Some(data) = self.cache.remove(name) {
self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
self.stats.evictions += 1;
self.lru_order.retain(|n| n != name);
true
} else {
false
}
}
pub fn clear_cache(&mut self) {
self.cache.clear();
self.lru_order.clear();
self.stats.memory_used = 0;
}
pub fn prefetch_hint(&mut self, name: &str) -> Result<()> {
if !self.cache.contains_key(name) && self.manifest.get_model(name).is_some() {
self.load_model(name)?;
self.stats.prefetches += 1;
}
Ok(())
}
fn load_model(&mut self, name: &str) -> Result<()> {
let entry = self
.manifest
.get_model(name)
.ok_or_else(|| AprenderError::Other(format!("Model '{name}' not found")))?
.clone();
while self.stats.memory_used + entry.size > self.config.max_memory {
if !self.evict_lru() {
break;
}
}
let data = self.reader.read_model(&entry)?;
let size = data.len();
self.stats.bytes_loaded += size;
self.stats.memory_used += size;
self.cache.insert(name.to_string(), data);
self.lru_order.push_back(name.to_string());
self.page_table.add_page(entry.offset, size);
Ok(())
}
fn update_lru(&mut self, name: &str) {
self.lru_order.retain(|n| n != name);
self.lru_order.push_back(name.to_string());
if let Some(entry) = self.manifest.get_model(name) {
self.page_table.touch(entry.offset);
}
}
fn evict_lru(&mut self) -> bool {
let to_evict = match self.config.eviction {
EvictionStrategy::LRU => self.lru_order.pop_front(),
EvictionStrategy::LFU => {
if let Some(offset) = self.page_table.lfu_page() {
self.manifest
.iter()
.find(|e| e.offset == offset)
.map(|e| e.name.clone())
} else {
self.lru_order.pop_front()
}
}
};
if let Some(name) = to_evict {
if let Some(data) = self.cache.remove(&name) {
self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
self.stats.evictions += 1;
if let Some(entry) = self.manifest.get_model(&name) {
self.page_table.remove(entry.offset);
}
return true;
}
}
false
}
fn record_access(&mut self, name: &str) {
if self.access_history.len() >= 10 {
self.access_history.pop_front();
}
self.access_history.push_back(name.to_string());
}
fn try_prefetch(&mut self) {
if self.access_history.len() < 2 {
return;
}
let last = self.access_history.back().cloned();
if let Some(last_name) = last {
let patterns: Vec<_> = self
.access_history
.iter()
.zip(self.access_history.iter().skip(1))
.filter(|(prev, _)| *prev == &last_name)
.map(|(_, next)| next.clone())
.take(self.config.prefetch_count)
.collect();
for name in patterns {
if !self.cache.contains_key(&name)
&& self.stats.memory_used + self.estimate_size(&name) <= self.config.max_memory
{
let _ = self.load_model(&name);
self.stats.prefetches += 1;
}
}
}
}
fn estimate_size(&self, name: &str) -> usize {
self.manifest.get_model(name).map_or(0, |e| e.size)
}
}
impl std::fmt::Debug for PagedBundle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PagedBundle")
.field("models", &self.manifest.len())
.field("cached", &self.cache.len())
.field("memory_used", &self.stats.memory_used)
.field("max_memory", &self.config.max_memory)
.field("hit_rate", &self.stats.hit_rate())
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[path = "paging_tests.rs"]
mod tests;