Skip to main content

oxillama_runtime/offload/
pager.rs

1//! LRU weight pager — the core of the CPU/disk offload system.
2//!
3//! [`LayerPager`] manages a budget-limited resident set of tensor bytes.
4//! When a tensor is needed and not resident, it is loaded from the backing
5//! [`PagerSource`] and the least-recently-used non-pinned tensor is evicted
6//! first (if needed to stay within the RAM budget).
7//!
8//! # Thread safety
9//!
10//! [`LayerPager`] is `Send + Sync`. The resident map is protected by an
11//! [`RwLock`], the LRU queue by a [`Mutex`]. All lock operations use
12//! `.map_err(|_| RuntimeError::LockPoisoned)?` — no `unwrap()` anywhere.
13
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::{Arc, Mutex, RwLock};
17
18use crate::error::RuntimeError;
19
20// ─────────────────────────────────────────────────────────────────────────────
21// TensorId
22// ─────────────────────────────────────────────────────────────────────────────
23
24/// Opaque identifier for a weight tensor.
25///
26/// Typically the tensor name as it appears in the GGUF file, e.g.
27/// `"blk.0.attn_q.weight"`.
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub struct TensorId(pub String);
30
31impl TensorId {
32    /// Construct a [`TensorId`] from any string-like value.
33    pub fn new(name: impl Into<String>) -> Self {
34        Self(name.into())
35    }
36}
37
38impl std::fmt::Display for TensorId {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.write_str(&self.0)
41    }
42}
43
44// ─────────────────────────────────────────────────────────────────────────────
45// ResidentTensor
46// ─────────────────────────────────────────────────────────────────────────────
47
48/// A tensor that is currently resident in RAM.
49///
50/// The bytes are stored as a reference-counted slice. The dequantization step
51/// happens outside the pager — callers pass `data` directly to the fused GEMM
52/// kernels (or dequant on-the-fly as in the existing arch layer).
53pub struct ResidentTensor {
54    /// Raw quantized bytes from the GGUF weight store.
55    pub data: Arc<[u8]>,
56    /// Byte length of `data` (cached to avoid the indirection).
57    pub size_bytes: usize,
58}
59
60impl std::fmt::Debug for ResidentTensor {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("ResidentTensor")
63            .field("size_bytes", &self.size_bytes)
64            .field("data_len", &self.data.len())
65            .finish()
66    }
67}
68
69// ─────────────────────────────────────────────────────────────────────────────
70// TensorEntry
71// ─────────────────────────────────────────────────────────────────────────────
72
73/// Location of a tensor in the backing weight file.
74#[derive(Debug, Clone)]
75pub struct TensorEntry {
76    /// Absolute byte offset within the backing source.
77    pub file_offset: u64,
78    /// Number of bytes occupied by this tensor.
79    pub size_bytes: usize,
80}
81
82// ─────────────────────────────────────────────────────────────────────────────
83// PagerSource
84// ─────────────────────────────────────────────────────────────────────────────
85
86/// Abstraction over the backing byte store for weight data.
87///
88/// Unlike the GGUF [`Source`][oxillama_gguf::source::Source] trait (which uses
89/// an associated error type), `PagerSource` uses [`RuntimeError`] directly for
90/// easy `dyn` dispatch.
91pub trait PagerSource: Send + Sync {
92    /// Read exactly `out.len()` bytes starting at `offset` into `out`.
93    ///
94    /// Returns `Err(RuntimeError::OffloadEof)` if `offset + out.len()` exceeds
95    /// the total size of the source.
96    fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError>;
97
98    /// Total size of the backing store in bytes.
99    fn total_size_bytes(&self) -> u64;
100}
101
102// ─────────────────────────────────────────────────────────────────────────────
103// FilePagerSource
104// ─────────────────────────────────────────────────────────────────────────────
105
106/// File-backed [`PagerSource`] using `std::fs::File` seek + read.
107///
108/// Opens a new file descriptor for each [`read_bytes_at`] call, which is safe
109/// across threads but not optimal for high-frequency reads. For throughput-
110/// critical use cases, prefer [`MmapPagerSource`] (requires `mmap` feature).
111///
112/// [`read_bytes_at`]: FilePagerSource::read_bytes_at
113pub struct FilePagerSource {
114    path: std::path::PathBuf,
115    total_bytes: u64,
116}
117
118impl FilePagerSource {
119    /// Open a file-backed pager source at `path`.
120    ///
121    /// Reads the file metadata to determine total size; fails if the file
122    /// cannot be opened or its metadata cannot be queried.
123    pub fn open(path: impl Into<std::path::PathBuf>) -> Result<Self, RuntimeError> {
124        let path = path.into();
125        let meta = std::fs::metadata(&path)?;
126        Ok(Self {
127            total_bytes: meta.len(),
128            path,
129        })
130    }
131}
132
133impl PagerSource for FilePagerSource {
134    fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
135        use std::io::{Read, Seek, SeekFrom};
136        let mut file = std::fs::File::open(&self.path)?;
137        file.seek(SeekFrom::Start(offset))?;
138        file.read_exact(out)?;
139        Ok(())
140    }
141
142    fn total_size_bytes(&self) -> u64 {
143        self.total_bytes
144    }
145}
146
147// ─────────────────────────────────────────────────────────────────────────────
148// MmapPagerSource (optional, requires `mmap` feature)
149// ─────────────────────────────────────────────────────────────────────────────
150
151/// Memory-mapped [`PagerSource`] — faster random access than seek+read.
152///
153/// Requires the `mmap` feature. The mmap is created read-only and is safe to
154/// share across threads. OS-level demand paging may still cause page faults
155/// on access, but the `read_bytes_at` implementation is zero-copy after the
156/// initial map.
157#[cfg(feature = "mmap")]
158pub struct MmapPagerSource {
159    mmap: Arc<memmap2::Mmap>,
160}
161
162#[cfg(feature = "mmap")]
163impl MmapPagerSource {
164    /// Open and memory-map the file at `path`.
165    pub fn open(path: impl AsRef<std::path::Path>) -> Result<Self, RuntimeError> {
166        let file = std::fs::File::open(path)?;
167        // Safety: the file is read-only. We do not mutate the mmap. If the
168        // underlying file is modified by another process while we hold the
169        // mmap, behaviour is unspecified but not UB in the Rust sense because
170        // we only ever read through shared refs.
171        let mmap = unsafe { memmap2::Mmap::map(&file)? };
172        Ok(Self {
173            mmap: Arc::new(mmap),
174        })
175    }
176}
177
178#[cfg(feature = "mmap")]
179impl PagerSource for MmapPagerSource {
180    fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
181        let start = offset as usize;
182        let end = start
183            .checked_add(out.len())
184            .ok_or(RuntimeError::OffloadEof {
185                offset,
186                needed: out.len(),
187                available: 0,
188            })?;
189        if end > self.mmap.len() {
190            let available = self.mmap.len().saturating_sub(start);
191            return Err(RuntimeError::OffloadEof {
192                offset,
193                needed: out.len(),
194                available,
195            });
196        }
197        out.copy_from_slice(&self.mmap[start..end]);
198        Ok(())
199    }
200
201    fn total_size_bytes(&self) -> u64 {
202        self.mmap.len() as u64
203    }
204}
205
206// ─────────────────────────────────────────────────────────────────────────────
207// LayerPager
208// ─────────────────────────────────────────────────────────────────────────────
209
210/// LRU weight pager — evicts cold tensors to free RAM, loads from `source` on demand.
211///
212/// # Pinned tensors
213///
214/// Tensors whose [`TensorId`] is in the `pinned` set are never evicted.
215/// If pinned tensors alone exceed the budget, eviction will exhaust all
216/// non-pinned candidates and then stop (no error); the budget acts as a
217/// best-effort target.
218///
219/// # Acquiring tensors
220///
221/// Call [`acquire`][LayerPager::acquire] to get an `Arc<ResidentTensor>` for a
222/// tensor by ID. The result shares ownership with the pager's resident map, so
223/// the bytes stay alive as long as either the pager or the caller holds a
224/// reference — eviction only removes the pager's entry, not the data itself.
225pub struct LayerPager {
226    source: Arc<dyn PagerSource>,
227    tensor_map: HashMap<TensorId, TensorEntry>,
228    resident: RwLock<HashMap<TensorId, Arc<ResidentTensor>>>,
229    pinned: HashSet<TensorId>,
230    lru: Mutex<VecDeque<TensorId>>,
231    budget_bytes: u64,
232    resident_bytes: AtomicU64,
233}
234
235impl LayerPager {
236    /// Create a new pager.
237    ///
238    /// # Parameters
239    ///
240    /// - `source` — backing weight store (file or mmap).
241    /// - `tensor_map` — map from [`TensorId`] to file offset + size.
242    /// - `budget_bytes` — maximum resident bytes (use `u64::MAX` for unlimited).
243    /// - `pinned` — set of tensor IDs that are never evicted.
244    pub fn new(
245        source: Arc<dyn PagerSource>,
246        tensor_map: HashMap<TensorId, TensorEntry>,
247        budget_bytes: u64,
248        pinned: HashSet<TensorId>,
249    ) -> Self {
250        Self {
251            source,
252            tensor_map,
253            resident: RwLock::new(HashMap::new()),
254            pinned,
255            lru: Mutex::new(VecDeque::new()),
256            budget_bytes,
257            resident_bytes: AtomicU64::new(0),
258        }
259    }
260
261    /// Acquire a tensor, loading it from the source if not currently resident.
262    ///
263    /// The LRU order is updated on every successful acquire. If the tensor
264    /// must be loaded, the pager first evicts non-pinned tensors until the
265    /// budget allows the new allocation, then reads the bytes from `source`.
266    ///
267    /// # Errors
268    ///
269    /// - [`RuntimeError::TensorNotFound`] — `id` is not in the tensor map.
270    /// - [`RuntimeError::LockPoisoned`] — an internal lock was poisoned.
271    /// - [`RuntimeError::OffloadEof`] / [`RuntimeError::Io`] — source read failed.
272    pub fn acquire(&self, id: &TensorId) -> Result<Arc<ResidentTensor>, RuntimeError> {
273        // Fast path: already resident.
274        {
275            let guard = self
276                .resident
277                .read()
278                .map_err(|_| RuntimeError::LockPoisoned)?;
279            if let Some(_tensor) = guard.get(id) {
280                // Update LRU position.
281                drop(guard);
282                self.bump_lru(id)?;
283                // Re-read after dropping the read guard so we avoid holding
284                // both read and LRU lock simultaneously.
285                let guard2 = self
286                    .resident
287                    .read()
288                    .map_err(|_| RuntimeError::LockPoisoned)?;
289                // Tensor may theoretically have been evicted between drop and
290                // re-read in a race; fall through to slow path if that happened.
291                if let Some(tensor2) = guard2.get(id) {
292                    return Ok(Arc::clone(tensor2));
293                }
294            }
295        }
296
297        // Slow path: not resident — load from source.
298        let entry = self
299            .tensor_map
300            .get(id)
301            .ok_or_else(|| RuntimeError::TensorNotFound(id.0.clone()))?;
302
303        // Evict until the new tensor fits.
304        self.evict_to_fit(entry.size_bytes)?;
305
306        // Read from backing source.
307        let mut data = vec![0u8; entry.size_bytes];
308        self.source.read_bytes_at(entry.file_offset, &mut data)?;
309
310        let tensor = Arc::new(ResidentTensor {
311            data: data.into(),
312            size_bytes: entry.size_bytes,
313        });
314
315        // Insert into resident map and append to LRU tail.
316        {
317            let mut guard = self
318                .resident
319                .write()
320                .map_err(|_| RuntimeError::LockPoisoned)?;
321            guard.insert(id.clone(), Arc::clone(&tensor));
322        }
323        {
324            let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
325            lru.push_back(id.clone());
326        }
327        self.resident_bytes
328            .fetch_add(entry.size_bytes as u64, Ordering::Relaxed);
329
330        Ok(tensor)
331    }
332
333    /// Evict non-pinned LRU tensors until `needed_bytes` can fit within the budget.
334    fn evict_to_fit(&self, needed_bytes: usize) -> Result<(), RuntimeError> {
335        loop {
336            let current = self.resident_bytes.load(Ordering::Relaxed);
337            if current.saturating_add(needed_bytes as u64) <= self.budget_bytes {
338                break;
339            }
340
341            // Find the oldest non-pinned entry.
342            let victim = {
343                let lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
344                lru.iter().find(|id| !self.pinned.contains(*id)).cloned()
345            };
346
347            match victim {
348                // Nothing left to evict (all resident tensors are pinned).
349                None => break,
350                Some(victim_id) => {
351                    let removed = {
352                        let mut guard = self
353                            .resident
354                            .write()
355                            .map_err(|_| RuntimeError::LockPoisoned)?;
356                        guard.remove(&victim_id)
357                    };
358                    if let Some(evicted) = removed {
359                        self.resident_bytes
360                            .fetch_sub(evicted.size_bytes as u64, Ordering::Relaxed);
361                        let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
362                        lru.retain(|x| x != &victim_id);
363                    }
364                }
365            }
366        }
367        Ok(())
368    }
369
370    /// Move `id` to the tail of the LRU queue (mark as most-recently-used).
371    fn bump_lru(&self, id: &TensorId) -> Result<(), RuntimeError> {
372        let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
373        if let Some(pos) = lru.iter().position(|x| x == id) {
374            lru.remove(pos);
375        }
376        lru.push_back(id.clone());
377        Ok(())
378    }
379
380    /// Return the number of bytes currently resident in RAM.
381    pub fn resident_bytes(&self) -> u64 {
382        self.resident_bytes.load(Ordering::Relaxed)
383    }
384
385    /// Return the number of tensors currently resident in RAM.
386    pub fn resident_count(&self) -> usize {
387        self.resident.read().map(|g| g.len()).unwrap_or(0)
388    }
389
390    /// Return the configured RAM budget in bytes.
391    pub fn budget_bytes(&self) -> u64 {
392        self.budget_bytes
393    }
394
395    /// Check whether a tensor is currently resident.
396    pub fn is_resident(&self, id: &TensorId) -> bool {
397        self.resident
398            .read()
399            .map(|g| g.contains_key(id))
400            .unwrap_or(false)
401    }
402
403    /// Check whether a tensor is pinned (will never be evicted).
404    pub fn is_pinned(&self, id: &TensorId) -> bool {
405        self.pinned.contains(id)
406    }
407}
408
409// ─────────────────────────────────────────────────────────────────────────────
410// Tests
411// ─────────────────────────────────────────────────────────────────────────────
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use std::collections::{HashMap, HashSet};
417    use std::io::Write;
418    use std::sync::Arc;
419
420    // ── In-memory PagerSource for tests ──────────────────────────────────────
421
422    struct VecPagerSource(Vec<u8>);
423
424    impl PagerSource for VecPagerSource {
425        fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
426            let start = offset as usize;
427            let end = start
428                .checked_add(out.len())
429                .ok_or(RuntimeError::OffloadEof {
430                    offset,
431                    needed: out.len(),
432                    available: 0,
433                })?;
434            if end > self.0.len() {
435                let available = self.0.len().saturating_sub(start);
436                return Err(RuntimeError::OffloadEof {
437                    offset,
438                    needed: out.len(),
439                    available,
440                });
441            }
442            out.copy_from_slice(&self.0[start..end]);
443            Ok(())
444        }
445
446        fn total_size_bytes(&self) -> u64 {
447            self.0.len() as u64
448        }
449    }
450
451    // ── Helper to build a pager with synthetic data ───────────────────────────
452
453    fn make_pager(
454        tensors: &[(&str, usize, u64)],
455        budget: u64,
456        pinned: &[&str],
457    ) -> (LayerPager, Vec<u8>) {
458        let total: usize = tensors
459            .iter()
460            .map(|(_, sz, offset)| *offset as usize + *sz)
461            .max()
462            .unwrap_or(0)
463            + 1;
464        let mut data = vec![0u8; total];
465        let mut tensor_map = HashMap::new();
466        for (id, size, offset) in tensors {
467            for i in 0..*size {
468                data[*offset as usize + i] = (i % 256) as u8;
469            }
470            tensor_map.insert(
471                TensorId(id.to_string()),
472                TensorEntry {
473                    file_offset: *offset,
474                    size_bytes: *size,
475                },
476            );
477        }
478        let pinned_set: HashSet<TensorId> =
479            pinned.iter().map(|s| TensorId(s.to_string())).collect();
480        let pager = LayerPager::new(
481            Arc::new(VecPagerSource(data.clone())),
482            tensor_map,
483            budget,
484            pinned_set,
485        );
486        (pager, data)
487    }
488
489    // ── Test: basic eviction ──────────────────────────────────────────────────
490
491    #[test]
492    fn offload_budget_evicts_coldest() {
493        // 3 tensors, 100 bytes each; budget = 200 bytes (fits 2)
494        let (pager, _) = make_pager(
495            &[
496                ("layer_0", 100, 0),
497                ("layer_1", 100, 100),
498                ("layer_2", 100, 200),
499            ],
500            200,
501            &[],
502        );
503
504        let _t0 = pager.acquire(&TensorId("layer_0".into())).expect("t0");
505        let _t1 = pager.acquire(&TensorId("layer_1".into())).expect("t1");
506        assert_eq!(pager.resident_count(), 2);
507
508        // Acquiring layer_2 must evict the coldest (layer_0).
509        drop(_t0);
510        drop(_t1);
511        let _t2 = pager.acquire(&TensorId("layer_2".into())).expect("t2");
512        assert!(
513            pager.resident_bytes() <= 200,
514            "resident_bytes ({}) must be <= budget (200)",
515            pager.resident_bytes()
516        );
517    }
518
519    // ── Test: pinned tensors survive eviction ─────────────────────────────────
520
521    #[test]
522    fn offload_pinned_never_evicted() {
523        // budget = 100, fits exactly 1; layer_0 is pinned
524        let (pager, _) = make_pager(
525            &[("layer_0", 100, 0), ("layer_1", 100, 100)],
526            100,
527            &["layer_0"],
528        );
529
530        let _t0 = pager.acquire(&TensorId("layer_0".into())).expect("pinned");
531        // Acquiring layer_1 must not evict layer_0
532        let _t1 = pager.acquire(&TensorId("layer_1".into())).expect("cold");
533
534        assert!(
535            pager.is_resident(&TensorId("layer_0".into())),
536            "pinned tensor must not be evicted"
537        );
538    }
539
540    // ── Test: bytes read correctly ────────────────────────────────────────────
541
542    #[test]
543    fn offload_acquire_reads_correct_bytes() {
544        let (pager, data) = make_pager(&[("t0", 64, 128)], u64::MAX, &[]);
545        let tensor = pager.acquire(&TensorId("t0".into())).expect("t0");
546        assert_eq!(tensor.data.len(), 64);
547        assert_eq!(&tensor.data[..], &data[128..192]);
548    }
549
550    // ── Test: unknown tensor returns error ────────────────────────────────────
551
552    #[test]
553    fn offload_unknown_tensor_returns_error() {
554        let (pager, _) = make_pager(&[("t0", 10, 0)], u64::MAX, &[]);
555        let res = pager.acquire(&TensorId("nonexistent".into()));
556        assert!(
557            matches!(res, Err(RuntimeError::TensorNotFound(_))),
558            "expected TensorNotFound, got {res:?}"
559        );
560    }
561
562    // ── Test: double acquire returns same bytes ───────────────────────────────
563
564    #[test]
565    fn offload_double_acquire_returns_same_bytes() {
566        let (pager, data) = make_pager(&[("t0", 32, 64)], u64::MAX, &[]);
567        let a = pager.acquire(&TensorId("t0".into())).expect("a");
568        let b = pager.acquire(&TensorId("t0".into())).expect("b");
569        assert_eq!(&a.data[..], &b.data[..]);
570        assert_eq!(&a.data[..], &data[64..96]);
571    }
572
573    // ── Test: FilePagerSource reads correct bytes ─────────────────────────────
574
575    #[test]
576    fn offload_file_pager_source_reads_correctly() {
577        let mut tmp = tempfile::NamedTempFile::new().expect("temp file");
578        let payload: Vec<u8> = (0u8..=255u8).collect();
579        tmp.write_all(&payload).expect("write");
580        let source = FilePagerSource::open(tmp.path()).expect("open");
581        let mut buf = vec![0u8; 10];
582        source.read_bytes_at(5, &mut buf).expect("read");
583        assert_eq!(&buf, &payload[5..15]);
584    }
585
586    // ── Test: FilePagerSource EOF returns error ───────────────────────────────
587
588    #[test]
589    fn offload_file_source_eof_errors() {
590        let mut tmp = tempfile::NamedTempFile::new().expect("temp file");
591        tmp.write_all(b"short").expect("write");
592        let source = FilePagerSource::open(tmp.path()).expect("open");
593        let mut buf = vec![0u8; 100];
594        let res = source.read_bytes_at(0, &mut buf);
595        assert!(res.is_err(), "reading past end of file must return Err");
596    }
597
598    // ── Test: resident_count and resident_bytes ───────────────────────────────
599
600    #[test]
601    fn offload_resident_count_tracks_evictions() {
602        let (pager, _) = make_pager(&[("a", 50, 0), ("b", 50, 50), ("c", 50, 100)], 100, &[]);
603        assert_eq!(pager.resident_count(), 0);
604        let _a = pager.acquire(&TensorId("a".into())).expect("a");
605        assert_eq!(pager.resident_count(), 1);
606        let _b = pager.acquire(&TensorId("b".into())).expect("b");
607        assert_eq!(pager.resident_count(), 2);
608        // c should evict a
609        let _c = pager.acquire(&TensorId("c".into())).expect("c");
610        assert!(pager.resident_count() <= 2, "budget limits to 2 tensors");
611        assert!(
612            pager.resident_bytes() <= 100,
613            "resident_bytes must not exceed budget"
614        );
615    }
616
617    // ── Test: is_pinned ───────────────────────────────────────────────────────
618
619    #[test]
620    fn offload_is_pinned_reflects_set() {
621        let (pager, _) = make_pager(&[("a", 10, 0), ("b", 10, 10)], u64::MAX, &["a"]);
622        assert!(pager.is_pinned(&TensorId("a".into())));
623        assert!(!pager.is_pinned(&TensorId("b".into())));
624    }
625
626    // ── Test: multiple evictions stay within budget ───────────────────────────
627
628    #[test]
629    fn offload_budget_strictly_respected() {
630        let budget = 50u64;
631        let (pager, _) = make_pager(
632            &[
633                ("t0", 50, 0),
634                ("t1", 50, 50),
635                ("t2", 50, 100),
636                ("t3", 50, 150),
637            ],
638            budget,
639            &[],
640        );
641        for name in ["t0", "t1", "t2", "t3"] {
642            let _ = pager.acquire(&TensorId(name.into())).expect(name);
643            assert!(
644                pager.resident_bytes() <= budget,
645                "after acquiring {name}, resident_bytes={} > budget={budget}",
646                pager.resident_bytes()
647            );
648        }
649    }
650
651    // ── Test: tensor_id display ───────────────────────────────────────────────
652
653    #[test]
654    fn tensor_id_display() {
655        let id = TensorId::new("blk.0.attn_q.weight");
656        assert_eq!(id.to_string(), "blk.0.attn_q.weight");
657    }
658}