Skip to main content

gam_solve/
streaming_border.rs

1//! Streaming, deterministic, out-of-core border-Gram accumulation (#973).
2//!
3//! Corpus-scale joint fits cannot hold the activation row set in memory: the
4//! Schur **border Gram** `G = Σ_n x_n x_nᵀ` (with `x_n ∈ ℝ^k` the row's border
5//! coordinates) must be accumulated over fixed-size row **chunks** streamed
6//! from disk shards. Because the methodological program (replicate nulls,
7//! resumable workflows) rests on determinism, the accumulation here is
8//! **bit-reproducible by construction**, not by luck:
9//!
10//! * The chunk partition is a pure function of `(n_rows, chunk_size)` — chunk
11//!   `j` covers rows `[j·chunk_size, min((j+1)·chunk_size, n_rows))`.
12//! * Each within-chunk Gram entry is a [`pairwise_sum`] over the chunk's rows
13//!   (the already-landed deterministic pairwise tree of
14//!   [`gam_linalg::pairwise_reduce`]).
15//! * Cross-chunk reduction follows the **same fixed pairwise tree** (the
16//!   [`StreamingPairwise`](gam_linalg::pairwise_reduce::StreamingPairwise)
17//!   cascade, applied entry-wise to whole chunk Grams): sequential base blocks
18//!   of [`CROSS_CHUNK_BASE`] chunk partials, then power-of-two cascade merges.
19//!   The tree shape depends only on the chunk count — never on values, device
20//!   timing, or thread scheduling. A unit test pins the cross-chunk
21//!   association bit-for-bit to [`pairwise_sum`] over the per-chunk entries.
22//! * Chunks may be **submitted in any order** (e.g. shards finishing on
23//!   different devices at different times): every chunk is keyed by its chunk
24//!   index, the in-order fold frontier advances eagerly, and out-of-order
25//!   arrivals wait in a pending buffer. The final Gram is a pure function of
26//!   the row content alone — identical bits for any submission order.
27//!
28//! All accumulation buffers are **f64** (the mixed-precision policy of #973:
29//! per-row kernels may run f32 upstream, but everything feeding evidence
30//! accumulates in f64 — this module exposes no f32 accumulation path at all).
31//!
32//! The accumulation state — partial Grams (in-order fold forest + pending
33//! out-of-order chunk partials) plus the chunk cursor — serializes to a
34//! [`BorderGramCheckpoint`] and resumes via [`StreamingBorderGram::resume`],
35//! with resume-equals-straight-through guaranteed (and unit-tested) at the
36//! bit level.
37//!
38//! Pure library: no SAE coupling, no flags, no environment variables. Drivers
39//! that also need a right-hand side `Σ_n x_n y_n` stack the response columns
40//! onto the border coordinates (`[X | Y]`) and read the cross block of the
41//! returned Gram; per-row weights `w_n` are pre-scaled into the rows as
42//! `√w_n · x_n` by the caller.
43
44use gam_linalg::pairwise_reduce::{BASE_CHUNK, pairwise_sum};
45use ndarray::{Array2, ArrayView2};
46use serde::{Deserialize, Serialize};
47use std::collections::BTreeMap;
48
49/// Base-block size of the **cross-chunk** pairwise tree, in chunk partials.
50///
51/// Pinned to the landed [`BASE_CHUNK`] of
52/// [`gam_linalg::pairwise_reduce`] so that the entry-wise association order
53/// of the cross-chunk fold is bit-identical to [`pairwise_sum`] over the
54/// per-chunk entry values (unit-tested below). A pure compile-time constant:
55/// the tree shape never depends on tuning, platform, or runtime conditions.
56pub const CROSS_CHUNK_BASE: usize = BASE_CHUNK;
57
58/// Serializable accumulation state of a [`StreamingBorderGram`]: the partial
59/// Grams plus the chunk cursor. Writing this to disk after every accepted
60/// chunk makes a preempted multi-hour pass resumable instead of restartable;
61/// [`StreamingBorderGram::resume`] reconstructs the accumulator with
62/// bit-identical future behavior (resume-equals-straight-through).
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct BorderGramCheckpoint {
65    /// Border dimension `k` (columns of every submitted chunk).
66    pub border_dim: usize,
67    /// Total row count of the full pass.
68    pub n_rows: usize,
69    /// Fixed chunk size (rows per chunk; the last chunk may be shorter).
70    pub chunk_size: usize,
71    /// Chunk cursor: number of chunks already folded into the in-order
72    /// cascade. Chunk indices `< frontier` are consumed; the next in-order
73    /// fold is chunk `frontier`.
74    pub frontier: usize,
75    /// Sequential partial of the current (unsealed) cross-chunk base block,
76    /// flattened `k·k` row-major. `None` iff `block_len == 0`.
77    pub block_partial: Option<Vec<f64>>,
78    /// Number of chunk partials folded into `block_partial`
79    /// (`0..CROSS_CHUNK_BASE`).
80    pub block_len: usize,
81    /// Completed cascade subtrees: `(weight in chunks, flattened k·k partial)`
82    /// with strictly decreasing power-of-two-multiple-of-base weights, bottom
83    /// to top — exactly the `StreamingPairwise` forest invariant.
84    pub forest: Vec<(usize, Vec<f64>)>,
85    /// Out-of-order chunk partials waiting for the frontier to reach them:
86    /// `(chunk_index, flattened k·k chunk Gram)`, all indices `> frontier`.
87    pub pending: Vec<(usize, Vec<f64>)>,
88}
89
90/// Chunked, out-of-core, bit-reproducible border-Gram accumulator.
91///
92/// Accumulates `G = Σ_n x_n x_nᵀ ∈ ℝ^{k×k}` over `n_rows` rows submitted as
93/// fixed-size chunks (any submission order), with f64 accumulation throughout
94/// and a deterministic pairwise reduction tree whose shape is a pure function
95/// of `(n_rows, chunk_size)`. See the module docs for the determinism
96/// contract.
97pub struct StreamingBorderGram {
98    border_dim: usize,
99    n_rows: usize,
100    chunk_size: usize,
101    /// Next chunk index expected by the in-order cascade fold.
102    frontier: usize,
103    /// Sequential partial of the current cross-chunk base block.
104    block_partial: Option<Vec<f64>>,
105    /// Chunk partials folded into `block_partial` so far.
106    block_len: usize,
107    /// Completed cascade subtrees `(weight in chunks, partial)`.
108    forest: Vec<(usize, Vec<f64>)>,
109    /// Out-of-order chunk partials keyed by chunk index (all `> frontier`).
110    pending: BTreeMap<usize, Vec<f64>>,
111}
112
113/// Entry-wise in-place accumulation `acc[i] += rhs[i]`.
114///
115/// IEEE-754 addition is commutative, so `acc + rhs` and `rhs + acc` are
116/// bit-identical; only the *association grouping* matters for reproducibility,
117/// and that is fixed by the cascade structure of the caller.
118fn add_into(acc: &mut [f64], rhs: &[f64]) {
119    for (a, r) in acc.iter_mut().zip(rhs.iter()) {
120        *a += *r;
121    }
122}
123
124/// Deterministic per-chunk Gram contribution, flattened `k·k` row-major, with
125/// `k = rows.ncols()`. Entry `(a, b)` is the [`pairwise_sum`] of
126/// `x_i[a]·x_i[b]` over the chunk's rows in row order; the symmetric mirror
127/// entry reuses the same products in the same order, so the matrix is bitwise
128/// symmetric.
129///
130/// Exposed as a free function so a **remote producer** (a worker node in the
131/// cross-node reduction, [`crate::cross_node`]) can compute exactly the
132/// partial this accumulator would have computed from the same rows, then ship
133/// the `k·k` partial instead of the rows. Bit-identical by construction to the
134/// in-process path: [`StreamingBorderGram::submit_chunk`] routes through this
135/// same function.
136pub fn chunk_gram_flat(rows: ArrayView2<'_, f64>) -> Vec<f64> {
137    let k = rows.ncols();
138    let r = rows.nrows();
139    let mut gram = vec![0.0_f64; k * k];
140    let mut products = vec![0.0_f64; r];
141    for a in 0..k {
142        for b in a..k {
143            for (i, p) in products.iter_mut().enumerate() {
144                *p = rows[[i, a]] * rows[[i, b]];
145            }
146            let s = pairwise_sum(&products);
147            gram[a * k + b] = s;
148            gram[b * k + a] = s;
149        }
150    }
151    gram
152}
153
154impl StreamingBorderGram {
155    /// Create an empty accumulator for `n_rows` total rows of border dimension
156    /// `border_dim`, streamed in chunks of `chunk_size` rows.
157    pub fn new(border_dim: usize, n_rows: usize, chunk_size: usize) -> Result<Self, String> {
158        if border_dim == 0 {
159            return Err("StreamingBorderGram: border_dim must be positive".to_string());
160        }
161        if chunk_size == 0 {
162            return Err("StreamingBorderGram: chunk_size must be positive".to_string());
163        }
164        Ok(Self {
165            border_dim,
166            n_rows,
167            chunk_size,
168            frontier: 0,
169            block_partial: None,
170            block_len: 0,
171            forest: Vec::new(),
172            pending: BTreeMap::new(),
173        })
174    }
175
176    /// Total number of chunks of the pass: `ceil(n_rows / chunk_size)`.
177    pub fn n_chunks(&self) -> usize {
178        self.n_rows.div_ceil(self.chunk_size)
179    }
180
181    /// Row range covered by chunk `chunk_index`:
182    /// `[chunk_index·chunk_size, min((chunk_index+1)·chunk_size, n_rows))`.
183    /// A pure function of the partition parameters — the caller slices its
184    /// shard rows with exactly this range.
185    pub fn chunk_rows(&self, chunk_index: usize) -> std::ops::Range<usize> {
186        let lo = chunk_index * self.chunk_size;
187        let hi = ((chunk_index + 1) * self.chunk_size).min(self.n_rows);
188        lo..hi
189    }
190
191    /// Number of chunks already consumed by the in-order cascade (the chunk
192    /// cursor). Pending out-of-order chunks are not counted.
193    pub fn frontier(&self) -> usize {
194        self.frontier
195    }
196
197    /// `true` once every chunk of the pass has been submitted.
198    pub fn is_complete(&self) -> bool {
199        self.frontier == self.n_chunks() && self.pending.is_empty()
200    }
201
202    /// Submit the rows of chunk `chunk_index` (shape
203    /// `(chunk_rows(chunk_index).len(), border_dim)`).
204    ///
205    /// Chunks may arrive in **any order**; each may be submitted exactly once.
206    /// The per-chunk Gram contribution is computed immediately (each entry a
207    /// [`pairwise_sum`] over the chunk's rows, in row order), so the caller's
208    /// row buffer can be dropped/remapped right after this returns.
209    pub fn submit_chunk(
210        &mut self,
211        chunk_index: usize,
212        rows: ArrayView2<'_, f64>,
213    ) -> Result<(), String> {
214        let n_chunks = self.n_chunks();
215        if chunk_index >= n_chunks {
216            return Err(format!(
217                "StreamingBorderGram: chunk index {chunk_index} out of range (n_chunks = {n_chunks})"
218            ));
219        }
220        if chunk_index < self.frontier || self.pending.contains_key(&chunk_index) {
221            return Err(format!(
222                "StreamingBorderGram: chunk {chunk_index} was already submitted"
223            ));
224        }
225        let expected_rows = self.chunk_rows(chunk_index).len();
226        if rows.nrows() != expected_rows || rows.ncols() != self.border_dim {
227            return Err(format!(
228                "StreamingBorderGram: chunk {chunk_index} has shape ({}, {}) but expected ({}, {})",
229                rows.nrows(),
230                rows.ncols(),
231                expected_rows,
232                self.border_dim
233            ));
234        }
235        let gram = self.chunk_gram(rows);
236        self.fold_or_park(chunk_index, gram);
237        Ok(())
238    }
239
240    /// Submit chunk `chunk_index` as a **precomputed** per-chunk Gram partial
241    /// (flattened `k·k` row-major), produced by [`chunk_gram_flat`] over exactly
242    /// the rows of [`Self::chunk_rows`]`(chunk_index)`.
243    ///
244    /// This is the cross-node ingestion seam ([`crate::cross_node`]):
245    /// a worker node computes its chunks' partials locally and ships the `k·k`
246    /// values; the coordinator folds them through the **same** fixed in-order
247    /// cascade as row-level submission, so the result is bit-identical to a
248    /// single process having seen all the rows. The validation here is
249    /// structural (index range, duplicate, partial length); the *content*
250    /// contract — that the partial really is `chunk_gram_flat` of the chunk's
251    /// rows — is the producer's, enforced by routing both producers through the
252    /// one free function.
253    pub fn submit_chunk_gram(&mut self, chunk_index: usize, gram: Vec<f64>) -> Result<(), String> {
254        let n_chunks = self.n_chunks();
255        if chunk_index >= n_chunks {
256            return Err(format!(
257                "StreamingBorderGram: chunk index {chunk_index} out of range (n_chunks = {n_chunks})"
258            ));
259        }
260        if chunk_index < self.frontier || self.pending.contains_key(&chunk_index) {
261            return Err(format!(
262                "StreamingBorderGram: chunk {chunk_index} was already submitted"
263            ));
264        }
265        let kk = self.border_dim * self.border_dim;
266        if gram.len() != kk {
267            return Err(format!(
268                "StreamingBorderGram: chunk {chunk_index} partial has len {} but expected {kk}",
269                gram.len()
270            ));
271        }
272        if !gram.iter().all(|v| v.is_finite()) {
273            return Err(format!(
274                "StreamingBorderGram: chunk {chunk_index} partial contains non-finite entries"
275            ));
276        }
277        self.fold_or_park(chunk_index, gram);
278        Ok(())
279    }
280
281    /// Fold an accepted chunk partial in-order, or park it in the pending
282    /// buffer until the frontier reaches it. Shared tail of the row-level and
283    /// gram-level submission paths so both produce identical fold behavior.
284    fn fold_or_park(&mut self, chunk_index: usize, gram: Vec<f64>) {
285        if chunk_index == self.frontier {
286            self.fold_chunk(gram);
287            self.frontier += 1;
288            // Drain any pending chunks the frontier has now reached.
289            while let Some(next) = self.pending.remove(&self.frontier) {
290                self.fold_chunk(next);
291                self.frontier += 1;
292            }
293        } else {
294            self.pending.insert(chunk_index, gram);
295        }
296    }
297
298    /// Per-chunk Gram contribution, flattened `k·k` row-major — delegates to
299    /// the shared free function [`chunk_gram_flat`] so the in-process and
300    /// cross-node producers are the same code path, bit for bit.
301    fn chunk_gram(&self, rows: ArrayView2<'_, f64>) -> Vec<f64> {
302        chunk_gram_flat(rows)
303    }
304
305    /// Fold one in-order chunk partial into the cross-chunk cascade. This is
306    /// the `StreamingPairwise` push, applied entry-wise to whole chunk Grams:
307    /// sequential accumulation within a [`CROSS_CHUNK_BASE`]-chunk base block
308    /// (seeded from the block's first partial), then power-of-two cascade
309    /// merges of completed blocks.
310    fn fold_chunk(&mut self, gram: Vec<f64>) {
311        match self.block_partial.as_mut() {
312            None => {
313                self.block_partial = Some(gram);
314                self.block_len = 1;
315            }
316            Some(acc) => {
317                add_into(acc, &gram);
318                self.block_len += 1;
319            }
320        }
321        if self.block_len == CROSS_CHUNK_BASE {
322            let block = self
323                .block_partial
324                .take()
325                .expect("block_len == CROSS_CHUNK_BASE implies a live block partial");
326            self.block_len = 0;
327            self.absorb(CROSS_CHUNK_BASE, block);
328        }
329    }
330
331    /// Merge a completed subtree partial of the given chunk-count `weight`
332    /// into the forest, cascading equal-weight merges — the exact
333    /// `StreamingPairwise::absorb` cascade, entry-wise on matrices.
334    fn absorb(&mut self, weight: usize, value: Vec<f64>) {
335        let mut w = weight;
336        let mut v = value;
337        while let Some((top_w, _)) = self.forest.last() {
338            if *top_w == w {
339                let (_, top_v) = self
340                    .forest
341                    .pop()
342                    .expect("forest top exists: just observed by last()");
343                // combine(left, right): entry-wise add (commutative bitwise).
344                v = {
345                    let mut merged = top_v;
346                    add_into(&mut merged, &v);
347                    merged
348                };
349                w = w.saturating_mul(2);
350            } else {
351                break;
352            }
353        }
354        self.forest.push((w, v));
355    }
356
357    /// Serialize the full accumulation state — partial Grams + chunk cursor —
358    /// for checkpointing. [`StreamingBorderGram::resume`] reconstructs an
359    /// accumulator whose future behavior is bit-identical to never having
360    /// stopped.
361    pub fn checkpoint(&self) -> BorderGramCheckpoint {
362        BorderGramCheckpoint {
363            border_dim: self.border_dim,
364            n_rows: self.n_rows,
365            chunk_size: self.chunk_size,
366            frontier: self.frontier,
367            block_partial: self.block_partial.clone(),
368            block_len: self.block_len,
369            forest: self.forest.clone(),
370            pending: self
371                .pending
372                .iter()
373                .map(|(idx, g)| (*idx, g.clone()))
374                .collect(),
375        }
376    }
377
378    /// Reconstruct an accumulator from a checkpoint. Validates the structural
379    /// invariants so a corrupted checkpoint is rejected loudly instead of
380    /// silently producing a wrong (but plausible-looking) Gram.
381    pub fn resume(state: BorderGramCheckpoint) -> Result<Self, String> {
382        if state.border_dim == 0 {
383            return Err("BorderGramCheckpoint: border_dim must be positive".to_string());
384        }
385        if state.chunk_size == 0 {
386            return Err("BorderGramCheckpoint: chunk_size must be positive".to_string());
387        }
388        let kk = state.border_dim * state.border_dim;
389        let n_chunks = state.n_rows.div_ceil(state.chunk_size);
390        if state.frontier > n_chunks {
391            return Err(format!(
392                "BorderGramCheckpoint: frontier {} exceeds n_chunks {n_chunks}",
393                state.frontier
394            ));
395        }
396        if state.block_len >= CROSS_CHUNK_BASE {
397            return Err(format!(
398                "BorderGramCheckpoint: block_len {} must be < CROSS_CHUNK_BASE {CROSS_CHUNK_BASE}",
399                state.block_len
400            ));
401        }
402        if state.block_partial.is_some() != (state.block_len > 0) {
403            return Err(
404                "BorderGramCheckpoint: block_partial presence inconsistent with block_len"
405                    .to_string(),
406            );
407        }
408        if let Some(b) = &state.block_partial {
409            if b.len() != kk {
410                return Err(format!(
411                    "BorderGramCheckpoint: block_partial has len {} but expected {kk}",
412                    b.len()
413                ));
414            }
415        }
416        for (w, g) in &state.forest {
417            if *w == 0 || g.len() != kk {
418                return Err(
419                    "BorderGramCheckpoint: malformed forest partial (zero weight or wrong len)"
420                        .to_string(),
421                );
422            }
423        }
424        let mut pending = BTreeMap::new();
425        for (idx, g) in state.pending {
426            if idx < state.frontier || idx >= n_chunks {
427                return Err(format!(
428                    "BorderGramCheckpoint: pending chunk index {idx} outside (frontier {}, n_chunks {n_chunks})",
429                    state.frontier
430                ));
431            }
432            if g.len() != kk {
433                return Err(format!(
434                    "BorderGramCheckpoint: pending chunk {idx} partial has len {} but expected {kk}",
435                    g.len()
436                ));
437            }
438            if pending.insert(idx, g).is_some() {
439                return Err(format!(
440                    "BorderGramCheckpoint: duplicate pending chunk index {idx}"
441                ));
442            }
443        }
444        Ok(Self {
445            border_dim: state.border_dim,
446            n_rows: state.n_rows,
447            chunk_size: state.chunk_size,
448            frontier: state.frontier,
449            block_partial: state.block_partial,
450            block_len: state.block_len,
451            forest: state.forest,
452            pending,
453        })
454    }
455
456    /// Finish the pass, returning the `k×k` border Gram. Errors if any chunk
457    /// is missing (out-of-order pending chunks the frontier never reached, or
458    /// chunks never submitted). The result is a pure function of the row
459    /// content: identical bits for any submission order and for any
460    /// checkpoint/resume history.
461    pub fn finish(mut self) -> Result<Array2<f64>, String> {
462        let n_chunks = self.n_chunks();
463        if self.frontier != n_chunks {
464            let missing: Vec<usize> = (self.frontier..n_chunks)
465                .filter(|idx| !self.pending.contains_key(idx))
466                .take(8)
467                .collect();
468            return Err(format!(
469                "StreamingBorderGram: finish() before all chunks were submitted \
470                 (frontier {}/{n_chunks}, first missing chunk indices {missing:?})",
471                self.frontier
472            ));
473        }
474        // Seal the trailing (short) base block, exactly like
475        // `StreamingPairwise::finish`.
476        if let Some(tail) = self.block_partial.take() {
477            let w = self.block_len;
478            self.block_len = 0;
479            self.forest.push((w, tail));
480        }
481        // Fold the forest right-to-left: each parent is
482        // combine(left_partial, accumulated_right).
483        let k = self.border_dim;
484        let mut iter = self.forest.into_iter().rev();
485        let flat = match iter.next() {
486            None => vec![0.0_f64; k * k],
487            Some((_, mut acc)) => {
488                for (_, left) in iter {
489                    add_into(&mut acc, &left);
490                }
491                acc
492            }
493        };
494        Array2::from_shape_vec((k, k), flat)
495            .map_err(|e| format!("StreamingBorderGram: Gram reshape failed: {e}"))
496    }
497}
498
499/// Bridges arbitrary-length row batches onto the fixed chunk partition.
500///
501/// A streaming row source (`gam_sae::corpus`) yields batches whose
502/// lengths are set by I/O policy (batch size, shard boundaries) — they do
503/// **not** align with the deterministic chunk partition the accumulation tree
504/// is keyed on. This assembler buffers incoming rows and submits exact chunks
505/// in order, so the resulting Gram is bit-identical to having sliced the
506/// partition directly: the batching of the producer can never leak into the
507/// bits.
508///
509/// Checkpointing is exposed **at chunk granularity only**:
510/// [`ChunkAssembler::checkpoint`] returns `Some` exactly when the internal
511/// buffer is empty (a chunk boundary), because buffered raw rows are not part
512/// of the accumulation state contract — a resumed pass re-reads its row
513/// stream from the checkpointed chunk cursor
514/// ([`StreamingBorderGram::chunk_rows`] of the frontier names the next row).
515pub struct ChunkAssembler {
516    gram: StreamingBorderGram,
517    /// Row-major buffered rows (`buffered_rows × border_dim`), not yet a full
518    /// chunk.
519    buffer: Vec<f64>,
520    /// Next chunk index to submit (in-order by construction).
521    next_chunk: usize,
522}
523
524impl ChunkAssembler {
525    /// New assembler over the same partition parameters as
526    /// [`StreamingBorderGram::new`].
527    pub fn new(border_dim: usize, n_rows: usize, chunk_size: usize) -> Result<Self, String> {
528        Ok(Self {
529            gram: StreamingBorderGram::new(border_dim, n_rows, chunk_size)?,
530            buffer: Vec::new(),
531            next_chunk: 0,
532        })
533    }
534
535    /// Number of buffered rows not yet folded into a chunk.
536    fn buffered_rows(&self) -> usize {
537        let k = self.gram.border_dim;
538        // Structural invariant: rows enter the buffer only via `push_rows`,
539        // which rejects any batch whose width is not `k`, and leave only in
540        // `need * k` drains — so the length is always a whole multiple of `k`.
541        // Kept as a real (release-surviving) assert rather than a `debug_assert`
542        // so a future buffer-maintenance bug cannot silently round the row count
543        // down.
544        assert!(
545            self.buffer.len() % k == 0,
546            "ChunkAssembler buffer length {} is not a multiple of border_dim {k}",
547            self.buffer.len()
548        );
549        self.buffer.len() / k
550    }
551
552    /// Append a batch of rows (any length, including empty) in stream order,
553    /// submitting every chunk the buffer completes.
554    pub fn push_rows(&mut self, rows: ArrayView2<'_, f64>) -> Result<(), String> {
555        let k = self.gram.border_dim;
556        if rows.ncols() != k {
557            return Err(format!(
558                "ChunkAssembler: batch has {} cols but border_dim is {k}",
559                rows.ncols()
560            ));
561        }
562        let n_chunks = self.gram.n_chunks();
563        // Rows consumed by completed chunks: the final chunk may be short, so
564        // clamp to the declared total.
565        let consumed = (self.gram.frontier() * self.gram.chunk_size).min(self.gram.n_rows);
566        let total_seen = consumed + self.buffered_rows() + rows.nrows();
567        if total_seen > self.gram.n_rows {
568            return Err(format!(
569                "ChunkAssembler: stream overran the declared row count ({} > {})",
570                total_seen, self.gram.n_rows
571            ));
572        }
573        for row in rows.outer_iter() {
574            self.buffer.extend(row.iter().copied());
575        }
576        // Submit every completed chunk in order.
577        while self.next_chunk < n_chunks {
578            let need = self.gram.chunk_rows(self.next_chunk).len();
579            if self.buffered_rows() < need {
580                break;
581            }
582            let chunk: Vec<f64> = self.buffer.drain(..need * k).collect();
583            let view = ndarray::ArrayView2::from_shape((need, k), &chunk)
584                .map_err(|e| format!("ChunkAssembler: chunk reshape failed: {e}"))?;
585            self.gram.submit_chunk(self.next_chunk, view)?;
586            self.next_chunk += 1;
587        }
588        Ok(())
589    }
590
591    /// Serialize the accumulation state — only at a chunk boundary. `None`
592    /// while rows are buffered mid-chunk (checkpoint after the next boundary,
593    /// or size batches to the chunk size for checkpoint-every-batch).
594    pub fn checkpoint(&self) -> Option<BorderGramCheckpoint> {
595        if self.buffer.is_empty() {
596            Some(self.gram.checkpoint())
597        } else {
598            None
599        }
600    }
601
602    /// Resume an assembler at the chunk boundary a checkpoint names. The
603    /// caller re-positions its row stream at row
604    /// `checkpoint.frontier * checkpoint.chunk_size` (the partition is pure,
605    /// so that index is exact) and replays from there.
606    pub fn resume(state: BorderGramCheckpoint) -> Result<Self, String> {
607        let gram = StreamingBorderGram::resume(state)?;
608        let next_chunk = gram.frontier();
609        Ok(Self {
610            gram,
611            buffer: Vec::new(),
612            next_chunk,
613        })
614    }
615
616    /// Finish the pass. Errors if the stream ended mid-chunk or short of the
617    /// declared row count — a truncated stream is rejected loudly, never
618    /// folded as a silently shorter corpus.
619    pub fn finish(self) -> Result<Array2<f64>, String> {
620        if !self.buffer.is_empty() {
621            let k = self.gram.border_dim;
622            return Err(format!(
623                "ChunkAssembler: stream ended mid-chunk with {} buffered rows \
624                 (declared n_rows = {})",
625                self.buffer.len() / k,
626                self.gram.n_rows
627            ));
628        }
629        self.gram.finish()
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636    use ndarray::Array2;
637
638    /// Deterministic pseudo-random row matrix keyed purely by index.
639    fn planted_rows(n: usize, k: usize) -> Array2<f64> {
640        Array2::from_shape_fn((n, k), |(i, j)| {
641            let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
642            (x.sin() * 43_758.547).fract() * 2.0 - 1.0
643        })
644    }
645
646    fn accumulate_in_order(
647        rows: &Array2<f64>,
648        chunk_size: usize,
649    ) -> (StreamingBorderGram, Vec<usize>) {
650        let acc =
651            StreamingBorderGram::new(rows.ncols(), rows.nrows(), chunk_size).expect("accumulator");
652        let order: Vec<usize> = (0..acc.n_chunks()).collect();
653        (acc, order)
654    }
655
656    fn run_with_order(rows: &Array2<f64>, chunk_size: usize, order: &[usize]) -> Array2<f64> {
657        let mut acc =
658            StreamingBorderGram::new(rows.ncols(), rows.nrows(), chunk_size).expect("accumulator");
659        for &j in order {
660            let range = acc.chunk_rows(j);
661            acc.submit_chunk(j, rows.slice(ndarray::s![range, ..]))
662                .expect("submit");
663        }
664        acc.finish().expect("finish")
665    }
666
667    fn assert_bit_identical(a: &Array2<f64>, b: &Array2<f64>, label: &str) {
668        assert_eq!(a.dim(), b.dim(), "{label}: shape mismatch");
669        for ((idx, x), y) in a.indexed_iter().zip(b.iter()) {
670            assert_eq!(
671                x.to_bits(),
672                y.to_bits(),
673                "{label}: entry {idx:?} differs bitwise: {x:?} vs {y:?}"
674            );
675        }
676    }
677
678    #[test]
679    fn gram_matches_naive_xtx() {
680        let n = 257; // deliberately not a multiple of the chunk size
681        let k = 5;
682        let rows = planted_rows(n, k);
683        let gram = run_with_order(&rows, 16, &(0..17).collect::<Vec<_>>());
684        let naive = rows.t().dot(&rows);
685        for i in 0..k {
686            for j in 0..k {
687                let d = (gram[[i, j]] - naive[[i, j]]).abs();
688                let scale = naive[[i, j]].abs().max(1.0);
689                assert!(
690                    d <= 1.0e-12 * scale,
691                    "Gram[{i},{j}] = {} vs naive {} (delta {d})",
692                    gram[[i, j]],
693                    naive[[i, j]]
694                );
695            }
696        }
697        // Bitwise symmetry: mirror entries reuse the same product sequence.
698        for i in 0..k {
699            for j in 0..k {
700                assert_eq!(gram[[i, j]].to_bits(), gram[[j, i]].to_bits());
701            }
702        }
703    }
704
705    #[test]
706    fn bit_reproducible_across_chunk_submission_orders() {
707        // Enough chunks (> CROSS_CHUNK_BASE) to exercise the base-block seal,
708        // the power-of-two cascade, AND the trailing short block.
709        let n = 2 * CROSS_CHUNK_BASE * 3 + 7; // 775 rows
710        let k = 4;
711        let chunk_size = 2; // 388 chunks
712        let rows = planted_rows(n, k);
713        let n_chunks = n.div_ceil(chunk_size);
714
715        let in_order: Vec<usize> = (0..n_chunks).collect();
716        let reversed: Vec<usize> = (0..n_chunks).rev().collect();
717        // Deterministic stride shuffle (388 is coprime to 129).
718        let strided: Vec<usize> = (0..n_chunks).map(|i| (i * 129) % n_chunks).collect();
719
720        let g0 = run_with_order(&rows, chunk_size, &in_order);
721        let g1 = run_with_order(&rows, chunk_size, &reversed);
722        let g2 = run_with_order(&rows, chunk_size, &strided);
723
724        assert_bit_identical(&g0, &g1, "in-order vs reversed submission");
725        assert_bit_identical(&g0, &g2, "in-order vs strided submission");
726    }
727
728    #[test]
729    fn cross_chunk_association_matches_landed_pairwise_sum() {
730        // The cross-chunk cascade must associate per-chunk Gram entries
731        // EXACTLY as the landed `pairwise_sum` tree does: for every entry,
732        // finish() == pairwise_sum(per-chunk entry values), bit for bit.
733        let n = 613;
734        let k = 3;
735        let chunk_size = 2; // 307 chunks: cascade + trailing block both live
736        let rows = planted_rows(n, k);
737        let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
738        let n_chunks = acc.n_chunks();
739        let mut per_chunk_entries: Vec<Vec<f64>> = vec![Vec::with_capacity(n_chunks); k * k];
740        for j in 0..n_chunks {
741            let range = acc.chunk_rows(j);
742            let chunk = rows.slice(ndarray::s![range, ..]);
743            let g = acc.chunk_gram(chunk);
744            for (e, vals) in g.iter().zip(per_chunk_entries.iter_mut()) {
745                vals.push(*e);
746            }
747            acc.submit_chunk(j, chunk).expect("submit");
748        }
749        let gram = acc.finish().expect("finish");
750        for a in 0..k {
751            for b in 0..k {
752                let expected = pairwise_sum(&per_chunk_entries[a * k + b]);
753                assert_eq!(
754                    gram[[a, b]].to_bits(),
755                    expected.to_bits(),
756                    "entry ({a},{b}): cascade {} vs pairwise_sum {}",
757                    gram[[a, b]],
758                    expected
759                );
760            }
761        }
762    }
763
764    #[test]
765    fn resume_equals_straight_through() {
766        let n = 491;
767        let k = 4;
768        let chunk_size = 3;
769        let rows = planted_rows(n, k);
770        let (acc, order) = accumulate_in_order(&rows, chunk_size);
771        let n_chunks = acc.n_chunks();
772        // Straight-through reference.
773        let straight = run_with_order(&rows, chunk_size, &order);
774
775        // Interrupted run: submit a mixed-order prefix (so the checkpoint
776        // carries a live base-block partial, forest entries, AND pending
777        // out-of-order chunks), checkpoint through a serde round-trip, resume,
778        // submit the rest.
779        let mut first = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
780        let mut submitted = vec![false; n_chunks];
781        // Prefix: chunks 0..60 in order, plus three far-ahead chunks.
782        let prefix: Vec<usize> = (0..60).chain([150, 100, 163]).collect();
783        for &j in &prefix {
784            let range = first.chunk_rows(j);
785            first
786                .submit_chunk(j, rows.slice(ndarray::s![range, ..]))
787                .expect("prefix submit");
788            submitted[j] = true;
789        }
790        assert!(
791            !first.pending.is_empty(),
792            "fixture must exercise pending out-of-order state"
793        );
794        let json = serde_json::to_string(&first.checkpoint()).expect("serialize checkpoint");
795        drop(first);
796        let restored: BorderGramCheckpoint =
797            serde_json::from_str(&json).expect("deserialize checkpoint");
798        let mut second = StreamingBorderGram::resume(restored).expect("resume");
799        for j in 0..n_chunks {
800            if submitted[j] {
801                continue;
802            }
803            let range = second.chunk_rows(j);
804            second
805                .submit_chunk(j, rows.slice(ndarray::s![range, ..]))
806                .expect("resumed submit");
807        }
808        let resumed = second.finish().expect("finish resumed");
809        assert_bit_identical(&straight, &resumed, "resume vs straight-through");
810    }
811
812    #[test]
813    fn rejects_duplicates_missing_chunks_and_bad_shapes() {
814        let n = 10;
815        let k = 2;
816        let chunk_size = 4; // chunks: [0,4), [4,8), [8,10)
817        let rows = planted_rows(n, k);
818        let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
819        assert_eq!(acc.n_chunks(), 3);
820
821        // Wrong shape (short chunk submitted at a full-chunk index).
822        let err = acc
823            .submit_chunk(0, rows.slice(ndarray::s![0..3, ..]))
824            .expect_err("short chunk must be rejected");
825        assert!(err.contains("expected (4, 2)"), "got: {err}");
826
827        acc.submit_chunk(0, rows.slice(ndarray::s![0..4, ..]))
828            .expect("chunk 0");
829        // Duplicate in-order chunk.
830        let err = acc
831            .submit_chunk(0, rows.slice(ndarray::s![0..4, ..]))
832            .expect_err("duplicate must be rejected");
833        assert!(err.contains("already submitted"), "got: {err}");
834
835        // Out-of-order chunk 2 (short trailing chunk, 2 rows), then duplicate.
836        acc.submit_chunk(2, rows.slice(ndarray::s![8..10, ..]))
837            .expect("chunk 2 out of order");
838        let err = acc
839            .submit_chunk(2, rows.slice(ndarray::s![8..10, ..]))
840            .expect_err("duplicate pending must be rejected");
841        assert!(err.contains("already submitted"), "got: {err}");
842
843        // Out-of-range chunk index.
844        let err = acc
845            .submit_chunk(3, rows.slice(ndarray::s![0..4, ..]))
846            .expect_err("out-of-range index must be rejected");
847        assert!(err.contains("out of range"), "got: {err}");
848
849        // finish() with chunk 1 missing must fail and name it.
850        let err = acc.finish().expect_err("missing chunk must fail finish");
851        assert!(
852            err.contains("[1]"),
853            "missing-chunk message must name chunk 1: {err}"
854        );
855    }
856
857    #[test]
858    fn checkpoint_validation_rejects_corruption() {
859        let mut acc = StreamingBorderGram::new(3, 100, 10).expect("accumulator");
860        let rows = planted_rows(100, 3);
861        acc.submit_chunk(0, rows.slice(ndarray::s![0..10, ..]))
862            .expect("chunk 0");
863        let good = acc.checkpoint();
864
865        let mut bad = good.clone();
866        bad.block_len = 0; // inconsistent with a live block_partial
867        assert!(StreamingBorderGram::resume(bad).is_err());
868
869        let mut bad = good.clone();
870        if let Some(b) = bad.block_partial.as_mut() {
871            b.pop(); // wrong partial length
872        }
873        assert!(StreamingBorderGram::resume(bad).is_err());
874
875        let mut bad = good.clone();
876        bad.pending.push((0, vec![0.0; 9])); // pending below the frontier
877        assert!(StreamingBorderGram::resume(bad).is_err());
878
879        let mut bad = good;
880        bad.frontier = 99; // beyond n_chunks
881        assert!(StreamingBorderGram::resume(bad).is_err());
882    }
883
884    #[test]
885    fn chunk_assembler_is_batching_invariant() {
886        // Whatever batch lengths the I/O layer produces, the assembled Gram is
887        // bit-identical to slicing the fixed partition directly.
888        let n = 463;
889        let k = 4;
890        let chunk_size = 16;
891        let rows = planted_rows(n, k);
892        let direct = {
893            let (acc, order) = accumulate_in_order(&rows, chunk_size);
894            drop(acc);
895            run_with_order(&rows, chunk_size, &order)
896        };
897
898        // Misaligned, varying batch lengths (3, 5, 7, 11, 13, cycling).
899        let mut asm = ChunkAssembler::new(k, n, chunk_size).expect("assembler");
900        let sizes = [3usize, 5, 7, 11, 13];
901        let mut at = 0usize;
902        let mut s = 0usize;
903        while at < n {
904            let take = sizes[s % sizes.len()].min(n - at);
905            asm.push_rows(rows.slice(ndarray::s![at..at + take, ..]))
906                .expect("push");
907            at += take;
908            s += 1;
909        }
910        let assembled = asm.finish().expect("finish");
911        assert_bit_identical(&direct, &assembled, "direct vs assembled batching");
912    }
913
914    #[test]
915    fn chunk_assembler_checkpoints_only_at_boundaries_and_resumes() {
916        let n = 200;
917        let k = 3;
918        let chunk_size = 10;
919        let rows = planted_rows(n, k);
920        let direct = run_with_order(&rows, chunk_size, &(0..20).collect::<Vec<_>>());
921
922        let mut asm = ChunkAssembler::new(k, n, chunk_size).expect("assembler");
923        // 7 rows: mid-chunk, no checkpoint available.
924        asm.push_rows(rows.slice(ndarray::s![0..7, ..]))
925            .expect("push");
926        assert!(
927            asm.checkpoint().is_none(),
928            "mid-chunk checkpoint must be None"
929        );
930        // Up to row 30: exactly 3 chunks folded, boundary checkpoint.
931        asm.push_rows(rows.slice(ndarray::s![7..30, ..]))
932            .expect("push");
933        let cp = asm.checkpoint().expect("boundary checkpoint");
934        assert_eq!(cp.frontier, 3);
935        drop(asm);
936
937        // Resume: re-read the stream from row frontier * chunk_size.
938        let mut resumed = ChunkAssembler::resume(cp).expect("resume");
939        resumed
940            .push_rows(rows.slice(ndarray::s![30..n, ..]))
941            .expect("push rest");
942        let gram = resumed.finish().expect("finish");
943        assert_bit_identical(&direct, &gram, "assembler resume vs straight-through");
944    }
945
946    #[test]
947    fn chunk_assembler_rejects_truncated_and_overrunning_streams() {
948        let k = 2;
949        let rows = planted_rows(30, k);
950        // Truncated: declared 30 rows, stream ends at 25 (mid-chunk).
951        let mut asm = ChunkAssembler::new(k, 30, 8).expect("assembler");
952        asm.push_rows(rows.slice(ndarray::s![0..25, ..]))
953            .expect("push");
954        let err = asm.finish().expect_err("truncated stream must fail finish");
955        assert!(err.contains("mid-chunk"), "got: {err}");
956        // Overrun: more rows than declared.
957        let mut asm = ChunkAssembler::new(k, 20, 8).expect("assembler");
958        let err = asm
959            .push_rows(rows.slice(ndarray::s![0..25, ..]))
960            .expect_err("overrun must be rejected");
961        assert!(err.contains("overran"), "got: {err}");
962    }
963
964    /// Mixed-precision error budget (#973): rows stored `f32` (the shard
965    /// format) and accumulated in `f64` must reproduce the all-`f64` border
966    /// Gram within a **named tolerance**, entry-wise relative to the Gram's
967    /// scale. `f32` storage rounds each value to ~6e-8 relative; products
968    /// double that; the deterministic f64 pairwise accumulation adds nothing
969    /// material. The budget below is ~100× that floor — tight enough to catch
970    /// any f32 accumulation sneaking into the path, loose enough to never
971    /// flake on legitimate storage rounding.
972    const MIXED_PRECISION_BORDER_RTOL: f64 = 1.0e-5;
973
974    #[test]
975    fn f32_storage_f64_accumulation_meets_the_error_budget() {
976        let n = 700;
977        let k = 5;
978        let chunk_size = 32;
979        let rows = planted_rows(n, k);
980        // The storage path: round every value through f32 (exactly what the
981        // shard writer + reader do), then accumulate in f64.
982        let stored = rows.mapv(|v| f64::from(v as f32));
983        let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
984        for j in 0..acc.n_chunks() {
985            let range = acc.chunk_rows(j);
986            acc.submit_chunk(j, stored.slice(ndarray::s![range, ..]))
987                .expect("submit");
988        }
989        let mixed = acc.finish().expect("finish");
990        // All-f64 reference.
991        let exact = rows.t().dot(&rows);
992        let scale = exact.iter().fold(0.0_f64, |m, &v| m.max(v.abs())).max(1.0);
993        for i in 0..k {
994            for j in 0..k {
995                let d = (mixed[[i, j]] - exact[[i, j]]).abs();
996                assert!(
997                    d <= MIXED_PRECISION_BORDER_RTOL * scale,
998                    "Gram[{i},{j}] mixed-precision delta {d:.3e} exceeds budget \
999                     {MIXED_PRECISION_BORDER_RTOL:.0e} × scale {scale:.3e}"
1000                );
1001            }
1002        }
1003    }
1004
1005    #[test]
1006    fn zero_rows_yields_zero_gram() {
1007        let acc = StreamingBorderGram::new(3, 0, 8).expect("accumulator");
1008        assert_eq!(acc.n_chunks(), 0);
1009        assert!(acc.is_complete());
1010        let gram = acc.finish().expect("finish empty");
1011        assert_eq!(gram.dim(), (3, 3));
1012        assert!(gram.iter().all(|v| v.to_bits() == 0.0_f64.to_bits()));
1013    }
1014}