1use gam_linalg::pairwise_reduce::{BASE_CHUNK, pairwise_sum};
45use ndarray::{Array2, ArrayView2};
46use serde::{Deserialize, Serialize};
47use std::collections::BTreeMap;
48
49pub const CROSS_CHUNK_BASE: usize = BASE_CHUNK;
57
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct BorderGramCheckpoint {
65 pub border_dim: usize,
67 pub n_rows: usize,
69 pub chunk_size: usize,
71 pub frontier: usize,
75 pub block_partial: Option<Vec<f64>>,
78 pub block_len: usize,
81 pub forest: Vec<(usize, Vec<f64>)>,
85 pub pending: Vec<(usize, Vec<f64>)>,
88}
89
90pub struct StreamingBorderGram {
98 border_dim: usize,
99 n_rows: usize,
100 chunk_size: usize,
101 frontier: usize,
103 block_partial: Option<Vec<f64>>,
105 block_len: usize,
107 forest: Vec<(usize, Vec<f64>)>,
109 pending: BTreeMap<usize, Vec<f64>>,
111}
112
113fn add_into(acc: &mut [f64], rhs: &[f64]) {
119 for (a, r) in acc.iter_mut().zip(rhs.iter()) {
120 *a += *r;
121 }
122}
123
124pub fn chunk_gram_flat(rows: ArrayView2<'_, f64>) -> Vec<f64> {
137 let k = rows.ncols();
138 let r = rows.nrows();
139 let mut gram = vec![0.0_f64; k * k];
140 let mut products = vec![0.0_f64; r];
141 for a in 0..k {
142 for b in a..k {
143 for (i, p) in products.iter_mut().enumerate() {
144 *p = rows[[i, a]] * rows[[i, b]];
145 }
146 let s = pairwise_sum(&products);
147 gram[a * k + b] = s;
148 gram[b * k + a] = s;
149 }
150 }
151 gram
152}
153
154impl StreamingBorderGram {
155 pub fn new(border_dim: usize, n_rows: usize, chunk_size: usize) -> Result<Self, String> {
158 if border_dim == 0 {
159 return Err("StreamingBorderGram: border_dim must be positive".to_string());
160 }
161 if chunk_size == 0 {
162 return Err("StreamingBorderGram: chunk_size must be positive".to_string());
163 }
164 Ok(Self {
165 border_dim,
166 n_rows,
167 chunk_size,
168 frontier: 0,
169 block_partial: None,
170 block_len: 0,
171 forest: Vec::new(),
172 pending: BTreeMap::new(),
173 })
174 }
175
176 pub fn n_chunks(&self) -> usize {
178 self.n_rows.div_ceil(self.chunk_size)
179 }
180
181 pub fn chunk_rows(&self, chunk_index: usize) -> std::ops::Range<usize> {
186 let lo = chunk_index * self.chunk_size;
187 let hi = ((chunk_index + 1) * self.chunk_size).min(self.n_rows);
188 lo..hi
189 }
190
191 pub fn frontier(&self) -> usize {
194 self.frontier
195 }
196
197 pub fn is_complete(&self) -> bool {
199 self.frontier == self.n_chunks() && self.pending.is_empty()
200 }
201
202 pub fn submit_chunk(
210 &mut self,
211 chunk_index: usize,
212 rows: ArrayView2<'_, f64>,
213 ) -> Result<(), String> {
214 let n_chunks = self.n_chunks();
215 if chunk_index >= n_chunks {
216 return Err(format!(
217 "StreamingBorderGram: chunk index {chunk_index} out of range (n_chunks = {n_chunks})"
218 ));
219 }
220 if chunk_index < self.frontier || self.pending.contains_key(&chunk_index) {
221 return Err(format!(
222 "StreamingBorderGram: chunk {chunk_index} was already submitted"
223 ));
224 }
225 let expected_rows = self.chunk_rows(chunk_index).len();
226 if rows.nrows() != expected_rows || rows.ncols() != self.border_dim {
227 return Err(format!(
228 "StreamingBorderGram: chunk {chunk_index} has shape ({}, {}) but expected ({}, {})",
229 rows.nrows(),
230 rows.ncols(),
231 expected_rows,
232 self.border_dim
233 ));
234 }
235 let gram = self.chunk_gram(rows);
236 self.fold_or_park(chunk_index, gram);
237 Ok(())
238 }
239
240 pub fn submit_chunk_gram(&mut self, chunk_index: usize, gram: Vec<f64>) -> Result<(), String> {
254 let n_chunks = self.n_chunks();
255 if chunk_index >= n_chunks {
256 return Err(format!(
257 "StreamingBorderGram: chunk index {chunk_index} out of range (n_chunks = {n_chunks})"
258 ));
259 }
260 if chunk_index < self.frontier || self.pending.contains_key(&chunk_index) {
261 return Err(format!(
262 "StreamingBorderGram: chunk {chunk_index} was already submitted"
263 ));
264 }
265 let kk = self.border_dim * self.border_dim;
266 if gram.len() != kk {
267 return Err(format!(
268 "StreamingBorderGram: chunk {chunk_index} partial has len {} but expected {kk}",
269 gram.len()
270 ));
271 }
272 if !gram.iter().all(|v| v.is_finite()) {
273 return Err(format!(
274 "StreamingBorderGram: chunk {chunk_index} partial contains non-finite entries"
275 ));
276 }
277 self.fold_or_park(chunk_index, gram);
278 Ok(())
279 }
280
281 fn fold_or_park(&mut self, chunk_index: usize, gram: Vec<f64>) {
285 if chunk_index == self.frontier {
286 self.fold_chunk(gram);
287 self.frontier += 1;
288 while let Some(next) = self.pending.remove(&self.frontier) {
290 self.fold_chunk(next);
291 self.frontier += 1;
292 }
293 } else {
294 self.pending.insert(chunk_index, gram);
295 }
296 }
297
298 fn chunk_gram(&self, rows: ArrayView2<'_, f64>) -> Vec<f64> {
302 chunk_gram_flat(rows)
303 }
304
305 fn fold_chunk(&mut self, gram: Vec<f64>) {
311 match self.block_partial.as_mut() {
312 None => {
313 self.block_partial = Some(gram);
314 self.block_len = 1;
315 }
316 Some(acc) => {
317 add_into(acc, &gram);
318 self.block_len += 1;
319 }
320 }
321 if self.block_len == CROSS_CHUNK_BASE {
322 let block = self
323 .block_partial
324 .take()
325 .expect("block_len == CROSS_CHUNK_BASE implies a live block partial");
326 self.block_len = 0;
327 self.absorb(CROSS_CHUNK_BASE, block);
328 }
329 }
330
331 fn absorb(&mut self, weight: usize, value: Vec<f64>) {
335 let mut w = weight;
336 let mut v = value;
337 while let Some((top_w, _)) = self.forest.last() {
338 if *top_w == w {
339 let (_, top_v) = self
340 .forest
341 .pop()
342 .expect("forest top exists: just observed by last()");
343 v = {
345 let mut merged = top_v;
346 add_into(&mut merged, &v);
347 merged
348 };
349 w = w.saturating_mul(2);
350 } else {
351 break;
352 }
353 }
354 self.forest.push((w, v));
355 }
356
357 pub fn checkpoint(&self) -> BorderGramCheckpoint {
362 BorderGramCheckpoint {
363 border_dim: self.border_dim,
364 n_rows: self.n_rows,
365 chunk_size: self.chunk_size,
366 frontier: self.frontier,
367 block_partial: self.block_partial.clone(),
368 block_len: self.block_len,
369 forest: self.forest.clone(),
370 pending: self
371 .pending
372 .iter()
373 .map(|(idx, g)| (*idx, g.clone()))
374 .collect(),
375 }
376 }
377
378 pub fn resume(state: BorderGramCheckpoint) -> Result<Self, String> {
382 if state.border_dim == 0 {
383 return Err("BorderGramCheckpoint: border_dim must be positive".to_string());
384 }
385 if state.chunk_size == 0 {
386 return Err("BorderGramCheckpoint: chunk_size must be positive".to_string());
387 }
388 let kk = state.border_dim * state.border_dim;
389 let n_chunks = state.n_rows.div_ceil(state.chunk_size);
390 if state.frontier > n_chunks {
391 return Err(format!(
392 "BorderGramCheckpoint: frontier {} exceeds n_chunks {n_chunks}",
393 state.frontier
394 ));
395 }
396 if state.block_len >= CROSS_CHUNK_BASE {
397 return Err(format!(
398 "BorderGramCheckpoint: block_len {} must be < CROSS_CHUNK_BASE {CROSS_CHUNK_BASE}",
399 state.block_len
400 ));
401 }
402 if state.block_partial.is_some() != (state.block_len > 0) {
403 return Err(
404 "BorderGramCheckpoint: block_partial presence inconsistent with block_len"
405 .to_string(),
406 );
407 }
408 if let Some(b) = &state.block_partial {
409 if b.len() != kk {
410 return Err(format!(
411 "BorderGramCheckpoint: block_partial has len {} but expected {kk}",
412 b.len()
413 ));
414 }
415 }
416 for (w, g) in &state.forest {
417 if *w == 0 || g.len() != kk {
418 return Err(
419 "BorderGramCheckpoint: malformed forest partial (zero weight or wrong len)"
420 .to_string(),
421 );
422 }
423 }
424 let mut pending = BTreeMap::new();
425 for (idx, g) in state.pending {
426 if idx < state.frontier || idx >= n_chunks {
427 return Err(format!(
428 "BorderGramCheckpoint: pending chunk index {idx} outside (frontier {}, n_chunks {n_chunks})",
429 state.frontier
430 ));
431 }
432 if g.len() != kk {
433 return Err(format!(
434 "BorderGramCheckpoint: pending chunk {idx} partial has len {} but expected {kk}",
435 g.len()
436 ));
437 }
438 if pending.insert(idx, g).is_some() {
439 return Err(format!(
440 "BorderGramCheckpoint: duplicate pending chunk index {idx}"
441 ));
442 }
443 }
444 Ok(Self {
445 border_dim: state.border_dim,
446 n_rows: state.n_rows,
447 chunk_size: state.chunk_size,
448 frontier: state.frontier,
449 block_partial: state.block_partial,
450 block_len: state.block_len,
451 forest: state.forest,
452 pending,
453 })
454 }
455
456 pub fn finish(mut self) -> Result<Array2<f64>, String> {
462 let n_chunks = self.n_chunks();
463 if self.frontier != n_chunks {
464 let missing: Vec<usize> = (self.frontier..n_chunks)
465 .filter(|idx| !self.pending.contains_key(idx))
466 .take(8)
467 .collect();
468 return Err(format!(
469 "StreamingBorderGram: finish() before all chunks were submitted \
470 (frontier {}/{n_chunks}, first missing chunk indices {missing:?})",
471 self.frontier
472 ));
473 }
474 if let Some(tail) = self.block_partial.take() {
477 let w = self.block_len;
478 self.block_len = 0;
479 self.forest.push((w, tail));
480 }
481 let k = self.border_dim;
484 let mut iter = self.forest.into_iter().rev();
485 let flat = match iter.next() {
486 None => vec![0.0_f64; k * k],
487 Some((_, mut acc)) => {
488 for (_, left) in iter {
489 add_into(&mut acc, &left);
490 }
491 acc
492 }
493 };
494 Array2::from_shape_vec((k, k), flat)
495 .map_err(|e| format!("StreamingBorderGram: Gram reshape failed: {e}"))
496 }
497}
498
499pub struct ChunkAssembler {
516 gram: StreamingBorderGram,
517 buffer: Vec<f64>,
520 next_chunk: usize,
522}
523
524impl ChunkAssembler {
525 pub fn new(border_dim: usize, n_rows: usize, chunk_size: usize) -> Result<Self, String> {
528 Ok(Self {
529 gram: StreamingBorderGram::new(border_dim, n_rows, chunk_size)?,
530 buffer: Vec::new(),
531 next_chunk: 0,
532 })
533 }
534
535 fn buffered_rows(&self) -> usize {
537 let k = self.gram.border_dim;
538 assert!(
545 self.buffer.len() % k == 0,
546 "ChunkAssembler buffer length {} is not a multiple of border_dim {k}",
547 self.buffer.len()
548 );
549 self.buffer.len() / k
550 }
551
552 pub fn push_rows(&mut self, rows: ArrayView2<'_, f64>) -> Result<(), String> {
555 let k = self.gram.border_dim;
556 if rows.ncols() != k {
557 return Err(format!(
558 "ChunkAssembler: batch has {} cols but border_dim is {k}",
559 rows.ncols()
560 ));
561 }
562 let n_chunks = self.gram.n_chunks();
563 let consumed = (self.gram.frontier() * self.gram.chunk_size).min(self.gram.n_rows);
566 let total_seen = consumed + self.buffered_rows() + rows.nrows();
567 if total_seen > self.gram.n_rows {
568 return Err(format!(
569 "ChunkAssembler: stream overran the declared row count ({} > {})",
570 total_seen, self.gram.n_rows
571 ));
572 }
573 for row in rows.outer_iter() {
574 self.buffer.extend(row.iter().copied());
575 }
576 while self.next_chunk < n_chunks {
578 let need = self.gram.chunk_rows(self.next_chunk).len();
579 if self.buffered_rows() < need {
580 break;
581 }
582 let chunk: Vec<f64> = self.buffer.drain(..need * k).collect();
583 let view = ndarray::ArrayView2::from_shape((need, k), &chunk)
584 .map_err(|e| format!("ChunkAssembler: chunk reshape failed: {e}"))?;
585 self.gram.submit_chunk(self.next_chunk, view)?;
586 self.next_chunk += 1;
587 }
588 Ok(())
589 }
590
591 pub fn checkpoint(&self) -> Option<BorderGramCheckpoint> {
595 if self.buffer.is_empty() {
596 Some(self.gram.checkpoint())
597 } else {
598 None
599 }
600 }
601
602 pub fn resume(state: BorderGramCheckpoint) -> Result<Self, String> {
607 let gram = StreamingBorderGram::resume(state)?;
608 let next_chunk = gram.frontier();
609 Ok(Self {
610 gram,
611 buffer: Vec::new(),
612 next_chunk,
613 })
614 }
615
616 pub fn finish(self) -> Result<Array2<f64>, String> {
620 if !self.buffer.is_empty() {
621 let k = self.gram.border_dim;
622 return Err(format!(
623 "ChunkAssembler: stream ended mid-chunk with {} buffered rows \
624 (declared n_rows = {})",
625 self.buffer.len() / k,
626 self.gram.n_rows
627 ));
628 }
629 self.gram.finish()
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use ndarray::Array2;
637
638 fn planted_rows(n: usize, k: usize) -> Array2<f64> {
640 Array2::from_shape_fn((n, k), |(i, j)| {
641 let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
642 (x.sin() * 43_758.547).fract() * 2.0 - 1.0
643 })
644 }
645
646 fn accumulate_in_order(
647 rows: &Array2<f64>,
648 chunk_size: usize,
649 ) -> (StreamingBorderGram, Vec<usize>) {
650 let acc =
651 StreamingBorderGram::new(rows.ncols(), rows.nrows(), chunk_size).expect("accumulator");
652 let order: Vec<usize> = (0..acc.n_chunks()).collect();
653 (acc, order)
654 }
655
656 fn run_with_order(rows: &Array2<f64>, chunk_size: usize, order: &[usize]) -> Array2<f64> {
657 let mut acc =
658 StreamingBorderGram::new(rows.ncols(), rows.nrows(), chunk_size).expect("accumulator");
659 for &j in order {
660 let range = acc.chunk_rows(j);
661 acc.submit_chunk(j, rows.slice(ndarray::s![range, ..]))
662 .expect("submit");
663 }
664 acc.finish().expect("finish")
665 }
666
667 fn assert_bit_identical(a: &Array2<f64>, b: &Array2<f64>, label: &str) {
668 assert_eq!(a.dim(), b.dim(), "{label}: shape mismatch");
669 for ((idx, x), y) in a.indexed_iter().zip(b.iter()) {
670 assert_eq!(
671 x.to_bits(),
672 y.to_bits(),
673 "{label}: entry {idx:?} differs bitwise: {x:?} vs {y:?}"
674 );
675 }
676 }
677
678 #[test]
679 fn gram_matches_naive_xtx() {
680 let n = 257; let k = 5;
682 let rows = planted_rows(n, k);
683 let gram = run_with_order(&rows, 16, &(0..17).collect::<Vec<_>>());
684 let naive = rows.t().dot(&rows);
685 for i in 0..k {
686 for j in 0..k {
687 let d = (gram[[i, j]] - naive[[i, j]]).abs();
688 let scale = naive[[i, j]].abs().max(1.0);
689 assert!(
690 d <= 1.0e-12 * scale,
691 "Gram[{i},{j}] = {} vs naive {} (delta {d})",
692 gram[[i, j]],
693 naive[[i, j]]
694 );
695 }
696 }
697 for i in 0..k {
699 for j in 0..k {
700 assert_eq!(gram[[i, j]].to_bits(), gram[[j, i]].to_bits());
701 }
702 }
703 }
704
705 #[test]
706 fn bit_reproducible_across_chunk_submission_orders() {
707 let n = 2 * CROSS_CHUNK_BASE * 3 + 7; let k = 4;
711 let chunk_size = 2; let rows = planted_rows(n, k);
713 let n_chunks = n.div_ceil(chunk_size);
714
715 let in_order: Vec<usize> = (0..n_chunks).collect();
716 let reversed: Vec<usize> = (0..n_chunks).rev().collect();
717 let strided: Vec<usize> = (0..n_chunks).map(|i| (i * 129) % n_chunks).collect();
719
720 let g0 = run_with_order(&rows, chunk_size, &in_order);
721 let g1 = run_with_order(&rows, chunk_size, &reversed);
722 let g2 = run_with_order(&rows, chunk_size, &strided);
723
724 assert_bit_identical(&g0, &g1, "in-order vs reversed submission");
725 assert_bit_identical(&g0, &g2, "in-order vs strided submission");
726 }
727
728 #[test]
729 fn cross_chunk_association_matches_landed_pairwise_sum() {
730 let n = 613;
734 let k = 3;
735 let chunk_size = 2; let rows = planted_rows(n, k);
737 let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
738 let n_chunks = acc.n_chunks();
739 let mut per_chunk_entries: Vec<Vec<f64>> = vec![Vec::with_capacity(n_chunks); k * k];
740 for j in 0..n_chunks {
741 let range = acc.chunk_rows(j);
742 let chunk = rows.slice(ndarray::s![range, ..]);
743 let g = acc.chunk_gram(chunk);
744 for (e, vals) in g.iter().zip(per_chunk_entries.iter_mut()) {
745 vals.push(*e);
746 }
747 acc.submit_chunk(j, chunk).expect("submit");
748 }
749 let gram = acc.finish().expect("finish");
750 for a in 0..k {
751 for b in 0..k {
752 let expected = pairwise_sum(&per_chunk_entries[a * k + b]);
753 assert_eq!(
754 gram[[a, b]].to_bits(),
755 expected.to_bits(),
756 "entry ({a},{b}): cascade {} vs pairwise_sum {}",
757 gram[[a, b]],
758 expected
759 );
760 }
761 }
762 }
763
764 #[test]
765 fn resume_equals_straight_through() {
766 let n = 491;
767 let k = 4;
768 let chunk_size = 3;
769 let rows = planted_rows(n, k);
770 let (acc, order) = accumulate_in_order(&rows, chunk_size);
771 let n_chunks = acc.n_chunks();
772 let straight = run_with_order(&rows, chunk_size, &order);
774
775 let mut first = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
780 let mut submitted = vec![false; n_chunks];
781 let prefix: Vec<usize> = (0..60).chain([150, 100, 163]).collect();
783 for &j in &prefix {
784 let range = first.chunk_rows(j);
785 first
786 .submit_chunk(j, rows.slice(ndarray::s![range, ..]))
787 .expect("prefix submit");
788 submitted[j] = true;
789 }
790 assert!(
791 !first.pending.is_empty(),
792 "fixture must exercise pending out-of-order state"
793 );
794 let json = serde_json::to_string(&first.checkpoint()).expect("serialize checkpoint");
795 drop(first);
796 let restored: BorderGramCheckpoint =
797 serde_json::from_str(&json).expect("deserialize checkpoint");
798 let mut second = StreamingBorderGram::resume(restored).expect("resume");
799 for j in 0..n_chunks {
800 if submitted[j] {
801 continue;
802 }
803 let range = second.chunk_rows(j);
804 second
805 .submit_chunk(j, rows.slice(ndarray::s![range, ..]))
806 .expect("resumed submit");
807 }
808 let resumed = second.finish().expect("finish resumed");
809 assert_bit_identical(&straight, &resumed, "resume vs straight-through");
810 }
811
812 #[test]
813 fn rejects_duplicates_missing_chunks_and_bad_shapes() {
814 let n = 10;
815 let k = 2;
816 let chunk_size = 4; let rows = planted_rows(n, k);
818 let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
819 assert_eq!(acc.n_chunks(), 3);
820
821 let err = acc
823 .submit_chunk(0, rows.slice(ndarray::s![0..3, ..]))
824 .expect_err("short chunk must be rejected");
825 assert!(err.contains("expected (4, 2)"), "got: {err}");
826
827 acc.submit_chunk(0, rows.slice(ndarray::s![0..4, ..]))
828 .expect("chunk 0");
829 let err = acc
831 .submit_chunk(0, rows.slice(ndarray::s![0..4, ..]))
832 .expect_err("duplicate must be rejected");
833 assert!(err.contains("already submitted"), "got: {err}");
834
835 acc.submit_chunk(2, rows.slice(ndarray::s![8..10, ..]))
837 .expect("chunk 2 out of order");
838 let err = acc
839 .submit_chunk(2, rows.slice(ndarray::s![8..10, ..]))
840 .expect_err("duplicate pending must be rejected");
841 assert!(err.contains("already submitted"), "got: {err}");
842
843 let err = acc
845 .submit_chunk(3, rows.slice(ndarray::s![0..4, ..]))
846 .expect_err("out-of-range index must be rejected");
847 assert!(err.contains("out of range"), "got: {err}");
848
849 let err = acc.finish().expect_err("missing chunk must fail finish");
851 assert!(
852 err.contains("[1]"),
853 "missing-chunk message must name chunk 1: {err}"
854 );
855 }
856
857 #[test]
858 fn checkpoint_validation_rejects_corruption() {
859 let mut acc = StreamingBorderGram::new(3, 100, 10).expect("accumulator");
860 let rows = planted_rows(100, 3);
861 acc.submit_chunk(0, rows.slice(ndarray::s![0..10, ..]))
862 .expect("chunk 0");
863 let good = acc.checkpoint();
864
865 let mut bad = good.clone();
866 bad.block_len = 0; assert!(StreamingBorderGram::resume(bad).is_err());
868
869 let mut bad = good.clone();
870 if let Some(b) = bad.block_partial.as_mut() {
871 b.pop(); }
873 assert!(StreamingBorderGram::resume(bad).is_err());
874
875 let mut bad = good.clone();
876 bad.pending.push((0, vec![0.0; 9])); assert!(StreamingBorderGram::resume(bad).is_err());
878
879 let mut bad = good;
880 bad.frontier = 99; assert!(StreamingBorderGram::resume(bad).is_err());
882 }
883
884 #[test]
885 fn chunk_assembler_is_batching_invariant() {
886 let n = 463;
889 let k = 4;
890 let chunk_size = 16;
891 let rows = planted_rows(n, k);
892 let direct = {
893 let (acc, order) = accumulate_in_order(&rows, chunk_size);
894 drop(acc);
895 run_with_order(&rows, chunk_size, &order)
896 };
897
898 let mut asm = ChunkAssembler::new(k, n, chunk_size).expect("assembler");
900 let sizes = [3usize, 5, 7, 11, 13];
901 let mut at = 0usize;
902 let mut s = 0usize;
903 while at < n {
904 let take = sizes[s % sizes.len()].min(n - at);
905 asm.push_rows(rows.slice(ndarray::s![at..at + take, ..]))
906 .expect("push");
907 at += take;
908 s += 1;
909 }
910 let assembled = asm.finish().expect("finish");
911 assert_bit_identical(&direct, &assembled, "direct vs assembled batching");
912 }
913
914 #[test]
915 fn chunk_assembler_checkpoints_only_at_boundaries_and_resumes() {
916 let n = 200;
917 let k = 3;
918 let chunk_size = 10;
919 let rows = planted_rows(n, k);
920 let direct = run_with_order(&rows, chunk_size, &(0..20).collect::<Vec<_>>());
921
922 let mut asm = ChunkAssembler::new(k, n, chunk_size).expect("assembler");
923 asm.push_rows(rows.slice(ndarray::s![0..7, ..]))
925 .expect("push");
926 assert!(
927 asm.checkpoint().is_none(),
928 "mid-chunk checkpoint must be None"
929 );
930 asm.push_rows(rows.slice(ndarray::s![7..30, ..]))
932 .expect("push");
933 let cp = asm.checkpoint().expect("boundary checkpoint");
934 assert_eq!(cp.frontier, 3);
935 drop(asm);
936
937 let mut resumed = ChunkAssembler::resume(cp).expect("resume");
939 resumed
940 .push_rows(rows.slice(ndarray::s![30..n, ..]))
941 .expect("push rest");
942 let gram = resumed.finish().expect("finish");
943 assert_bit_identical(&direct, &gram, "assembler resume vs straight-through");
944 }
945
946 #[test]
947 fn chunk_assembler_rejects_truncated_and_overrunning_streams() {
948 let k = 2;
949 let rows = planted_rows(30, k);
950 let mut asm = ChunkAssembler::new(k, 30, 8).expect("assembler");
952 asm.push_rows(rows.slice(ndarray::s![0..25, ..]))
953 .expect("push");
954 let err = asm.finish().expect_err("truncated stream must fail finish");
955 assert!(err.contains("mid-chunk"), "got: {err}");
956 let mut asm = ChunkAssembler::new(k, 20, 8).expect("assembler");
958 let err = asm
959 .push_rows(rows.slice(ndarray::s![0..25, ..]))
960 .expect_err("overrun must be rejected");
961 assert!(err.contains("overran"), "got: {err}");
962 }
963
964 const MIXED_PRECISION_BORDER_RTOL: f64 = 1.0e-5;
973
974 #[test]
975 fn f32_storage_f64_accumulation_meets_the_error_budget() {
976 let n = 700;
977 let k = 5;
978 let chunk_size = 32;
979 let rows = planted_rows(n, k);
980 let stored = rows.mapv(|v| f64::from(v as f32));
983 let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
984 for j in 0..acc.n_chunks() {
985 let range = acc.chunk_rows(j);
986 acc.submit_chunk(j, stored.slice(ndarray::s![range, ..]))
987 .expect("submit");
988 }
989 let mixed = acc.finish().expect("finish");
990 let exact = rows.t().dot(&rows);
992 let scale = exact.iter().fold(0.0_f64, |m, &v| m.max(v.abs())).max(1.0);
993 for i in 0..k {
994 for j in 0..k {
995 let d = (mixed[[i, j]] - exact[[i, j]]).abs();
996 assert!(
997 d <= MIXED_PRECISION_BORDER_RTOL * scale,
998 "Gram[{i},{j}] mixed-precision delta {d:.3e} exceeds budget \
999 {MIXED_PRECISION_BORDER_RTOL:.0e} × scale {scale:.3e}"
1000 );
1001 }
1002 }
1003 }
1004
1005 #[test]
1006 fn zero_rows_yields_zero_gram() {
1007 let acc = StreamingBorderGram::new(3, 0, 8).expect("accumulator");
1008 assert_eq!(acc.n_chunks(), 0);
1009 assert!(acc.is_complete());
1010 let gram = acc.finish().expect("finish empty");
1011 assert_eq!(gram.dim(), (3, 3));
1012 assert!(gram.iter().all(|v| v.to_bits() == 0.0_f64.to_bits()));
1013 }
1014}