Skip to main content

gam_sae/corpus/
object_store.rs

1//! Object-store shard streaming with bounded prefetch (#987, extending #973).
2//!
3//! The mmap reader ([`super::shard_reader::MmapShardSource`]) assumes the whole
4//! corpus lives on a local filesystem. A frontier activation corpus
5//! (10⁹–10¹¹ rows, hundreds of TB) lives in **object storage**; no node ever
6//! holds more than a bounded window of it. This module is the seam that makes
7//! that regime a [`CorpusRowSource`] like any other — the SAE term, the ρ
8//! cascade, and the streaming border-Gram accumulation are all unchanged
9//! consumers.
10//!
11//! ## The trait, not the SDK
12//!
13//! [`ObjectStore`] is two methods: list the shard keys, fetch (a range of) one
14//! object. gam takes no cloud-SDK dependency; an S3/GCS/Azure backend is a
15//! ~20-line implementor on the caller's side. The in-tree
16//! [`FsObjectStore`] (a directory of `*.shard` files) is the reference
17//! implementation and the test double — and also the honest way to run the
18//! object-store code path against a network filesystem mount.
19//!
20//! ## Determinism contract (inherited, not re-invented)
21//!
22//! The global row order is pinned by the **lexicographically sorted key list**,
23//! exactly as the mmap source pins it by sorted file names: shard-0 rows
24//! `0..n0`, then shard-1 rows `0..n1`, … with stable global `row_id`s. Fetch
25//! latency, retries, and prefetch depth can never reorder, drop, or duplicate
26//! rows; the `(row_id, row)` sequence is byte-identical across runs, fleets,
27//! and backends, so warm-start keys ([`super::warm_state`]), subsample hashes
28//! ([`super::rho_cascade`]), and the cross-node chunk partition
29//! ([`gam_solve::cross_node`]) all agree with a local-disk run.
30//!
31//! ## Bounded prefetch, never materialize
32//!
33//! At most [`PREFETCH_SHARDS_AHEAD`] shards beyond the one being drained are
34//! resident at a time. `next_batch` serves rows out of the front shard's
35//! fetched bytes and tops the window up as shards drain; the corpus is never
36//! materialized, and the resident set is bounded by the window regardless of
37//! corpus size. (An async/multipart store hides its latency *inside*
38//! [`ObjectStore::fetch_range`]; this driver only promises it will never ask
39//! for more than the window.)
40//!
41//! ## Mandatory selectivity at this scale
42//!
43//! At object-store scale, fitting every row is not on the table: the #987
44//! contract is that the fit sees a **designed sample** whose inclusion weights
45//! are carried into the likelihood so the criterion stays unbiased (the #973
46//! subsample-honesty contract, mechanized by
47//! [`gam_solve::row_sampling_measure::RowSamplingMeasure::designed_subsample`]).
48//! [`designed_sampling_mandatory`] is the auto-derived predicate drivers
49//! consult: above the threshold a full-corpus pass is refused as a default and
50//! the designed-sample path is the only sanctioned one — selectivity is a
51//! correctness-of-economics requirement here, not an optimization.
52
53use ndarray::Array2;
54use std::collections::VecDeque;
55use std::fs::File;
56use std::io::{Read, Seek, SeekFrom};
57use std::path::PathBuf;
58use std::sync::Arc;
59
60use super::shard_reader::{
61    CorpusRowSource, DEFAULT_BATCH_ROWS, DEFAULT_PREFETCH_WINDOW_BYTES, DTYPE_F32, HEADER_LEN,
62    RowBatch, SHARD_MAGIC, ShardError,
63};
64
65/// Maximum number of fully fetched shards held beyond the one currently being
66/// drained. Auto-derived policy, not a knob: one shard of look-ahead hides
67/// fetch latency for the deterministic next shard; a second guards against a
68/// short shard draining faster than the next fetch completes. Together with
69/// the per-shard payload this bounds the resident set independent of corpus
70/// size.
71pub const PREFETCH_SHARDS_AHEAD: usize = 2;
72
73/// Corpus row count at and above which designed (importance-weighted)
74/// subsampling is **mandatory** rather than optional: a fit driver seeing at
75/// least this many rows must route through
76/// [`gam_solve::row_sampling_measure::RowSamplingMeasure::designed_subsample`] and carry
77/// the inclusion weights into the likelihood, instead of attempting a
78/// full-corpus exact pass. Auto-derived threshold: 10⁸ rows is where even a
79/// single linear pass per outer iteration dominates the entire fit budget and
80/// where the #973 cascade's honest-subsample arms stop being an optimization
81/// and become the only affordable unbiased estimator.
82pub const DESIGNED_SAMPLE_MANDATORY_MIN_ROWS: u64 = 100_000_000;
83
84/// Auto-switch predicate (#987): must this corpus be fit through a designed,
85/// honesty-weighted subsample? Pure function of the row count; no flag.
86#[inline]
87pub fn designed_sampling_mandatory(total_rows: u64) -> bool {
88    total_rows >= DESIGNED_SAMPLE_MANDATORY_MIN_ROWS
89}
90
91/// Minimal object-store abstraction: list shard keys, fetch object bytes.
92///
93/// Implementors must be deterministic in *content* (the same key always yields
94/// the same bytes during one pass) but are free to be remote, retried, cached,
95/// or parallel inside. `fetch_range` has a correct-by-default implementation
96/// over `fetch`; backends with native range reads (HTTP `Range`, `pread`)
97/// should override it so header probing does not pull whole objects.
98pub trait ObjectStore: Send + Sync {
99    /// Keys of every shard object in the store. Order is irrelevant — the
100    /// source sorts lexicographically to pin the global row order.
101    fn list_shards(&self) -> Result<Vec<String>, ShardError>;
102
103    /// Fetch the full bytes of one object.
104    fn fetch(&self, key: &str) -> Result<Vec<u8>, ShardError>;
105
106    /// Fetch `len` bytes starting at `offset` (clamped to object end). The
107    /// default fetches the whole object and slices — correct everywhere,
108    /// efficient nowhere; override for real backends.
109    fn fetch_range(&self, key: &str, offset: u64, len: usize) -> Result<Vec<u8>, ShardError> {
110        let full = self.fetch(key)?;
111        let start = (offset as usize).min(full.len());
112        let end = start.saturating_add(len).min(full.len());
113        Ok(full[start..end].to_vec())
114    }
115}
116
117/// Reference [`ObjectStore`]: a local directory of `*.shard` files. Doubles as
118/// the test backend and as the adapter for network-filesystem mounts.
119pub struct FsObjectStore {
120    root: PathBuf,
121}
122
123impl FsObjectStore {
124    pub fn new(root: PathBuf) -> Self {
125        Self { root }
126    }
127
128    fn path_of(&self, key: &str) -> PathBuf {
129        self.root.join(key)
130    }
131}
132
133impl ObjectStore for FsObjectStore {
134    fn list_shards(&self) -> Result<Vec<String>, ShardError> {
135        let mut keys = Vec::new();
136        for entry in std::fs::read_dir(&self.root)? {
137            let entry = entry?;
138            let path = entry.path();
139            if path.extension().and_then(|e| e.to_str()) == Some("shard") {
140                if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
141                    keys.push(name.to_string());
142                }
143            }
144        }
145        Ok(keys)
146    }
147
148    fn fetch(&self, key: &str) -> Result<Vec<u8>, ShardError> {
149        let mut bytes = Vec::new();
150        File::open(self.path_of(key))?.read_to_end(&mut bytes)?;
151        Ok(bytes)
152    }
153
154    fn fetch_range(&self, key: &str, offset: u64, len: usize) -> Result<Vec<u8>, ShardError> {
155        let mut file = File::open(self.path_of(key))?;
156        let total = file.metadata()?.len();
157        let start = offset.min(total);
158        let take = (len as u64).min(total - start) as usize;
159        file.seek(SeekFrom::Start(start))?;
160        let mut buf = vec![0u8; take];
161        file.read_exact(&mut buf)?;
162        Ok(buf)
163    }
164}
165
166/// Parsed `v1` shard header (see [`super::shard_reader`] for the layout).
167#[derive(Clone, Debug)]
168struct ShardMeta {
169    key: String,
170    n_rows: usize,
171    /// Global id of this shard's first row in the concatenated stream.
172    global_row_base: u64,
173}
174
175/// One fetched, resident shard: its payload bytes (header stripped) plus which
176/// shard of the key sequence it is.
177struct ResidentShard {
178    /// Index into `ObjectStoreShardSource::shards`.
179    shard_idx: usize,
180    /// Raw little-endian `f32` payload (`n_rows * p * 4` bytes).
181    payload: Vec<u8>,
182}
183
184/// Parse and validate a `v1` shard header out of its first [`HEADER_LEN`]
185/// bytes, returning `(n_rows, p)`.
186fn parse_header(key: &str, header: &[u8]) -> Result<(usize, usize), ShardError> {
187    let path = PathBuf::from(key);
188    if header.len() < HEADER_LEN {
189        return Err(ShardError::Truncated {
190            path,
191            expected: HEADER_LEN,
192            actual: header.len(),
193        });
194    }
195    if header[0..8] != SHARD_MAGIC {
196        return Err(ShardError::BadMagic { path });
197    }
198    let n_rows = u64::from_le_bytes(header[8..16].try_into().expect("8 bytes")) as usize;
199    let p = u64::from_le_bytes(header[16..24].try_into().expect("8 bytes")) as usize;
200    let dtype = u32::from_le_bytes(header[24..28].try_into().expect("4 bytes"));
201    if dtype != DTYPE_F32 {
202        return Err(ShardError::BadDtype { path, tag: dtype });
203    }
204    Ok((n_rows, p))
205}
206
207/// A [`CorpusRowSource`] streaming `v1` shards out of an [`ObjectStore`] with a
208/// bounded prefetch window. See the module docs for the determinism and
209/// residency contracts.
210pub struct ObjectStoreShardSource {
211    store: Arc<dyn ObjectStore>,
212    /// Shard metadata in lexicographic key order — the pinned global row order.
213    shards: Vec<ShardMeta>,
214    p: usize,
215    total_rows: u64,
216    batch_rows: usize,
217    /// Resident fetched shards, front = the shard currently being drained.
218    /// Invariant: contiguous shard indices starting at `cursor_shard`, length
219    /// ≤ 1 + [`PREFETCH_SHARDS_AHEAD`].
220    window: VecDeque<ResidentShard>,
221    /// Index into `shards` of the shard currently being drained.
222    cursor_shard: usize,
223    /// Local row within `shards[cursor_shard]` to read next.
224    cursor_local_row: usize,
225}
226
227impl ObjectStoreShardSource {
228    /// Open a source over every shard the store lists. Headers are probed with
229    /// 32-byte range reads so `total_rows` / `width` are known up front without
230    /// fetching any payload.
231    pub fn open(store: Arc<dyn ObjectStore>) -> Result<Self, ShardError> {
232        let mut keys = store.list_shards()?;
233        // Sort by key bytes: deterministic and independent of the order the
234        // store returns listings in — the exact discipline of the mmap source.
235        keys.sort();
236        if keys.is_empty() {
237            return Err(ShardError::Empty);
238        }
239        let mut shards = Vec::with_capacity(keys.len());
240        let mut p: Option<usize> = None;
241        let mut running_base: u64 = 0;
242        for key in keys {
243            let header = store.fetch_range(&key, 0, HEADER_LEN)?;
244            let (n_rows, shard_p) = parse_header(&key, &header)?;
245            match p {
246                None => p = Some(shard_p),
247                Some(expected) if expected != shard_p => {
248                    return Err(ShardError::WidthMismatch {
249                        expected,
250                        found: shard_p,
251                        path: PathBuf::from(&key),
252                    });
253                }
254                Some(_) => {}
255            }
256            shards.push(ShardMeta {
257                key,
258                n_rows,
259                global_row_base: running_base,
260            });
261            running_base = running_base.saturating_add(n_rows as u64);
262        }
263        let p = p.ok_or(ShardError::Empty)?;
264        let total_rows = running_base;
265        if total_rows == 0 {
266            return Err(ShardError::Empty);
267        }
268        // Same auto-derived batch sizing as the mmap source; additionally cap
269        // by the shared prefetch-window byte budget so one batch's rows always
270        // fit inside the policy window even for very wide activations.
271        let row_bytes = p.max(1) * std::mem::size_of::<f32>();
272        let window_rows = (DEFAULT_PREFETCH_WINDOW_BYTES / row_bytes).max(1);
273        let batch_rows = DEFAULT_BATCH_ROWS
274            .min(total_rows as usize)
275            .min(window_rows)
276            .max(1);
277        Ok(Self {
278            store,
279            shards,
280            p,
281            total_rows,
282            batch_rows,
283            window: VecDeque::new(),
284            cursor_shard: 0,
285            cursor_local_row: 0,
286        })
287    }
288
289    /// True once every row of every shard has been yielded.
290    #[inline]
291    fn at_end(&self) -> bool {
292        self.cursor_shard >= self.shards.len()
293    }
294
295    /// Fetch one shard's payload and validate its length against the header.
296    fn fetch_shard(&self, shard_idx: usize) -> Result<ResidentShard, ShardError> {
297        let meta = &self.shards[shard_idx];
298        let payload_len = meta
299            .n_rows
300            .checked_mul(self.p)
301            .and_then(|cells| cells.checked_mul(std::mem::size_of::<f32>()))
302            .ok_or_else(|| ShardError::Truncated {
303                path: PathBuf::from(&meta.key),
304                expected: usize::MAX,
305                actual: 0,
306            })?;
307        let payload = self
308            .store
309            .fetch_range(&meta.key, HEADER_LEN as u64, payload_len)?;
310        if payload.len() < payload_len {
311            return Err(ShardError::Truncated {
312                path: PathBuf::from(&meta.key),
313                expected: HEADER_LEN + payload_len,
314                actual: HEADER_LEN + payload.len(),
315            });
316        }
317        Ok(ResidentShard { shard_idx, payload })
318    }
319
320    /// Ensure the window holds the current shard plus up to
321    /// [`PREFETCH_SHARDS_AHEAD`] deterministic successors. Skips empty shards
322    /// at the cursor (zero-row shards are legal; they contribute no rows).
323    fn fill_window(&mut self) -> Result<(), ShardError> {
324        // Advance the cursor past drained / empty shards first.
325        while self.cursor_shard < self.shards.len()
326            && self.cursor_local_row >= self.shards[self.cursor_shard].n_rows
327        {
328            self.cursor_shard += 1;
329            self.cursor_local_row = 0;
330            if let Some(front) = self.window.front() {
331                if front.shard_idx < self.cursor_shard {
332                    self.window.pop_front();
333                }
334            }
335        }
336        if self.at_end() {
337            self.window.clear();
338            return Ok(());
339        }
340        // Drop any stale front (defensive; reset() clears outright).
341        while let Some(front) = self.window.front() {
342            if front.shard_idx < self.cursor_shard {
343                self.window.pop_front();
344            } else {
345                break;
346            }
347        }
348        // Top up: current shard first, then look-ahead, bounded by the window.
349        let want_last = (self.cursor_shard + PREFETCH_SHARDS_AHEAD).min(self.shards.len() - 1);
350        let mut next_fetch = match self.window.back() {
351            Some(back) => back.shard_idx + 1,
352            None => self.cursor_shard,
353        };
354        while next_fetch <= want_last {
355            let resident = self.fetch_shard(next_fetch)?;
356            self.window.push_back(resident);
357            next_fetch += 1;
358        }
359        Ok(())
360    }
361}
362
363impl CorpusRowSource for ObjectStoreShardSource {
364    fn total_rows(&self) -> u64 {
365        self.total_rows
366    }
367
368    fn width(&self) -> usize {
369        self.p
370    }
371
372    fn batch_rows(&self) -> usize {
373        self.batch_rows
374    }
375
376    fn reset(&mut self) {
377        self.cursor_shard = 0;
378        self.cursor_local_row = 0;
379        self.window.clear();
380    }
381
382    fn next_batch(&mut self) -> Result<Option<RowBatch>, ShardError> {
383        self.fill_window()?;
384        if self.at_end() {
385            return Ok(None);
386        }
387        let meta = &self.shards[self.cursor_shard];
388        let front = self
389            .window
390            .front()
391            .expect("fill_window leaves the current shard resident");
392        // The window is a contiguous run of shard indices starting at the read
393        // cursor, so after a fill the front must be `cursor_shard`. A release
394        // build cannot drop this check (a stale front would read the wrong
395        // payload against `meta`'s row metadata and silently corrupt the
396        // batch), so it is a real error rather than a `debug_assert`.
397        if front.shard_idx != self.cursor_shard {
398            return Err(ShardError::ResidencyInvariant {
399                cursor_shard: self.cursor_shard,
400                front_shard: front.shard_idx,
401            });
402        }
403
404        // A batch never crosses a shard boundary (same contract as the mmap
405        // source): contiguous rows of one payload, bounded by the batch size.
406        let remaining = meta.n_rows - self.cursor_local_row;
407        let take = self.batch_rows.min(remaining);
408        let row_bytes = self.p * std::mem::size_of::<f32>();
409        let mut rows = Array2::<f64>::zeros((take, self.p));
410        let mut row_ids = Vec::with_capacity(take);
411        for k in 0..take {
412            let local = self.cursor_local_row + k;
413            let start = local * row_bytes;
414            let bytes = &front.payload[start..start + row_bytes];
415            let mut row_view = rows.row_mut(k);
416            let slice = row_view
417                .as_slice_mut()
418                .expect("freshly allocated contiguous row");
419            for (c, slot) in slice.iter_mut().enumerate() {
420                let b = c * std::mem::size_of::<f32>();
421                let lane = f32::from_le_bytes(bytes[b..b + 4].try_into().expect("4 bytes"));
422                *slot = f64::from(lane);
423            }
424            row_ids.push(meta.global_row_base + local as u64);
425        }
426        self.cursor_local_row += take;
427        Ok(Some(RowBatch { rows, row_ids }))
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::super::shard_reader::{MmapShardSource, encode_shard_bytes};
434    use super::*;
435    use ndarray::array;
436    use std::io::Write;
437    use std::sync::Mutex;
438
439    fn temp_store(name: &str, shards: &[(&str, Array2<f64>)]) -> PathBuf {
440        let mut dir = std::env::temp_dir();
441        dir.push(format!(
442            "gam-sae-objstore-test-{}-{}",
443            std::process::id(),
444            name
445        ));
446        std::fs::create_dir_all(&dir).expect("create store dir");
447        for (key, rows) in shards {
448            let bytes = encode_shard_bytes(rows.view());
449            let mut f = File::create(dir.join(key)).expect("create shard");
450            f.write_all(&bytes).expect("write shard");
451            f.sync_all().expect("sync shard");
452        }
453        dir
454    }
455
456    fn drain(src: &mut dyn CorpusRowSource) -> (Vec<u64>, Vec<f64>) {
457        let mut ids = Vec::new();
458        let mut vals = Vec::new();
459        while let Some(batch) = src.next_batch().expect("batch") {
460            ids.extend(batch.row_ids.iter().copied());
461            vals.extend(batch.rows.iter().copied());
462        }
463        (ids, vals)
464    }
465
466    #[test]
467    fn object_store_replays_the_mmap_row_sequence_exactly() {
468        // The same shard set, read via the object-store source and via the
469        // mmap source, must yield byte-identical (row_id, row) sequences —
470        // the backend is invisible to every downstream determinism contract.
471        let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
472        let b = array![[7.0_f64, 8.0], [9.0, 10.0]];
473        let dir = temp_store("parity", &[("a.shard", a), ("b.shard", b)]);
474
475        let store = Arc::new(FsObjectStore::new(dir.clone()));
476        let mut remote = ObjectStoreShardSource::open(store).expect("open object-store source");
477        let mut local = MmapShardSource::open_dir(&dir).expect("open mmap source");
478
479        assert_eq!(remote.total_rows(), local.total_rows());
480        assert_eq!(remote.width(), local.width());
481        let (ids_r, vals_r) = drain(&mut remote);
482        let (ids_l, vals_l) = drain(&mut local);
483        assert_eq!(ids_r, ids_l);
484        assert_eq!(
485            vals_r.iter().map(|v| v.to_bits()).collect::<Vec<_>>(),
486            vals_l.iter().map(|v| v.to_bits()).collect::<Vec<_>>(),
487            "object-store rows must be bit-identical to mmap rows"
488        );
489
490        // reset() replays the identical sequence.
491        remote.reset();
492        let (ids_again, vals_again) = drain(&mut remote);
493        assert_eq!(ids_again, ids_r);
494        assert_eq!(vals_again, vals_r);
495        std::fs::remove_dir_all(&dir).ok();
496    }
497
498    /// A store wrapper that counts whole-shard fetches and asserts the
499    /// bounded-window contract: the source never holds fetch results for more
500    /// than `1 + PREFETCH_SHARDS_AHEAD` shards at once (we can't observe its
501    /// memory directly, but we can observe that fetches happen lazily in key
502    /// order rather than all up front).
503    struct CountingStore {
504        inner: FsObjectStore,
505        payload_fetches: Mutex<Vec<String>>,
506    }
507
508    impl ObjectStore for CountingStore {
509        fn list_shards(&self) -> Result<Vec<String>, ShardError> {
510            self.inner.list_shards()
511        }
512        fn fetch(&self, key: &str) -> Result<Vec<u8>, ShardError> {
513            self.inner.fetch(key)
514        }
515        fn fetch_range(&self, key: &str, offset: u64, len: usize) -> Result<Vec<u8>, ShardError> {
516            if offset as usize >= HEADER_LEN {
517                self.payload_fetches.lock().unwrap().push(key.to_string());
518            }
519            self.inner.fetch_range(key, offset, len)
520        }
521    }
522
523    #[test]
524    fn prefetch_is_bounded_and_in_key_order() {
525        // 6 single-row shards; with PREFETCH_SHARDS_AHEAD = 2 the first batch
526        // may trigger at most 3 payload fetches, and fetches arrive in key
527        // order.
528        let mk = |v: f64| array![[v]];
529        let dir = temp_store(
530            "bounded",
531            &[
532                ("s0.shard", mk(0.0)),
533                ("s1.shard", mk(1.0)),
534                ("s2.shard", mk(2.0)),
535                ("s3.shard", mk(3.0)),
536                ("s4.shard", mk(4.0)),
537                ("s5.shard", mk(5.0)),
538            ],
539        );
540        let store = Arc::new(CountingStore {
541            inner: FsObjectStore::new(dir.clone()),
542            payload_fetches: Mutex::new(Vec::new()),
543        });
544        let mut src =
545            ObjectStoreShardSource::open(Arc::clone(&store) as Arc<dyn ObjectStore>).expect("open");
546        let first = src.next_batch().expect("batch").expect("some");
547        assert_eq!(first.row_ids, vec![0]);
548        {
549            let fetched = store.payload_fetches.lock().unwrap();
550            assert!(
551                fetched.len() <= 1 + PREFETCH_SHARDS_AHEAD,
552                "first batch fetched {} shard payloads; window allows {}",
553                fetched.len(),
554                1 + PREFETCH_SHARDS_AHEAD
555            );
556            let mut sorted = fetched.clone();
557            sorted.sort();
558            assert_eq!(*fetched, sorted, "payload fetches must be in key order");
559        }
560        // Draining the rest touches every shard exactly once.
561        let (ids, _) = drain(&mut src);
562        assert_eq!(ids, vec![1, 2, 3, 4, 5]);
563        let fetched = store.payload_fetches.lock().unwrap();
564        assert_eq!(fetched.len(), 6, "each shard payload fetched exactly once");
565        std::fs::remove_dir_all(&dir).ok();
566    }
567
568    #[test]
569    fn mandatory_selectivity_threshold_is_pure_and_monotone() {
570        assert!(!designed_sampling_mandatory(0));
571        assert!(!designed_sampling_mandatory(
572            DESIGNED_SAMPLE_MANDATORY_MIN_ROWS - 1
573        ));
574        assert!(designed_sampling_mandatory(
575            DESIGNED_SAMPLE_MANDATORY_MIN_ROWS
576        ));
577        assert!(designed_sampling_mandatory(u64::MAX));
578    }
579
580    #[test]
581    fn width_mismatch_is_rejected() {
582        let dir = temp_store(
583            "width",
584            &[
585                ("a.shard", array![[1.0_f64, 2.0]]),
586                ("b.shard", array![[3.0_f64]]),
587            ],
588        );
589        let store = Arc::new(FsObjectStore::new(dir.clone()));
590        let err = ObjectStoreShardSource::open(store);
591        assert!(matches!(err, Err(ShardError::WidthMismatch { .. })));
592        std::fs::remove_dir_all(&dir).ok();
593    }
594}