use crate::linalg::pairwise_reduce::{BASE_CHUNK, pairwise_sum};
use ndarray::{Array2, ArrayView2};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
pub const CROSS_CHUNK_BASE: usize = BASE_CHUNK;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BorderGramCheckpoint {
pub border_dim: usize,
pub n_rows: usize,
pub chunk_size: usize,
pub frontier: usize,
pub block_partial: Option<Vec<f64>>,
pub block_len: usize,
pub forest: Vec<(usize, Vec<f64>)>,
pub pending: Vec<(usize, Vec<f64>)>,
}
pub struct StreamingBorderGram {
border_dim: usize,
n_rows: usize,
chunk_size: usize,
frontier: usize,
block_partial: Option<Vec<f64>>,
block_len: usize,
forest: Vec<(usize, Vec<f64>)>,
pending: BTreeMap<usize, Vec<f64>>,
}
fn add_into(acc: &mut [f64], rhs: &[f64]) {
for (a, r) in acc.iter_mut().zip(rhs.iter()) {
*a += *r;
}
}
pub fn chunk_gram_flat(rows: ArrayView2<'_, f64>) -> Vec<f64> {
let k = rows.ncols();
let r = rows.nrows();
let mut gram = vec![0.0_f64; k * k];
let mut products = vec![0.0_f64; r];
for a in 0..k {
for b in a..k {
for (i, p) in products.iter_mut().enumerate() {
*p = rows[[i, a]] * rows[[i, b]];
}
let s = pairwise_sum(&products);
gram[a * k + b] = s;
gram[b * k + a] = s;
}
}
gram
}
impl StreamingBorderGram {
pub fn new(border_dim: usize, n_rows: usize, chunk_size: usize) -> Result<Self, String> {
if border_dim == 0 {
return Err("StreamingBorderGram: border_dim must be positive".to_string());
}
if chunk_size == 0 {
return Err("StreamingBorderGram: chunk_size must be positive".to_string());
}
Ok(Self {
border_dim,
n_rows,
chunk_size,
frontier: 0,
block_partial: None,
block_len: 0,
forest: Vec::new(),
pending: BTreeMap::new(),
})
}
pub fn n_chunks(&self) -> usize {
self.n_rows.div_ceil(self.chunk_size)
}
pub fn chunk_rows(&self, chunk_index: usize) -> std::ops::Range<usize> {
let lo = chunk_index * self.chunk_size;
let hi = ((chunk_index + 1) * self.chunk_size).min(self.n_rows);
lo..hi
}
pub fn frontier(&self) -> usize {
self.frontier
}
pub fn is_complete(&self) -> bool {
self.frontier == self.n_chunks() && self.pending.is_empty()
}
pub fn submit_chunk(
&mut self,
chunk_index: usize,
rows: ArrayView2<'_, f64>,
) -> Result<(), String> {
let n_chunks = self.n_chunks();
if chunk_index >= n_chunks {
return Err(format!(
"StreamingBorderGram: chunk index {chunk_index} out of range (n_chunks = {n_chunks})"
));
}
if chunk_index < self.frontier || self.pending.contains_key(&chunk_index) {
return Err(format!(
"StreamingBorderGram: chunk {chunk_index} was already submitted"
));
}
let expected_rows = self.chunk_rows(chunk_index).len();
if rows.nrows() != expected_rows || rows.ncols() != self.border_dim {
return Err(format!(
"StreamingBorderGram: chunk {chunk_index} has shape ({}, {}) but expected ({}, {})",
rows.nrows(),
rows.ncols(),
expected_rows,
self.border_dim
));
}
let gram = self.chunk_gram(rows);
self.fold_or_park(chunk_index, gram);
Ok(())
}
pub fn submit_chunk_gram(&mut self, chunk_index: usize, gram: Vec<f64>) -> Result<(), String> {
let n_chunks = self.n_chunks();
if chunk_index >= n_chunks {
return Err(format!(
"StreamingBorderGram: chunk index {chunk_index} out of range (n_chunks = {n_chunks})"
));
}
if chunk_index < self.frontier || self.pending.contains_key(&chunk_index) {
return Err(format!(
"StreamingBorderGram: chunk {chunk_index} was already submitted"
));
}
let kk = self.border_dim * self.border_dim;
if gram.len() != kk {
return Err(format!(
"StreamingBorderGram: chunk {chunk_index} partial has len {} but expected {kk}",
gram.len()
));
}
if !gram.iter().all(|v| v.is_finite()) {
return Err(format!(
"StreamingBorderGram: chunk {chunk_index} partial contains non-finite entries"
));
}
self.fold_or_park(chunk_index, gram);
Ok(())
}
fn fold_or_park(&mut self, chunk_index: usize, gram: Vec<f64>) {
if chunk_index == self.frontier {
self.fold_chunk(gram);
self.frontier += 1;
while let Some(next) = self.pending.remove(&self.frontier) {
self.fold_chunk(next);
self.frontier += 1;
}
} else {
self.pending.insert(chunk_index, gram);
}
}
fn chunk_gram(&self, rows: ArrayView2<'_, f64>) -> Vec<f64> {
chunk_gram_flat(rows)
}
fn fold_chunk(&mut self, gram: Vec<f64>) {
match self.block_partial.as_mut() {
None => {
self.block_partial = Some(gram);
self.block_len = 1;
}
Some(acc) => {
add_into(acc, &gram);
self.block_len += 1;
}
}
if self.block_len == CROSS_CHUNK_BASE {
let block = self
.block_partial
.take()
.expect("block_len == CROSS_CHUNK_BASE implies a live block partial");
self.block_len = 0;
self.absorb(CROSS_CHUNK_BASE, block);
}
}
fn absorb(&mut self, weight: usize, value: Vec<f64>) {
let mut w = weight;
let mut v = value;
while let Some((top_w, _)) = self.forest.last() {
if *top_w == w {
let (_, top_v) = self
.forest
.pop()
.expect("forest top exists: just observed by last()");
v = {
let mut merged = top_v;
add_into(&mut merged, &v);
merged
};
w = w.saturating_mul(2);
} else {
break;
}
}
self.forest.push((w, v));
}
pub fn checkpoint(&self) -> BorderGramCheckpoint {
BorderGramCheckpoint {
border_dim: self.border_dim,
n_rows: self.n_rows,
chunk_size: self.chunk_size,
frontier: self.frontier,
block_partial: self.block_partial.clone(),
block_len: self.block_len,
forest: self.forest.clone(),
pending: self
.pending
.iter()
.map(|(idx, g)| (*idx, g.clone()))
.collect(),
}
}
pub fn resume(state: BorderGramCheckpoint) -> Result<Self, String> {
if state.border_dim == 0 {
return Err("BorderGramCheckpoint: border_dim must be positive".to_string());
}
if state.chunk_size == 0 {
return Err("BorderGramCheckpoint: chunk_size must be positive".to_string());
}
let kk = state.border_dim * state.border_dim;
let n_chunks = state.n_rows.div_ceil(state.chunk_size);
if state.frontier > n_chunks {
return Err(format!(
"BorderGramCheckpoint: frontier {} exceeds n_chunks {n_chunks}",
state.frontier
));
}
if state.block_len >= CROSS_CHUNK_BASE {
return Err(format!(
"BorderGramCheckpoint: block_len {} must be < CROSS_CHUNK_BASE {CROSS_CHUNK_BASE}",
state.block_len
));
}
if state.block_partial.is_some() != (state.block_len > 0) {
return Err(
"BorderGramCheckpoint: block_partial presence inconsistent with block_len"
.to_string(),
);
}
if let Some(b) = &state.block_partial {
if b.len() != kk {
return Err(format!(
"BorderGramCheckpoint: block_partial has len {} but expected {kk}",
b.len()
));
}
}
for (w, g) in &state.forest {
if *w == 0 || g.len() != kk {
return Err(
"BorderGramCheckpoint: malformed forest partial (zero weight or wrong len)"
.to_string(),
);
}
}
let mut pending = BTreeMap::new();
for (idx, g) in state.pending {
if idx < state.frontier || idx >= n_chunks {
return Err(format!(
"BorderGramCheckpoint: pending chunk index {idx} outside (frontier {}, n_chunks {n_chunks})",
state.frontier
));
}
if g.len() != kk {
return Err(format!(
"BorderGramCheckpoint: pending chunk {idx} partial has len {} but expected {kk}",
g.len()
));
}
if pending.insert(idx, g).is_some() {
return Err(format!(
"BorderGramCheckpoint: duplicate pending chunk index {idx}"
));
}
}
Ok(Self {
border_dim: state.border_dim,
n_rows: state.n_rows,
chunk_size: state.chunk_size,
frontier: state.frontier,
block_partial: state.block_partial,
block_len: state.block_len,
forest: state.forest,
pending,
})
}
pub fn finish(mut self) -> Result<Array2<f64>, String> {
let n_chunks = self.n_chunks();
if self.frontier != n_chunks {
let missing: Vec<usize> = (self.frontier..n_chunks)
.filter(|idx| !self.pending.contains_key(idx))
.take(8)
.collect();
return Err(format!(
"StreamingBorderGram: finish() before all chunks were submitted \
(frontier {}/{n_chunks}, first missing chunk indices {missing:?})",
self.frontier
));
}
if let Some(tail) = self.block_partial.take() {
let w = self.block_len;
self.block_len = 0;
self.forest.push((w, tail));
}
let k = self.border_dim;
let mut iter = self.forest.into_iter().rev();
let flat = match iter.next() {
None => vec![0.0_f64; k * k],
Some((_, mut acc)) => {
for (_, left) in iter {
add_into(&mut acc, &left);
}
acc
}
};
Array2::from_shape_vec((k, k), flat)
.map_err(|e| format!("StreamingBorderGram: Gram reshape failed: {e}"))
}
}
pub struct ChunkAssembler {
gram: StreamingBorderGram,
buffer: Vec<f64>,
next_chunk: usize,
}
impl ChunkAssembler {
pub fn new(border_dim: usize, n_rows: usize, chunk_size: usize) -> Result<Self, String> {
Ok(Self {
gram: StreamingBorderGram::new(border_dim, n_rows, chunk_size)?,
buffer: Vec::new(),
next_chunk: 0,
})
}
fn buffered_rows(&self) -> usize {
let k = self.gram.border_dim;
assert!(
self.buffer.len() % k == 0,
"ChunkAssembler buffer length {} is not a multiple of border_dim {k}",
self.buffer.len()
);
self.buffer.len() / k
}
pub fn push_rows(&mut self, rows: ArrayView2<'_, f64>) -> Result<(), String> {
let k = self.gram.border_dim;
if rows.ncols() != k {
return Err(format!(
"ChunkAssembler: batch has {} cols but border_dim is {k}",
rows.ncols()
));
}
let n_chunks = self.gram.n_chunks();
let consumed = (self.gram.frontier() * self.gram.chunk_size).min(self.gram.n_rows);
let total_seen = consumed + self.buffered_rows() + rows.nrows();
if total_seen > self.gram.n_rows {
return Err(format!(
"ChunkAssembler: stream overran the declared row count ({} > {})",
total_seen, self.gram.n_rows
));
}
for row in rows.outer_iter() {
self.buffer.extend(row.iter().copied());
}
while self.next_chunk < n_chunks {
let need = self.gram.chunk_rows(self.next_chunk).len();
if self.buffered_rows() < need {
break;
}
let chunk: Vec<f64> = self.buffer.drain(..need * k).collect();
let view = ndarray::ArrayView2::from_shape((need, k), &chunk)
.map_err(|e| format!("ChunkAssembler: chunk reshape failed: {e}"))?;
self.gram.submit_chunk(self.next_chunk, view)?;
self.next_chunk += 1;
}
Ok(())
}
pub fn checkpoint(&self) -> Option<BorderGramCheckpoint> {
if self.buffer.is_empty() {
Some(self.gram.checkpoint())
} else {
None
}
}
pub fn resume(state: BorderGramCheckpoint) -> Result<Self, String> {
let gram = StreamingBorderGram::resume(state)?;
let next_chunk = gram.frontier();
Ok(Self {
gram,
buffer: Vec::new(),
next_chunk,
})
}
pub fn finish(self) -> Result<Array2<f64>, String> {
if !self.buffer.is_empty() {
let k = self.gram.border_dim;
return Err(format!(
"ChunkAssembler: stream ended mid-chunk with {} buffered rows \
(declared n_rows = {})",
self.buffer.len() / k,
self.gram.n_rows
));
}
self.gram.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn planted_rows(n: usize, k: usize) -> Array2<f64> {
Array2::from_shape_fn((n, k), |(i, j)| {
let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
(x.sin() * 43_758.547).fract() * 2.0 - 1.0
})
}
fn accumulate_in_order(
rows: &Array2<f64>,
chunk_size: usize,
) -> (StreamingBorderGram, Vec<usize>) {
let acc =
StreamingBorderGram::new(rows.ncols(), rows.nrows(), chunk_size).expect("accumulator");
let order: Vec<usize> = (0..acc.n_chunks()).collect();
(acc, order)
}
fn run_with_order(rows: &Array2<f64>, chunk_size: usize, order: &[usize]) -> Array2<f64> {
let mut acc =
StreamingBorderGram::new(rows.ncols(), rows.nrows(), chunk_size).expect("accumulator");
for &j in order {
let range = acc.chunk_rows(j);
acc.submit_chunk(j, rows.slice(ndarray::s![range, ..]))
.expect("submit");
}
acc.finish().expect("finish")
}
fn assert_bit_identical(a: &Array2<f64>, b: &Array2<f64>, label: &str) {
assert_eq!(a.dim(), b.dim(), "{label}: shape mismatch");
for ((idx, x), y) in a.indexed_iter().zip(b.iter()) {
assert_eq!(
x.to_bits(),
y.to_bits(),
"{label}: entry {idx:?} differs bitwise: {x:?} vs {y:?}"
);
}
}
#[test]
fn gram_matches_naive_xtx() {
let n = 257; let k = 5;
let rows = planted_rows(n, k);
let gram = run_with_order(&rows, 16, &(0..17).collect::<Vec<_>>());
let naive = rows.t().dot(&rows);
for i in 0..k {
for j in 0..k {
let d = (gram[[i, j]] - naive[[i, j]]).abs();
let scale = naive[[i, j]].abs().max(1.0);
assert!(
d <= 1.0e-12 * scale,
"Gram[{i},{j}] = {} vs naive {} (delta {d})",
gram[[i, j]],
naive[[i, j]]
);
}
}
for i in 0..k {
for j in 0..k {
assert_eq!(gram[[i, j]].to_bits(), gram[[j, i]].to_bits());
}
}
}
#[test]
fn bit_reproducible_across_chunk_submission_orders() {
let n = 2 * CROSS_CHUNK_BASE * 3 + 7; let k = 4;
let chunk_size = 2; let rows = planted_rows(n, k);
let n_chunks = n.div_ceil(chunk_size);
let in_order: Vec<usize> = (0..n_chunks).collect();
let reversed: Vec<usize> = (0..n_chunks).rev().collect();
let strided: Vec<usize> = (0..n_chunks).map(|i| (i * 129) % n_chunks).collect();
let g0 = run_with_order(&rows, chunk_size, &in_order);
let g1 = run_with_order(&rows, chunk_size, &reversed);
let g2 = run_with_order(&rows, chunk_size, &strided);
assert_bit_identical(&g0, &g1, "in-order vs reversed submission");
assert_bit_identical(&g0, &g2, "in-order vs strided submission");
}
#[test]
fn cross_chunk_association_matches_landed_pairwise_sum() {
let n = 613;
let k = 3;
let chunk_size = 2; let rows = planted_rows(n, k);
let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
let n_chunks = acc.n_chunks();
let mut per_chunk_entries: Vec<Vec<f64>> = vec![Vec::with_capacity(n_chunks); k * k];
for j in 0..n_chunks {
let range = acc.chunk_rows(j);
let chunk = rows.slice(ndarray::s![range, ..]);
let g = acc.chunk_gram(chunk);
for (e, vals) in g.iter().zip(per_chunk_entries.iter_mut()) {
vals.push(*e);
}
acc.submit_chunk(j, chunk).expect("submit");
}
let gram = acc.finish().expect("finish");
for a in 0..k {
for b in 0..k {
let expected = pairwise_sum(&per_chunk_entries[a * k + b]);
assert_eq!(
gram[[a, b]].to_bits(),
expected.to_bits(),
"entry ({a},{b}): cascade {} vs pairwise_sum {}",
gram[[a, b]],
expected
);
}
}
}
#[test]
fn resume_equals_straight_through() {
let n = 491;
let k = 4;
let chunk_size = 3;
let rows = planted_rows(n, k);
let (acc, order) = accumulate_in_order(&rows, chunk_size);
let n_chunks = acc.n_chunks();
let straight = run_with_order(&rows, chunk_size, &order);
let mut first = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
let mut submitted = vec![false; n_chunks];
let prefix: Vec<usize> = (0..60).chain([150, 100, 163]).collect();
for &j in &prefix {
let range = first.chunk_rows(j);
first
.submit_chunk(j, rows.slice(ndarray::s![range, ..]))
.expect("prefix submit");
submitted[j] = true;
}
assert!(
!first.pending.is_empty(),
"fixture must exercise pending out-of-order state"
);
let json = serde_json::to_string(&first.checkpoint()).expect("serialize checkpoint");
drop(first);
let restored: BorderGramCheckpoint =
serde_json::from_str(&json).expect("deserialize checkpoint");
let mut second = StreamingBorderGram::resume(restored).expect("resume");
for j in 0..n_chunks {
if submitted[j] {
continue;
}
let range = second.chunk_rows(j);
second
.submit_chunk(j, rows.slice(ndarray::s![range, ..]))
.expect("resumed submit");
}
let resumed = second.finish().expect("finish resumed");
assert_bit_identical(&straight, &resumed, "resume vs straight-through");
}
#[test]
fn rejects_duplicates_missing_chunks_and_bad_shapes() {
let n = 10;
let k = 2;
let chunk_size = 4; let rows = planted_rows(n, k);
let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
assert_eq!(acc.n_chunks(), 3);
let err = acc
.submit_chunk(0, rows.slice(ndarray::s![0..3, ..]))
.expect_err("short chunk must be rejected");
assert!(err.contains("expected (4, 2)"), "got: {err}");
acc.submit_chunk(0, rows.slice(ndarray::s![0..4, ..]))
.expect("chunk 0");
let err = acc
.submit_chunk(0, rows.slice(ndarray::s![0..4, ..]))
.expect_err("duplicate must be rejected");
assert!(err.contains("already submitted"), "got: {err}");
acc.submit_chunk(2, rows.slice(ndarray::s![8..10, ..]))
.expect("chunk 2 out of order");
let err = acc
.submit_chunk(2, rows.slice(ndarray::s![8..10, ..]))
.expect_err("duplicate pending must be rejected");
assert!(err.contains("already submitted"), "got: {err}");
let err = acc
.submit_chunk(3, rows.slice(ndarray::s![0..4, ..]))
.expect_err("out-of-range index must be rejected");
assert!(err.contains("out of range"), "got: {err}");
let err = acc.finish().expect_err("missing chunk must fail finish");
assert!(
err.contains("[1]"),
"missing-chunk message must name chunk 1: {err}"
);
}
#[test]
fn checkpoint_validation_rejects_corruption() {
let mut acc = StreamingBorderGram::new(3, 100, 10).expect("accumulator");
let rows = planted_rows(100, 3);
acc.submit_chunk(0, rows.slice(ndarray::s![0..10, ..]))
.expect("chunk 0");
let good = acc.checkpoint();
let mut bad = good.clone();
bad.block_len = 0; assert!(StreamingBorderGram::resume(bad).is_err());
let mut bad = good.clone();
if let Some(b) = bad.block_partial.as_mut() {
b.pop(); }
assert!(StreamingBorderGram::resume(bad).is_err());
let mut bad = good.clone();
bad.pending.push((0, vec![0.0; 9])); assert!(StreamingBorderGram::resume(bad).is_err());
let mut bad = good;
bad.frontier = 99; assert!(StreamingBorderGram::resume(bad).is_err());
}
#[test]
fn chunk_assembler_is_batching_invariant() {
let n = 463;
let k = 4;
let chunk_size = 16;
let rows = planted_rows(n, k);
let direct = {
let (acc, order) = accumulate_in_order(&rows, chunk_size);
drop(acc);
run_with_order(&rows, chunk_size, &order)
};
let mut asm = ChunkAssembler::new(k, n, chunk_size).expect("assembler");
let sizes = [3usize, 5, 7, 11, 13];
let mut at = 0usize;
let mut s = 0usize;
while at < n {
let take = sizes[s % sizes.len()].min(n - at);
asm.push_rows(rows.slice(ndarray::s![at..at + take, ..]))
.expect("push");
at += take;
s += 1;
}
let assembled = asm.finish().expect("finish");
assert_bit_identical(&direct, &assembled, "direct vs assembled batching");
}
#[test]
fn chunk_assembler_checkpoints_only_at_boundaries_and_resumes() {
let n = 200;
let k = 3;
let chunk_size = 10;
let rows = planted_rows(n, k);
let direct = run_with_order(&rows, chunk_size, &(0..20).collect::<Vec<_>>());
let mut asm = ChunkAssembler::new(k, n, chunk_size).expect("assembler");
asm.push_rows(rows.slice(ndarray::s![0..7, ..]))
.expect("push");
assert!(
asm.checkpoint().is_none(),
"mid-chunk checkpoint must be None"
);
asm.push_rows(rows.slice(ndarray::s![7..30, ..]))
.expect("push");
let cp = asm.checkpoint().expect("boundary checkpoint");
assert_eq!(cp.frontier, 3);
drop(asm);
let mut resumed = ChunkAssembler::resume(cp).expect("resume");
resumed
.push_rows(rows.slice(ndarray::s![30..n, ..]))
.expect("push rest");
let gram = resumed.finish().expect("finish");
assert_bit_identical(&direct, &gram, "assembler resume vs straight-through");
}
#[test]
fn chunk_assembler_rejects_truncated_and_overrunning_streams() {
let k = 2;
let rows = planted_rows(30, k);
let mut asm = ChunkAssembler::new(k, 30, 8).expect("assembler");
asm.push_rows(rows.slice(ndarray::s![0..25, ..]))
.expect("push");
let err = asm.finish().expect_err("truncated stream must fail finish");
assert!(err.contains("mid-chunk"), "got: {err}");
let mut asm = ChunkAssembler::new(k, 20, 8).expect("assembler");
let err = asm
.push_rows(rows.slice(ndarray::s![0..25, ..]))
.expect_err("overrun must be rejected");
assert!(err.contains("overran"), "got: {err}");
}
const MIXED_PRECISION_BORDER_RTOL: f64 = 1.0e-5;
#[test]
fn f32_storage_f64_accumulation_meets_the_error_budget() {
let n = 700;
let k = 5;
let chunk_size = 32;
let rows = planted_rows(n, k);
let stored = rows.mapv(|v| f64::from(v as f32));
let mut acc = StreamingBorderGram::new(k, n, chunk_size).expect("accumulator");
for j in 0..acc.n_chunks() {
let range = acc.chunk_rows(j);
acc.submit_chunk(j, stored.slice(ndarray::s![range, ..]))
.expect("submit");
}
let mixed = acc.finish().expect("finish");
let exact = rows.t().dot(&rows);
let scale = exact.iter().fold(0.0_f64, |m, &v| m.max(v.abs())).max(1.0);
for i in 0..k {
for j in 0..k {
let d = (mixed[[i, j]] - exact[[i, j]]).abs();
assert!(
d <= MIXED_PRECISION_BORDER_RTOL * scale,
"Gram[{i},{j}] mixed-precision delta {d:.3e} exceeds budget \
{MIXED_PRECISION_BORDER_RTOL:.0e} × scale {scale:.3e}"
);
}
}
}
#[test]
fn zero_rows_yields_zero_gram() {
let acc = StreamingBorderGram::new(3, 0, 8).expect("accumulator");
assert_eq!(acc.n_chunks(), 0);
assert!(acc.is_complete());
let gram = acc.finish().expect("finish empty");
assert_eq!(gram.dim(), (3, 3));
assert!(gram.iter().all(|v| v.to_bits() == 0.0_f64.to_bits()));
}
}