aprender-core 0.29.1

Next-generation machine learning library in pure Rust
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
//! Memory Paging for Model Bundles
//!
//! Implements LRU-based paging for loading models larger than available RAM.

use super::format::BundleReader;
use super::manifest::{BundleManifest, ModelEntry};
use super::mmap::PageTable;
use super::{DEFAULT_MAX_MEMORY, DEFAULT_PAGE_SIZE};
use crate::error::{AprenderError, Result};
use std::collections::{HashMap, VecDeque};
use std::path::Path;

// ============================================================================
// Paging Configuration
// ============================================================================

/// Configuration for paged model loading.
#[derive(Debug, Clone)]
pub struct PagingConfig {
    /// Maximum memory to use for cached model data (bytes).
    pub max_memory: usize,
    /// Page size for loading (bytes).
    pub page_size: usize,
    /// Enable pre-fetching of likely-needed pages.
    pub prefetch: bool,
    /// Number of pages to pre-fetch.
    pub prefetch_count: usize,
    /// Eviction strategy.
    pub eviction: EvictionStrategy,
}

/// Strategy for evicting pages when memory is full.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[allow(clippy::upper_case_acronyms)]
pub enum EvictionStrategy {
    /// Least Recently Used - evict oldest accessed page.
    #[default]
    LRU,
    /// Least Frequently Used - evict least accessed page.
    LFU,
}

impl Default for PagingConfig {
    fn default() -> Self {
        Self {
            max_memory: DEFAULT_MAX_MEMORY,
            page_size: DEFAULT_PAGE_SIZE,
            prefetch: true,
            prefetch_count: 2,
            eviction: EvictionStrategy::default(),
        }
    }
}

impl PagingConfig {
    /// Create a new paging configuration.
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    /// Set maximum memory.
    ///
    /// Note: The minimum is 1024 bytes to ensure meaningful paging behavior.
    #[must_use]
    pub fn with_max_memory(mut self, max_memory: usize) -> Self {
        self.max_memory = max_memory.max(1024);
        self
    }

    /// Set page size.
    #[must_use]
    pub fn with_page_size(mut self, page_size: usize) -> Self {
        self.page_size = page_size.max(512);
        self
    }

    /// Enable or disable pre-fetching.
    #[must_use]
    pub fn with_prefetch(mut self, prefetch: bool) -> Self {
        self.prefetch = prefetch;
        self
    }

    /// Set pre-fetch count.
    #[must_use]
    pub fn with_prefetch_count(mut self, count: usize) -> Self {
        self.prefetch_count = count;
        self
    }

    /// Set eviction strategy.
    #[must_use]
    pub fn with_eviction(mut self, strategy: EvictionStrategy) -> Self {
        self.eviction = strategy;
        self
    }
}

// ============================================================================
// Paging Statistics
// ============================================================================

/// Statistics for paged bundle access.
#[derive(Debug, Clone, Default)]
pub struct PagingStats {
    /// Number of page hits (data already in memory).
    pub hits: usize,
    /// Number of page misses (data loaded from disk).
    pub misses: usize,
    /// Number of page evictions.
    pub evictions: usize,
    /// Total bytes loaded.
    pub bytes_loaded: usize,
    /// Current memory usage.
    pub memory_used: usize,
    /// Number of pre-fetches.
    pub prefetches: usize,
}

impl PagingStats {
    /// Calculate hit rate.
    #[must_use]
    pub fn hit_rate(&self) -> f32 {
        let total = self.hits + self.misses;
        if total == 0 {
            0.0
        } else {
            self.hits as f32 / total as f32
        }
    }

    /// Reset statistics.
    pub fn reset(&mut self) {
        *self = Self::default();
    }
}

// ============================================================================
// Paged Bundle
// ============================================================================

/// A model bundle with memory paging support.
///
/// Enables loading models larger than available RAM by dynamically
/// loading and evicting model data as needed.
pub struct PagedBundle {
    /// Bundle reader.
    reader: BundleReader,
    /// Bundle manifest.
    manifest: BundleManifest,
    /// Cached model data.
    cache: HashMap<String, Vec<u8>>,
    /// LRU order for eviction.
    lru_order: VecDeque<String>,
    /// Page table for tracking.
    page_table: PageTable,
    /// Paging configuration.
    config: PagingConfig,
    /// Paging statistics.
    stats: PagingStats,
    /// Access history for pre-fetching.
    access_history: VecDeque<String>,
}

impl PagedBundle {
    /// Open a bundle with paging enabled.
    pub fn open(path: impl AsRef<Path>, config: PagingConfig) -> Result<Self> {
        let mut reader = BundleReader::open(path)?;
        let manifest = reader.read_manifest()?;

        Ok(Self {
            reader,
            manifest,
            cache: HashMap::new(),
            lru_order: VecDeque::new(),
            page_table: PageTable::new(),
            config,
            stats: PagingStats::default(),
            access_history: VecDeque::with_capacity(10),
        })
    }

    /// Get a model's data, loading from disk if needed.
    pub fn get_model(&mut self, name: &str) -> Result<&[u8]> {
        // Check cache first
        if self.cache.contains_key(name) {
            self.stats.hits += 1;
            self.update_lru(name);
            self.record_access(name);

            // Pre-fetch if enabled
            if self.config.prefetch {
                self.try_prefetch();
            }

            return Ok(self.cache.get(name).expect("Key should exist"));
        }

        // Cache miss - load from disk
        self.stats.misses += 1;
        self.load_model(name)?;

        // Record access
        self.record_access(name);

        // Pre-fetch if enabled
        if self.config.prefetch {
            self.try_prefetch();
        }

        Ok(self.cache.get(name).expect("Just loaded"))
    }

    /// Check if a model is currently in cache.
    #[must_use]
    pub fn is_cached(&self, name: &str) -> bool {
        self.cache.contains_key(name)
    }

    /// Get all model names.
    #[must_use]
    pub fn model_names(&self) -> Vec<&str> {
        self.manifest.model_names()
    }

    /// Get model metadata.
    #[must_use]
    pub fn get_metadata(&self, name: &str) -> Option<&ModelEntry> {
        self.manifest.get_model(name)
    }

    /// Get paging statistics.
    #[must_use]
    pub fn stats(&self) -> &PagingStats {
        &self.stats
    }

    /// Get paging configuration.
    #[must_use]
    pub fn config(&self) -> &PagingConfig {
        &self.config
    }

    /// Get current memory usage.
    #[must_use]
    pub fn memory_used(&self) -> usize {
        self.stats.memory_used
    }

    /// Get number of cached models.
    #[must_use]
    pub fn cached_count(&self) -> usize {
        self.cache.len()
    }

    /// Explicitly evict a model from cache.
    pub fn evict(&mut self, name: &str) -> bool {
        if let Some(data) = self.cache.remove(name) {
            self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
            self.stats.evictions += 1;
            self.lru_order.retain(|n| n != name);
            true
        } else {
            false
        }
    }

    /// Clear all cached data.
    pub fn clear_cache(&mut self) {
        self.cache.clear();
        self.lru_order.clear();
        self.stats.memory_used = 0;
    }

    /// Hint that a model will be needed soon.
    pub fn prefetch_hint(&mut self, name: &str) -> Result<()> {
        if !self.cache.contains_key(name) && self.manifest.get_model(name).is_some() {
            self.load_model(name)?;
            self.stats.prefetches += 1;
        }
        Ok(())
    }

    /// Load a model into cache.
    fn load_model(&mut self, name: &str) -> Result<()> {
        let entry = self
            .manifest
            .get_model(name)
            .ok_or_else(|| AprenderError::Other(format!("Model '{name}' not found")))?
            .clone();

        // Evict if necessary
        while self.stats.memory_used + entry.size > self.config.max_memory {
            if !self.evict_lru() {
                // Can't evict anything, but try to load anyway
                break;
            }
        }

        // Load the data
        let data = self.reader.read_model(&entry)?;
        let size = data.len();

        // Update stats
        self.stats.bytes_loaded += size;
        self.stats.memory_used += size;

        // Add to cache
        self.cache.insert(name.to_string(), data);
        self.lru_order.push_back(name.to_string());

        // Update page table
        self.page_table.add_page(entry.offset, size);

        Ok(())
    }

    /// Update LRU order for a model.
    fn update_lru(&mut self, name: &str) {
        self.lru_order.retain(|n| n != name);
        self.lru_order.push_back(name.to_string());

        // Update page table timestamp
        if let Some(entry) = self.manifest.get_model(name) {
            self.page_table.touch(entry.offset);
        }
    }

    /// Evict the least recently used model.
    fn evict_lru(&mut self) -> bool {
        let to_evict = match self.config.eviction {
            EvictionStrategy::LRU => self.lru_order.pop_front(),
            EvictionStrategy::LFU => {
                if let Some(offset) = self.page_table.lfu_page() {
                    // Find model with this offset
                    self.manifest
                        .iter()
                        .find(|e| e.offset == offset)
                        .map(|e| e.name.clone())
                } else {
                    self.lru_order.pop_front()
                }
            }
        };

        if let Some(name) = to_evict {
            if let Some(data) = self.cache.remove(&name) {
                self.stats.memory_used = self.stats.memory_used.saturating_sub(data.len());
                self.stats.evictions += 1;

                // Remove from page table
                if let Some(entry) = self.manifest.get_model(&name) {
                    self.page_table.remove(entry.offset);
                }

                return true;
            }
        }

        false
    }

    /// Record an access for prediction.
    fn record_access(&mut self, name: &str) {
        if self.access_history.len() >= 10 {
            self.access_history.pop_front();
        }
        self.access_history.push_back(name.to_string());
    }

    /// Try to pre-fetch likely-needed models.
    fn try_prefetch(&mut self) {
        if self.access_history.len() < 2 {
            return;
        }

        // Simple pattern: if A -> B happened before, pre-fetch B after A
        let last = self.access_history.back().cloned();
        if let Some(last_name) = last {
            // Look for patterns in history
            let patterns: Vec<_> = self
                .access_history
                .iter()
                .zip(self.access_history.iter().skip(1))
                .filter(|(prev, _)| *prev == &last_name)
                .map(|(_, next)| next.clone())
                .take(self.config.prefetch_count)
                .collect();

            for name in patterns {
                if !self.cache.contains_key(&name)
                    && self.stats.memory_used + self.estimate_size(&name) <= self.config.max_memory
                {
                    let _ = self.load_model(&name);
                    self.stats.prefetches += 1;
                }
            }
        }
    }

    /// Estimate size of a model.
    fn estimate_size(&self, name: &str) -> usize {
        self.manifest.get_model(name).map_or(0, |e| e.size)
    }
}

impl std::fmt::Debug for PagedBundle {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PagedBundle")
            .field("models", &self.manifest.len())
            .field("cached", &self.cache.len())
            .field("memory_used", &self.stats.memory_used)
            .field("max_memory", &self.config.max_memory)
            .field("hit_rate", &self.stats.hit_rate())
            .finish_non_exhaustive()
    }
}

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
#[path = "paging_tests.rs"]
mod tests;