Skip to main content

gam_sae/corpus/
designed_target.rs

1//! Designed corpus target collection — the #991 bridge from a streaming
2//! [`CorpusRowSource`] to the in-memory row set + honesty weights the SAE term
3//! fits on.
4//!
5//! # The architecture this realizes
6//!
7//! At frontier scale the fit never sees the whole corpus: it sees a **designed
8//! sample** whose inclusion weights ride into the likelihood so the criterion
9//! stays unbiased (#987 / #973). That makes "fit the corpus" a two-step
10//! pipeline with a bounded memory footprint by construction:
11//!
12//! 1. **Design** — a [`RowSamplingMeasure`] over the corpus (uniform on a first
13//!    harvest; [`TieredHarvest::corpus_measure`]-driven once Fisher factors
14//!    exist) picks `budget` rows via
15//!    [`RowSamplingMeasure::designed_subsample`] (deterministic, seeded, honest `1/π`
16//!    weights).
17//! 2. **Collect** — one deterministic streaming pass over the source
18//!    materializes exactly those rows (the only dense `f64` block the fit ever
19//!    holds: `budget × p`, not `N × p`), aligned with their weights and global
20//!    `row_id`s.
21//!
22//! The term consumes the result as `(target, set_row_loss_weights)`; the
23//! weights enter the objective through the term's single `√w` honesty seam.
24//!
25//! # Exactness degeneracy (the bit-identity contract)
26//!
27//! `budget ≥ corpus rows` (always the case below
28//! [`designed_sampling_mandatory`]'s threshold unless a caller narrows it)
29//! collects **every** row in stream order with weight exactly `1.0` — and the
30//! term stores all-equal weights as `None`, so a shard-backed full-budget fit
31//! is **bit-for-bit** the in-memory fit of the same rows. Selectivity is then
32//! purely a budget decision, not a code path: drivers call this
33//! unconditionally and let [`auto_designed_budget`] decide.
34
35use ndarray::Array2;
36
37use super::object_store::designed_sampling_mandatory;
38use super::shard_reader::CorpusRowSource;
39use crate::inference::harvest::TieredHarvest;
40use gam_solve::row_sampling_measure::{MeasureProvenance, RowSamplingMeasure};
41
42/// Default designed-sample budget once [`designed_sampling_mandatory`] fires.
43/// Auto-derived policy, not a knob: 2·10⁶ rows is comfortably in-memory at any
44/// realistic activation width (`2e6 × 4096 × 8B ≈ 64 GiB` is the extreme; at
45/// GPT-2-small widths it is ~6 GiB), large enough that designed-sample SEs on
46/// shared structure are far below fit noise, and small enough that an outer
47/// iteration's full pass over the *sample* is minutes, not days.
48pub const DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS: usize = 2_000_000;
49
50/// Auto-derive the collection budget from the corpus size (#991,
51/// magic-by-default): below the [`designed_sampling_mandatory`] threshold the
52/// budget is the whole corpus (the exact pass); at or above it, the designed
53/// default budget.
54pub fn auto_designed_budget(total_rows: u64) -> usize {
55    if designed_sampling_mandatory(total_rows) {
56        DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
57    } else {
58        total_rows as usize
59    }
60}
61
62/// The collected designed row set: the dense fit target plus everything needed
63/// to keep the fit honest and traceable back to the corpus.
64#[derive(Debug, Clone)]
65pub struct DesignedCorpusTarget {
66    /// `(n_selected × p)` upcast activations of exactly the designed rows, in
67    /// ascending global row order.
68    pub target: Array2<f64>,
69    /// Global corpus `row_id` of each target row (ascending). These are the
70    /// keys for warm-state reuse ([`super::warm_state`]) and for aligning a
71    /// [`TieredHarvest`] Fisher tier with the fitted rows.
72    pub row_ids: Vec<u64>,
73    /// Per-selected-row Horvitz–Thompson likelihood weight `1/π`, aligned with
74    /// `target` rows. Hand to `SaeManifoldTerm::set_row_loss_weights` (which
75    /// mean-normalizes; an exact full pass yields all-`1.0` here and the
76    /// unweighted path there).
77    pub likelihood_weights: Vec<f64>,
78    /// Provenance of the measure that shaped the design.
79    pub provenance: MeasureProvenance,
80    /// Total corpus rows the design was drawn from.
81    pub corpus_rows: u64,
82}
83
84impl DesignedCorpusTarget {
85    /// Number of collected rows.
86    pub fn len(&self) -> usize {
87        self.row_ids.len()
88    }
89
90    pub fn is_empty(&self) -> bool {
91        self.row_ids.is_empty()
92    }
93
94    /// Whether selectivity actually engaged (a proper subsample) or the
95    /// collection was the exact full pass.
96    pub fn is_designed_subsample(&self) -> bool {
97        (self.len() as u64) < self.corpus_rows
98    }
99}
100
101/// Collect a designed target from a streaming source.
102///
103/// `measure` is the design measure over the corpus rows (`None` ⇒ uniform —
104/// the first-harvest cold start). `budget` rows are selected via
105/// [`RowSamplingMeasure::designed_subsample`] (deterministic in `(measure, budget,
106/// seed)`), then materialized in one deterministic pass. The source is
107/// `reset()` before reading, so the call is idempotent across ρ passes.
108pub fn collect_designed_target(
109    source: &mut dyn CorpusRowSource,
110    measure: Option<&RowSamplingMeasure>,
111    budget: usize,
112    seed: u64,
113) -> Result<DesignedCorpusTarget, String> {
114    let corpus_rows = source.total_rows();
115    let p = source.width();
116    let n = usize::try_from(corpus_rows)
117        .map_err(|_| "collect_designed_target: corpus row count exceeds usize".to_string())?;
118    let uniform;
119    let measure = match measure {
120        Some(m) => {
121            if m.n_rows() != n {
122                return Err(format!(
123                    "collect_designed_target: measure covers {} rows but the corpus has {n}",
124                    m.n_rows()
125                ));
126            }
127            m
128        }
129        None => {
130            uniform = RowSamplingMeasure::uniform(n);
131            &uniform
132        }
133    };
134    let sample = measure.designed_subsample(budget, seed);
135    let n_sel = sample.rows.len();
136    let mut target = Array2::<f64>::zeros((n_sel, p));
137    let mut row_ids = Vec::with_capacity(n_sel);
138
139    source.reset();
140    // Two-pointer walk: batches arrive in ascending global row order and
141    // `sample.rows` is ascending, so each selected row is matched exactly once.
142    let mut next_sel = 0usize;
143    while next_sel < n_sel {
144        let Some(batch) = source
145            .next_batch()
146            .map_err(|e| format!("collect_designed_target: shard read failed: {e}"))?
147        else {
148            break;
149        };
150        for (k, &rid) in batch.row_ids.iter().enumerate() {
151            if next_sel >= n_sel {
152                break;
153            }
154            if rid == sample.rows[next_sel] as u64 {
155                target.row_mut(next_sel).assign(&batch.rows.row(k));
156                row_ids.push(rid);
157                next_sel += 1;
158            }
159        }
160    }
161    if next_sel != n_sel {
162        return Err(format!(
163            "collect_designed_target: stream ended after matching {next_sel} of {n_sel} \
164             designed rows (corpus declared {corpus_rows} rows)"
165        ));
166    }
167    Ok(DesignedCorpusTarget {
168        target,
169        row_ids,
170        likelihood_weights: sample.likelihood_weights,
171        provenance: sample.provenance,
172        corpus_rows,
173    })
174}
175
176/// Fully magic entry point: budget from [`auto_designed_budget`], uniform
177/// first-harvest measure. Below the mandatory-selectivity threshold this is
178/// the exact full pass (weights ≡ 1.0).
179pub fn collect_designed_target_auto(
180    source: &mut dyn CorpusRowSource,
181    seed: u64,
182) -> Result<DesignedCorpusTarget, String> {
183    let budget = auto_designed_budget(source.total_rows());
184    collect_designed_target(source, None, budget, seed)
185}
186
187/// Harvest-loop entry point: design the collection from a previous harvest's
188/// lifted Fisher measure ([`TieredHarvest::corpus_measure`] — uniform when the
189/// harvest has no Fisher tier, so the cold start degenerates to
190/// [`collect_designed_target_auto`]'s design).
191pub fn collect_designed_target_from_harvest(
192    source: &mut dyn CorpusRowSource,
193    harvest: &TieredHarvest,
194    budget: usize,
195    seed: u64,
196) -> Result<DesignedCorpusTarget, String> {
197    let measure = harvest.corpus_measure();
198    collect_designed_target(source, Some(&measure), budget, seed)
199}
200
201#[cfg(test)]
202mod tests {
203    use super::super::shard_reader::{MmapShardSource, encode_shard_bytes};
204    use super::*;
205    use ndarray::Array2 as NdArray2;
206    use std::io::Write;
207    use std::path::PathBuf;
208
209    fn planted_rows(n: usize, p: usize) -> NdArray2<f64> {
210        NdArray2::from_shape_fn((n, p), |(i, j)| {
211            let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
212            (x.sin() * 43_758.547).fract() * 2.0 - 1.0
213        })
214    }
215
216    fn temp_shard_dir(name: &str, rows: &NdArray2<f64>, split_at: usize) -> PathBuf {
217        let mut dir = std::env::temp_dir();
218        dir.push(format!(
219            "gam-designed-target-test-{}-{}",
220            std::process::id(),
221            name
222        ));
223        std::fs::create_dir_all(&dir).expect("create dir");
224        let parts = [
225            ("a.shard", rows.slice(ndarray::s![..split_at, ..])),
226            ("b.shard", rows.slice(ndarray::s![split_at.., ..])),
227        ];
228        for (key, part) in parts {
229            let bytes = encode_shard_bytes(part);
230            let mut f = std::fs::File::create(dir.join(key)).expect("create shard");
231            f.write_all(&bytes).expect("write shard");
232            f.sync_all().expect("sync");
233        }
234        dir
235    }
236
237    #[test]
238    fn full_budget_collects_every_row_bit_for_bit_with_unit_weights() {
239        let n = 137;
240        let p = 5;
241        let rows = planted_rows(n, p);
242        let dir = temp_shard_dir("full", &rows, 60);
243        let mut src = MmapShardSource::open_dir(&dir).expect("open");
244        let collected = collect_designed_target_auto(&mut src, 7).expect("collect");
245
246        assert!(!collected.is_designed_subsample());
247        assert_eq!(collected.row_ids, (0..n as u64).collect::<Vec<_>>());
248        assert!(collected.likelihood_weights.iter().all(|&w| w == 1.0));
249        // Bit-identity to the f32-storage round-trip of the source rows: the
250        // collection adds nothing on top of the shard format's own rounding.
251        let stored = rows.mapv(|v| f64::from(v as f32));
252        for (a, b) in collected.target.iter().zip(stored.iter()) {
253            assert_eq!(a.to_bits(), b.to_bits());
254        }
255        std::fs::remove_dir_all(&dir).ok();
256    }
257
258    #[test]
259    fn designed_budget_collects_exactly_the_designed_rows_with_their_weights() {
260        let n = 200;
261        let p = 3;
262        let rows = planted_rows(n, p);
263        let dir = temp_shard_dir("designed", &rows, 90);
264        let mut src = MmapShardSource::open_dir(&dir).expect("open");
265
266        let budget = 40usize;
267        let seed = 17u64;
268        let collected = collect_designed_target(&mut src, None, budget, seed).expect("collect");
269        assert!(collected.is_designed_subsample());
270
271        // The selection must be the measure's own design, row for row,
272        // weight for weight.
273        let sample = RowSamplingMeasure::uniform(n).designed_subsample(budget, seed);
274        assert_eq!(
275            collected.row_ids,
276            sample.rows.iter().map(|&r| r as u64).collect::<Vec<_>>()
277        );
278        assert_eq!(collected.likelihood_weights, sample.likelihood_weights);
279
280        // Each collected row is bitwise the corpus row it claims to be.
281        let stored = rows.mapv(|v| f64::from(v as f32));
282        for (k, &rid) in collected.row_ids.iter().enumerate() {
283            for c in 0..p {
284                assert_eq!(
285                    collected.target[[k, c]].to_bits(),
286                    stored[[rid as usize, c]].to_bits(),
287                    "row {rid} col {c}"
288                );
289            }
290        }
291
292        // Deterministic: same (measure, budget, seed) ⇒ identical collection.
293        let again = collect_designed_target(&mut src, None, budget, seed).expect("collect again");
294        assert_eq!(again.row_ids, collected.row_ids);
295        for (a, b) in again.target.iter().zip(collected.target.iter()) {
296            assert_eq!(a.to_bits(), b.to_bits());
297        }
298        std::fs::remove_dir_all(&dir).ok();
299    }
300
301    #[test]
302    fn measure_dimension_mismatch_is_rejected() {
303        let rows = planted_rows(20, 2);
304        let dir = temp_shard_dir("mismatch", &rows, 10);
305        let mut src = MmapShardSource::open_dir(&dir).expect("open");
306        let wrong = RowSamplingMeasure::uniform(7);
307        let err = collect_designed_target(&mut src, Some(&wrong), 5, 1)
308            .expect_err("mismatched measure must be rejected");
309        assert!(err.contains("covers 7 rows"), "got: {err}");
310        std::fs::remove_dir_all(&dir).ok();
311    }
312
313    #[test]
314    fn auto_budget_is_exact_below_threshold_and_bounded_above_it() {
315        assert_eq!(auto_designed_budget(1_000), 1_000);
316        assert_eq!(
317            auto_designed_budget(99_999_999),
318            99_999_999,
319            "below the mandatory threshold the budget is the whole corpus"
320        );
321        assert_eq!(
322            auto_designed_budget(100_000_000),
323            DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
324        );
325        assert_eq!(
326            auto_designed_budget(u64::MAX),
327            DESIGNED_SAMPLE_DEFAULT_BUDGET_ROWS
328        );
329    }
330}