Skip to main content

gam_sae/corpus/
shard_reader.rs

1//! mmap-backed activation-shard reader with bounded prefetch (#973).
2//!
3//! # On-disk shard format (`v1`)
4//!
5//! A shard is a single file holding a fixed header followed by a row-major
6//! payload. All integers are little-endian. The header is exactly 32 bytes:
7//!
8//! ```text
9//! offset  size  field
10//! 0       8     magic            = SHARD_MAGIC  (b"GAMSAE01")
11//! 8       8     n_rows  (u64)    number of activation rows in this shard
12//! 16      8     p       (u64)    columns per row (activation width)
13//! 24      4     dtype   (u32)    DTYPE_F32 = 0  (the only payload encoding)
14//! 28      4     reserved(u32)    = 0  (header padding to 32 bytes)
15//! 32      ..    payload          n_rows * p contiguous little-endian f32
16//! ```
17//!
18//! The reader memory-maps the file read-only and reads rows directly out of the
19//! mapped payload, upcasting each `f32` lane to `f64` on demand (the
20//! mixed-precision-storage contract — see [`super::kernels`]). `f32` storage
21//! halves on-disk and page-cache footprint versus `f64` so an
22//! out-of-core corpus streams without ever materializing as dense `f64`.
23//!
24//! # Determinism contract
25//!
26//! [`CorpusRowSource`] yields batches in a **fixed global row order**
27//! (shard-0 rows `0..n0`, then shard-1 rows `0..n1`, …), assigning each row a
28//! stable global `row_id`. This order is *independent of OS readahead*: the
29//! bounded prefetch below only touches pages we are about to read in that same
30//! deterministic order — it never reorders, drops, or duplicates rows, and the
31//! sequence of `(row_id, row)` pairs is byte-identical across runs and
32//! platforms. That stable `row_id` is what [`super::warm_state`] keys its
33//! per-row warm starts on and what [`super::rho_cascade`] hashes to pick a
34//! subsample.
35//!
36//! # Bounded prefetch
37//!
38//! `prefetch_window_bytes` caps how far ahead of the current read cursor the
39//! reader hints the OS to fault pages in. This keeps the resident set bounded
40//! (we do not want the kernel pulling an entire multi-GiB shard into RAM) while
41//! still hiding fault latency for the rows we are about to consume next. The
42//! hint is advisory: on platforms without `madvise` we simply touch the first
43//! byte of each upcoming page to warm it. Correctness never depends on it.
44
45use memmap2::Mmap;
46use ndarray::Array2;
47use std::fs::File;
48use std::path::{Path, PathBuf};
49use std::sync::Arc;
50
51/// Magic bytes identifying a `v1` activation shard.
52pub const SHARD_MAGIC: [u8; 8] = *b"GAMSAE01";
53/// `dtype` tag for an `f32` row-major payload (the only supported encoding).
54pub const DTYPE_F32: u32 = 0;
55/// Fixed header length in bytes.
56pub const HEADER_LEN: usize = 32;
57
58/// Default bounded read-ahead window. Auto-derived; not a CLI knob. Large
59/// enough to hide fault latency for a healthy batch, small enough that the
60/// resident set stays bounded regardless of shard size. Shared with the
61/// object-store source ([`super::object_store`]) so both backends apply the
62/// same bounded-prefetch policy.
63pub(super) const DEFAULT_PREFETCH_WINDOW_BYTES: usize = 8 * 1024 * 1024;
64
65/// Default number of rows handed back per [`CorpusRowSource::next_batch`].
66/// Auto-derived from the activation width at open time; this is the floor.
67/// Shared with [`super::object_store`].
68pub(super) const DEFAULT_BATCH_ROWS: usize = 1024;
69
70#[derive(Debug)]
71pub enum ShardError {
72    Io(std::io::Error),
73    BadMagic {
74        path: PathBuf,
75    },
76    BadDtype {
77        path: PathBuf,
78        tag: u32,
79    },
80    Truncated {
81        path: PathBuf,
82        expected: usize,
83        actual: usize,
84    },
85    /// Two shards in one source disagree on activation width `p`.
86    WidthMismatch {
87        expected: usize,
88        found: usize,
89        path: PathBuf,
90    },
91    /// After a window fill the front resident shard does not match the read
92    /// cursor. The window is contractually a contiguous run of shard indices
93    /// starting at `cursor_shard`, so the front must equal it; a mismatch is an
94    /// internal window-maintenance logic error. Surfaced (rather than silently
95    /// reading a payload from one shard against another shard's row metadata)
96    /// so corruption never reaches a returned batch.
97    ResidencyInvariant {
98        cursor_shard: usize,
99        front_shard: usize,
100    },
101    Empty,
102}
103
104impl std::fmt::Display for ShardError {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            ShardError::Io(e) => write!(f, "shard I/O error: {e}"),
108            ShardError::BadMagic { path } => {
109                write!(f, "shard '{}' has wrong magic header", path.display())
110            }
111            ShardError::BadDtype { path, tag } => write!(
112                f,
113                "shard '{}' has unsupported dtype tag {tag} (only f32={DTYPE_F32})",
114                path.display()
115            ),
116            ShardError::Truncated {
117                path,
118                expected,
119                actual,
120            } => write!(
121                f,
122                "shard '{}' is truncated: header expects {expected} bytes, file has {actual}",
123                path.display()
124            ),
125            ShardError::WidthMismatch {
126                expected,
127                found,
128                path,
129            } => write!(
130                f,
131                "shard '{}' has width p={found}, expected p={expected}",
132                path.display()
133            ),
134            ShardError::ResidencyInvariant {
135                cursor_shard,
136                front_shard,
137            } => write!(
138                f,
139                "shard window residency invariant violated: read cursor is at shard {cursor_shard} but the window front is shard {front_shard}"
140            ),
141            ShardError::Empty => write!(f, "shard source has no shards / no rows"),
142        }
143    }
144}
145
146impl std::error::Error for ShardError {}
147
148impl From<std::io::Error> for ShardError {
149    fn from(e: std::io::Error) -> Self {
150        ShardError::Io(e)
151    }
152}
153
154/// One deterministic batch of activation rows.
155///
156/// `row_ids[k]` is the stable global id of `rows.row(k)`; ids are contiguous
157/// within a batch and strictly increasing across the whole stream.
158#[derive(Debug, Clone)]
159pub struct RowBatch {
160    /// `(batch_rows × p)` upcast `f64` activations.
161    pub rows: Array2<f64>,
162    /// Global row id of each row in `rows`, same length as `rows.nrows()`.
163    pub row_ids: Vec<u64>,
164}
165
166impl RowBatch {
167    #[inline]
168    pub fn len(&self) -> usize {
169        self.row_ids.len()
170    }
171
172    #[inline]
173    pub fn is_empty(&self) -> bool {
174        self.row_ids.is_empty()
175    }
176}
177
178/// A deterministic, restartable source of activation row batches.
179///
180/// This is one half of the seam ([`super::warm_state::RowWarmCache`] is the
181/// other) that the streaming SAE term will consume. The contract:
182///
183/// * `next_batch` yields rows in a **fixed global order** with stable
184///   `row_id`s, independent of OS readahead, until the corpus is exhausted
185///   (then `Ok(None)`).
186/// * `reset` rewinds to the first row so a new outer ρ pass replays the exact
187///   same `(row_id, row)` sequence — the every-step-is-a-full-corpus-pass
188///   contract [`super::rho_cascade`] relies on.
189/// * `total_rows` / `width` are known up front (from shard headers) so callers
190///   can size accumulators before the first read.
191pub trait CorpusRowSource {
192    /// Total rows across every shard in this source.
193    fn total_rows(&self) -> u64;
194    /// Activation width `p` (columns per row).
195    fn width(&self) -> usize;
196    /// Yield the next deterministic batch, or `Ok(None)` at end of corpus.
197    fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError>;
198    /// Rewind to the first row so the next `next_batch` replays from the start.
199    fn reset(&mut self);
200    /// Rows handed back per `next_batch` (may be smaller for the final batch).
201    fn batch_rows(&self) -> usize;
202}
203
204/// A single memory-mapped shard.
205struct MappedShard {
206    mmap: Arc<Mmap>,
207    n_rows: usize,
208    p: usize,
209    data_offset: usize,
210    /// Global id of this shard's first row in the concatenated stream.
211    global_row_base: u64,
212}
213
214impl MappedShard {
215    fn open(path: PathBuf) -> Result<Self, ShardError> {
216        let file = File::open(&path)?;
217        // SAFETY: shards are read-only training artifacts; this module never
218        // mutates the mapping and the caller's contract is no concurrent
219        // writers. Matches the existing `PcaScoresMemmapDesignOperator` usage.
220        let mmap = unsafe { Mmap::map(&file)? };
221        if mmap.len() < HEADER_LEN {
222            return Err(ShardError::Truncated {
223                path,
224                expected: HEADER_LEN,
225                actual: mmap.len(),
226            });
227        }
228        if mmap[0..8] != SHARD_MAGIC {
229            return Err(ShardError::BadMagic { path });
230        }
231        let n_rows = u64::from_le_bytes(mmap[8..16].try_into().expect("8 bytes")) as usize;
232        let p = u64::from_le_bytes(mmap[16..24].try_into().expect("8 bytes")) as usize;
233        let dtype = u32::from_le_bytes(mmap[24..28].try_into().expect("4 bytes"));
234        if dtype != DTYPE_F32 {
235            return Err(ShardError::BadDtype { path, tag: dtype });
236        }
237        let payload_bytes = n_rows
238            .checked_mul(p)
239            .and_then(|cells| cells.checked_mul(std::mem::size_of::<f32>()))
240            .ok_or_else(|| ShardError::Truncated {
241                path: path.clone(),
242                expected: usize::MAX,
243                actual: mmap.len(),
244            })?;
245        let expected = HEADER_LEN + payload_bytes;
246        if mmap.len() < expected {
247            return Err(ShardError::Truncated {
248                path,
249                expected,
250                actual: mmap.len(),
251            });
252        }
253        Ok(Self {
254            mmap: Arc::new(mmap),
255            n_rows,
256            p,
257            data_offset: HEADER_LEN,
258            global_row_base: 0,
259        })
260    }
261
262    /// Read a single row's `p` `f32` lanes, upcasting each to `f64`.
263    #[inline]
264    fn read_row_into(&self, local_row: usize, out: &mut [f64]) {
265        assert_eq!(out.len(), self.p);
266        let byte_start = self.data_offset + local_row * self.p * std::mem::size_of::<f32>();
267        let bytes = &self.mmap[byte_start..byte_start + self.p * std::mem::size_of::<f32>()];
268        for (c, slot) in out.iter_mut().enumerate() {
269            let b = c * std::mem::size_of::<f32>();
270            let lane = f32::from_le_bytes(bytes[b..b + 4].try_into().expect("4 bytes"));
271            *slot = f64::from(lane);
272        }
273    }
274
275    /// Bounded read-ahead: warm pages from `byte_start` up to (but not past)
276    /// `byte_start + window`, clamped to the shard payload. Advisory only.
277    fn prefetch(&self, byte_start: usize, window: usize) {
278        let payload_end = self.data_offset + self.n_rows * self.p * std::mem::size_of::<f32>();
279        let end = byte_start.saturating_add(window).min(payload_end);
280        if end <= byte_start {
281            return;
282        }
283        // Touch one byte per page so the kernel faults exactly the bounded
284        // window in our deterministic read order, never the whole shard via
285        // speculative readahead. `read_volatile` keeps the loads from being
286        // optimized away without mutating the read-only mapping.
287        let page = 4096usize;
288        let base = self.mmap.as_ptr();
289        let mut off = byte_start;
290        while off < end {
291            // SAFETY: `off < end <= mmap.len()`, so `base.add(off)` is in
292            // bounds of the live read-only mapping; we only read.
293            unsafe {
294                std::ptr::read_volatile(base.add(off));
295            }
296            off += page;
297        }
298    }
299}
300
301/// A [`CorpusRowSource`] over one or many shards with a bounded prefetch
302/// window.
303pub struct MmapShardSource {
304    shards: Vec<MappedShard>,
305    p: usize,
306    total_rows: u64,
307    batch_rows: usize,
308    prefetch_window_bytes: usize,
309    /// Index into `shards` of the shard currently being drained.
310    cursor_shard: usize,
311    /// Local row within `shards[cursor_shard]` to read next.
312    cursor_local_row: usize,
313}
314
315impl MmapShardSource {
316    /// Open a source over an explicit, ordered list of shard paths. The order
317    /// of `paths` *defines* the deterministic global row order, so the caller
318    /// must pass a stable ordering (e.g. lexicographically sorted filenames).
319    pub fn open(paths: &[PathBuf]) -> Result<Self, ShardError> {
320        if paths.is_empty() {
321            return Err(ShardError::Empty);
322        }
323        let mut shards = Vec::with_capacity(paths.len());
324        let mut p: Option<usize> = None;
325        let mut running_base: u64 = 0;
326        for path in paths {
327            let mut shard = MappedShard::open(path.clone())?;
328            match p {
329                None => p = Some(shard.p),
330                Some(expected) if expected != shard.p => {
331                    return Err(ShardError::WidthMismatch {
332                        expected,
333                        found: shard.p,
334                        path: path.clone(),
335                    });
336                }
337                Some(_) => {}
338            }
339            shard.global_row_base = running_base;
340            running_base = running_base.saturating_add(shard.n_rows as u64);
341            shards.push(shard);
342        }
343        let p = p.ok_or(ShardError::Empty)?;
344        let total_rows = running_base;
345        if total_rows == 0 {
346            return Err(ShardError::Empty);
347        }
348        // Auto-derive a batch row count: at least DEFAULT_BATCH_ROWS rows, and
349        // never more than the whole corpus. No CLI knob.
350        let batch_rows = DEFAULT_BATCH_ROWS.min(total_rows as usize).max(1);
351        Ok(Self {
352            shards,
353            p,
354            total_rows,
355            batch_rows,
356            prefetch_window_bytes: DEFAULT_PREFETCH_WINDOW_BYTES,
357            cursor_shard: 0,
358            cursor_local_row: 0,
359        })
360    }
361
362    /// Open a source over every `*.shard` file in `dir`, ordered by file name
363    /// (the stable, OS-independent ordering that pins the deterministic global
364    /// row sequence).
365    pub fn open_dir(dir: &Path) -> Result<Self, ShardError> {
366        let mut paths: Vec<PathBuf> = Vec::new();
367        for entry in std::fs::read_dir(dir)? {
368            let entry = entry?;
369            let path = entry.path();
370            if path.extension().and_then(|e| e.to_str()) == Some("shard") {
371                paths.push(path);
372            }
373        }
374        // Sort by file name bytes: deterministic and independent of the order
375        // the OS returns directory entries in.
376        paths.sort_by(|a, b| a.file_name().cmp(&b.file_name()));
377        if paths.is_empty() {
378            return Err(ShardError::Empty);
379        }
380        Self::open(&paths)
381    }
382
383    /// True once every row of every shard has been yielded.
384    #[inline]
385    fn at_end(&self) -> bool {
386        self.cursor_shard >= self.shards.len()
387    }
388
389    /// Advance the cursor past any fully-drained trailing shards so
390    /// `cursor_shard` either points at a shard with remaining rows or is at
391    /// `shards.len()` (end of corpus).
392    fn skip_drained_shards(&mut self) {
393        while self.cursor_shard < self.shards.len()
394            && self.cursor_local_row >= self.shards[self.cursor_shard].n_rows
395        {
396            self.cursor_shard += 1;
397            self.cursor_local_row = 0;
398        }
399    }
400}
401
402impl CorpusRowSource for MmapShardSource {
403    fn total_rows(&self) -> u64 {
404        self.total_rows
405    }
406
407    fn width(&self) -> usize {
408        self.p
409    }
410
411    fn batch_rows(&self) -> usize {
412        self.batch_rows
413    }
414
415    fn reset(&mut self) {
416        self.cursor_shard = 0;
417        self.cursor_local_row = 0;
418    }
419
420    fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError> {
421        self.skip_drained_shards();
422        if self.at_end() {
423            return Ok(None);
424        }
425        // A batch never crosses a shard boundary: it stays within the current
426        // shard and is the smaller of the configured batch size and that
427        // shard's remaining rows. This keeps row reads contiguous in one
428        // mapping and keeps the prefetch window inside one shard's payload.
429        let shard_idx = self.cursor_shard;
430        let take = {
431            let shard = &self.shards[shard_idx];
432            let remaining = shard.n_rows - self.cursor_local_row;
433            self.batch_rows.min(remaining)
434        };
435
436        // Bounded prefetch over exactly the rows we are about to read, in the
437        // same deterministic order, before touching them.
438        {
439            let shard = &self.shards[shard_idx];
440            let first_byte =
441                shard.data_offset + self.cursor_local_row * shard.p * std::mem::size_of::<f32>();
442            let want = take * shard.p * std::mem::size_of::<f32>();
443            shard.prefetch(first_byte, want.min(self.prefetch_window_bytes));
444        }
445
446        let p = self.p;
447        let mut rows = Array2::<f64>::zeros((take, p));
448        let mut row_ids = Vec::with_capacity(take);
449        {
450            let shard = &self.shards[shard_idx];
451            for k in 0..take {
452                let local = self.cursor_local_row + k;
453                let mut row_view = rows.row_mut(k);
454                let slice = row_view
455                    .as_slice_mut()
456                    .expect("freshly allocated contiguous row");
457                shard.read_row_into(local, slice);
458                row_ids.push(shard.global_row_base + local as u64);
459            }
460        }
461        self.cursor_local_row += take;
462        self.skip_drained_shards();
463        Ok(Some(RowBatch { rows, row_ids }))
464    }
465}
466
467/// Serialize a `(n_rows × p)` `f64` activation matrix to the `v1` shard byte
468/// layout, downcasting each value to `f32`.
469///
470/// This is the writer counterpart to [`MmapShardSource`] — used by ingestion
471/// tooling and by the round-trip tests below to prove the reader reproduces the
472/// exact rows it was given (up to the `f32` storage rounding the format
473/// promises). Rows are written in row-major order.
474pub fn encode_shard_bytes(rows: ndarray::ArrayView2<'_, f64>) -> Vec<u8> {
475    let n_rows = rows.nrows();
476    let p = rows.ncols();
477    let mut out = Vec::with_capacity(HEADER_LEN + n_rows * p * std::mem::size_of::<f32>());
478    out.extend_from_slice(&SHARD_MAGIC);
479    out.extend_from_slice(&(n_rows as u64).to_le_bytes());
480    out.extend_from_slice(&(p as u64).to_le_bytes());
481    out.extend_from_slice(&DTYPE_F32.to_le_bytes());
482    out.extend_from_slice(&0u32.to_le_bytes());
483    for row in rows.outer_iter() {
484        for &v in row.iter() {
485            out.extend_from_slice(&(v as f32).to_le_bytes());
486        }
487    }
488    out
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use ndarray::array;
495    use std::io::Write;
496
497    fn write_temp_shard(name: &str, rows: ndarray::ArrayView2<'_, f64>) -> PathBuf {
498        let bytes = encode_shard_bytes(rows);
499        let mut path = std::env::temp_dir();
500        path.push(format!(
501            "gam-sae-corpus-test-{}-{}.shard",
502            std::process::id(),
503            name
504        ));
505        let mut f = File::create(&path).expect("create temp shard");
506        f.write_all(&bytes).expect("write shard");
507        f.sync_all().expect("sync shard");
508        path
509    }
510
511    #[test]
512    fn single_shard_round_trips_rows_and_ids() {
513        let data = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
514        let path = write_temp_shard("single", data.view());
515        let mut src = MmapShardSource::open(&[path.clone()]).expect("open");
516        assert_eq!(src.total_rows(), 3);
517        assert_eq!(src.width(), 3);
518        let batch = src.next_batch().expect("batch").expect("some");
519        assert_eq!(batch.row_ids, vec![0, 1, 2]);
520        assert_eq!(batch.rows, data);
521        assert!(src.next_batch().expect("end").is_none());
522        std::fs::remove_file(&path).ok();
523    }
524
525    #[test]
526    fn multi_shard_global_ids_are_contiguous() {
527        let a = array![[1.0_f64], [2.0]];
528        let b = array![[3.0_f64], [4.0], [5.0]];
529        let pa = write_temp_shard("multi-a", a.view());
530        let pb = write_temp_shard("multi-b", b.view());
531        let mut src = MmapShardSource::open(&[pa.clone(), pb.clone()]).expect("open");
532        assert_eq!(src.total_rows(), 5);
533        let mut all_ids = Vec::new();
534        let mut all_vals = Vec::new();
535        while let Some(batch) = src.next_batch().expect("batch") {
536            all_ids.extend(batch.row_ids.iter().copied());
537            for r in batch.rows.outer_iter() {
538                all_vals.push(r[0]);
539            }
540        }
541        assert_eq!(all_ids, vec![0, 1, 2, 3, 4]);
542        assert_eq!(all_vals, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
543        std::fs::remove_file(&pa).ok();
544        std::fs::remove_file(&pb).ok();
545    }
546
547    #[test]
548    fn reset_replays_identical_sequence() {
549        let data = array![[1.0_f64, 1.0], [2.0, 2.0]];
550        let path = write_temp_shard("reset", data.view());
551        let mut src = MmapShardSource::open(&[path.clone()]).expect("open");
552        let first: Vec<u64> = {
553            let mut ids = Vec::new();
554            while let Some(b) = src.next_batch().expect("b") {
555                ids.extend(b.row_ids);
556            }
557            ids
558        };
559        src.reset();
560        let second: Vec<u64> = {
561            let mut ids = Vec::new();
562            while let Some(b) = src.next_batch().expect("b") {
563                ids.extend(b.row_ids);
564            }
565            ids
566        };
567        assert_eq!(first, second);
568        std::fs::remove_file(&path).ok();
569    }
570
571    #[test]
572    fn bad_magic_is_rejected() {
573        let mut path = std::env::temp_dir();
574        path.push(format!(
575            "gam-sae-corpus-badmagic-{}.shard",
576            std::process::id()
577        ));
578        let mut f = File::create(&path).expect("create");
579        f.write_all(&[0u8; 64]).expect("write");
580        f.sync_all().ok();
581        let err = MmapShardSource::open(&[path.clone()]);
582        assert!(matches!(err, Err(ShardError::BadMagic { .. })));
583        std::fs::remove_file(&path).ok();
584    }
585}