gam_terms/basis/measure_jet_moments.rs
1//! Measure-jet frame data interface: per-cell frozen-weight polynomial
2//! moment tables with a binomial-shift merge monoid
3//! (`docs/measure_jet_frame.md`, §2 "Data interface: moments or
4//! nothing").
5//!
6//! This module aggregates caller-computed weights into order-0..2 coordinate
7//! moments. Those tables exactly determine polynomial couplings under the
8//! same frozen weights, including the local affine sufficient statistics used
9//! by `measure_jet_smooth.rs`. They do NOT exactly determine Gaussian
10//! transforms at moved kernel centers: support curves, Gaussian Gram entries,
11//! and Gaussian `XᵀWX` products need their own kernel pass or a separately
12//! controlled approximation. Truncation does NOT live here either: the caller
13//! computes the Gaussian weights `w_i` (mass × kernel profile, with whatever
14//! cutoff its explicit `e^{−ρ²/2}` tolerance budget licenses) and this module
15//! only aggregates what it is handed.
16//!
17//! # The monoid
18//!
19//! A table holds, per response channel `g` and per coordinate multi-index
20//! `α` with `|α| ≤ 2`, the centered moment `μ_α = Σ_i w_i g_i (x_i − c)^α`
21//! about the cell reference point `c`. The binomial shift
22//!
23//! ```text
24//! μ′_α = Σ_{β ≤ α} C(α, β) (c − c′)^{α−β} μ_β
25//! ```
26//!
27//! re-expresses the same frozen-weight polynomial table about any other
28//! center `c′` exactly as a finite polynomial identity. It does not move the
29//! Gaussian kernel center or recompute weights. Merging two tables with
30//! already-compatible frozen weights is therefore "recenter to a common
31//! reference, add componentwise":
32//! an associative, commutative monoid whose identity is the empty (all-zero)
33//! table at any center. Exact distributed fitting, exact online updates, and
34//! bit-reproducibility under sorted reduction are corollaries of that one
35//! algebraic fact ([`merge_moment_tables`] is a monoid homomorphism from
36//! disjoint row sets under union to tables under ⊕).
37//!
38//! # Determinism / bit-exactness convention (sorted reduction)
39//!
40//! Floating-point addition is commutative but not associative, so the monoid
41//! laws hold algebraically while bit-patterns depend on reduction ORDER.
42//! This module pins one order everywhere:
43//!
44//! - [`accumulate_moment_table`] splits rows into fixed-size chunks
45//! ([`MEASURE_JET_MOMENT_CHUNK_ROWS`], never derived from thread count),
46//! accumulates each chunk sequentially in row order, and folds the chunk
47//! partials sequentially in chunk-index order — the sorted reduction. The
48//! result is bit-identical across runs, machines, and rayon pool sizes.
49//! - [`recenter_moment_table`] evaluates the shift in ONE fixed expression
50//! order (documented at the site).
51//! - [`merge_moment_tables`] canonically orients its operands by the
52//! lexicographic total order on centers (`f64::total_cmp` per coordinate),
53//! so `a ⊕ b` and `b ⊕ a` execute the SAME instruction stream and are
54//! bit-identical for arbitrary inputs.
55//!
56//! Cross-GROUPING bit-identity — `(A⊕B)⊕C` vs `A⊕(B⊕C)` — additionally
57//! requires the moment arithmetic itself to be exact; the in-module tests
58//! pin it on dyadic lattices (integer coordinates/channels, dyadic weights),
59//! where every product and sum is exactly representable, and callers
60//! reducing many chunks get run-to-run determinism by folding in chunk-index
61//! order exactly as the accumulator does.
62//!
63//! # 1:1 contract with `assemble_weighted_forms`
64//!
65//! [`jet_sufficient_stats`] reproduces, in closed form from a stored table
66//! whose weights were computed for the same center and scale, exactly the
67//! local-fit quantities the current workhorse
68//! (`measure_jet_smooth.rs::assemble_weighted_forms`) computes from raw
69//! points per (center, scale) block: the kernel mass `q`, the dimensionless
70//! weighted feature mean `a_mean`, the dimensionless slope Gram
71//! `G = Φ̃ᵀWΦ̃/q`, the weighted channel mean `uᵀv`, and the exact-projection
72//! right-hand side `Bᵀv/q` — so the substrate can later replace that
73//! same-center point loop without changing a single number.
74
75use std::cmp::Ordering;
76use std::ops::Range;
77
78use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
79use rayon::prelude::*;
80
81use super::BasisError;
82
83/// Rows per chunk in the streaming accumulation fan-out. Fixed (never
84/// derived from the thread count) so the chunk partition — and therefore the
85/// sorted-reduction bit pattern — is invariant across machines and rayon
86/// pool sizes. Sized like the design evaluators' streaming blocks: large
87/// enough to amortize per-chunk setup, small enough that per-chunk partial
88/// tables stay cache-resident for the d ≤ 8 regimes the jet order targets.
89pub(crate) const MEASURE_JET_MOMENT_CHUNK_ROWS: usize = 8192;
90
91/// Per-cell moment table: Gaussian-weighted coordinate moments of orders
92/// 0..=2 crossed with response channels, all centered at the cell's
93/// reference point `c`.
94///
95/// Channel convention: channel 0 is the UNIT channel (`g ≡ 1`); further
96/// channels carry responses (`y`, and later `y²`, PIRLS working `z`, `w` per
97/// the frame notes). The table itself never enforces the convention — it
98/// aggregates whatever the caller hands it — but [`jet_sufficient_stats`]
99/// reads `q`, `a_mean`, and the Gram off channel 0.
100///
101/// `m2` is stored as the full (symmetric-by-construction) `d×d` second
102/// moment per channel.
103#[derive(Debug, Clone, PartialEq)]
104pub struct MeasureJetMomentTable {
105 /// Reference point `c` (length `d`).
106 pub center: Array1<f64>,
107 /// Per channel: `Σ_i w_i g_i`.
108 pub m0: Array1<f64>,
109 /// Per channel × d: `Σ_i w_i g_i (x_i − c)`.
110 pub m1: Array2<f64>,
111 /// Per channel: `d×d` matrix `Σ_i w_i g_i (x_i − c)(x_i − c)ᵀ`.
112 pub m2: Vec<Array2<f64>>,
113}
114
115impl MeasureJetMomentTable {
116 /// The monoid identity at `center`: an all-zero table over `n_channels`
117 /// channels. Merging it (at ANY center) into another table leaves that
118 /// table's moments unchanged up to the exact zero shift.
119 pub fn zero(center: Array1<f64>, n_channels: usize) -> Self {
120 let d = center.len();
121 Self {
122 center,
123 m0: Array1::zeros(n_channels),
124 m1: Array2::zeros((n_channels, d)),
125 m2: (0..n_channels).map(|_| Array2::zeros((d, d))).collect(),
126 }
127 }
128
129 /// Ambient dimension `d` of the cell.
130 pub fn dim(&self) -> usize {
131 self.center.len()
132 }
133
134 /// Number of response channels stored (channel 0 = unit by convention).
135 pub fn n_channels(&self) -> usize {
136 self.m0.len()
137 }
138}
139
140/// Shape/finiteness self-consistency of a (publicly constructible) table:
141/// returns `(n_channels, d)`. Single validation source for the fallible
142/// consumers ([`merge_moment_tables`], [`jet_sufficient_stats`]).
143pub(crate) fn validate_table_shape(
144 t: &MeasureJetMomentTable,
145 label: &str,
146) -> Result<(usize, usize), BasisError> {
147 let d = t.center.len();
148 let n_channels = t.m0.len();
149 if t.center.iter().any(|v| !v.is_finite()) {
150 crate::bail_invalid_basis!("measure-jet moment table `{label}` has a non-finite center");
151 }
152 if t.m1.dim() != (n_channels, d) {
153 crate::bail_dim_basis!(
154 "measure-jet moment table `{label}` m1 shape {:?} does not match (channels, d) = ({n_channels}, {d})",
155 t.m1.dim()
156 );
157 }
158 if t.m2.len() != n_channels {
159 crate::bail_dim_basis!(
160 "measure-jet moment table `{label}` has {} m2 blocks for {n_channels} channels",
161 t.m2.len()
162 );
163 }
164 for (ch, block) in t.m2.iter().enumerate() {
165 if block.dim() != (d, d) {
166 crate::bail_dim_basis!(
167 "measure-jet moment table `{label}` m2[{ch}] shape {:?} is not ({d}, {d})",
168 block.dim()
169 );
170 }
171 }
172 Ok((n_channels, d))
173}
174
175/// Sequential moment accumulation over one row chunk, in row order. The
176/// per-entry update order is fixed — `wg = w·g`, then `m1 += wg·dx_k`, then
177/// `m2 += (wg·dx_k)·dx_l` with that exact association — as part of the
178/// module's bit-determinism contract.
179pub(crate) fn accumulate_chunk(
180 coords: ArrayView2<'_, f64>,
181 weights: ArrayView1<'_, f64>,
182 channels: &[ArrayView1<'_, f64>],
183 center: ArrayView1<'_, f64>,
184 rows: Range<usize>,
185) -> Result<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>), BasisError> {
186 let d = center.len();
187 let n_channels = channels.len();
188 let mut m0 = Array1::<f64>::zeros(n_channels);
189 let mut m1 = Array2::<f64>::zeros((n_channels, d));
190 let mut m2: Vec<Array2<f64>> = (0..n_channels).map(|_| Array2::zeros((d, d))).collect();
191 let mut dx = vec![0.0_f64; d];
192 for r in rows {
193 let w = weights[r];
194 if !(w.is_finite() && w >= 0.0) {
195 crate::bail_invalid_basis!(
196 "measure-jet moment accumulation needs finite nonnegative weights; got {w} at row {r}"
197 );
198 }
199 for k in 0..d {
200 let x = coords[(r, k)];
201 if !x.is_finite() {
202 crate::bail_invalid_basis!(
203 "measure-jet moment accumulation hit a non-finite coordinate at row {r}, axis {k}"
204 );
205 }
206 dx[k] = x - center[k];
207 }
208 for (ch, g) in channels.iter().enumerate() {
209 let gv = g[r];
210 if !gv.is_finite() {
211 crate::bail_invalid_basis!(
212 "measure-jet moment accumulation hit a non-finite channel value at row {r}, channel {ch}"
213 );
214 }
215 let wg = w * gv;
216 m0[ch] += wg;
217 let m2_ch = &mut m2[ch];
218 for k in 0..d {
219 let wg_dk = wg * dx[k];
220 m1[(ch, k)] += wg_dk;
221 for l in 0..d {
222 m2_ch[(k, l)] += wg_dk * dx[l];
223 }
224 }
225 }
226 }
227 Ok((m0, m1, m2))
228}
229
230/// Accumulate one cell's moment table from raw rows. The single point where
231/// data rows are read; everything downstream is closed-form algebra on the
232/// result.
233///
234/// `weights` are the caller-computed Gaussian kernel weights
235/// `w_i = mass_i · exp(−‖x_i − c‖²/(2ε²))` (or their truncated variant — the
236/// cutoff and its `e^{−ρ²/2}` budget are the caller's responsibility);
237/// `channels` are the per-row response channels `g_i`, typically
238/// `[ones, y]`, with channel 0 = unit by convention.
239///
240/// Streaming/parallel layout: rows are split into fixed
241/// [`MEASURE_JET_MOMENT_CHUNK_ROWS`]-sized chunks (the bms-style chunked row
242/// reduction), each chunk is accumulated sequentially in row order, and the
243/// chunk partials are folded sequentially in chunk-index order — the sorted
244/// reduction that makes the output bit-deterministic regardless of thread
245/// scheduling. `rows == 0` is allowed and yields the monoid identity at
246/// `center`.
247pub fn accumulate_moment_table(
248 coords: ArrayView2<'_, f64>,
249 weights: ArrayView1<'_, f64>,
250 channels: &[ArrayView1<'_, f64>],
251 center: ArrayView1<'_, f64>,
252) -> Result<MeasureJetMomentTable, BasisError> {
253 let n = coords.nrows();
254 let d = coords.ncols();
255 if d == 0 {
256 crate::bail_invalid_basis!(
257 "measure-jet moment accumulation needs at least one coordinate axis"
258 );
259 }
260 if center.len() != d {
261 crate::bail_dim_basis!(
262 "measure-jet moment center length {} does not match coordinate dimension {d}",
263 center.len()
264 );
265 }
266 if center.iter().any(|v| !v.is_finite()) {
267 crate::bail_invalid_basis!("measure-jet moment accumulation needs a finite center");
268 }
269 if weights.len() != n {
270 crate::bail_dim_basis!(
271 "measure-jet moment weights length {} does not match {n} rows",
272 weights.len()
273 );
274 }
275 if channels.is_empty() {
276 crate::bail_invalid_basis!(
277 "measure-jet moment accumulation needs at least one response channel (channel 0 = unit)"
278 );
279 }
280 for (ch, g) in channels.iter().enumerate() {
281 if g.len() != n {
282 crate::bail_dim_basis!(
283 "measure-jet moment channel {ch} length {} does not match {n} rows",
284 g.len()
285 );
286 }
287 }
288 let n_chunks = n.div_ceil(MEASURE_JET_MOMENT_CHUNK_ROWS).max(1);
289 let partials: Vec<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> = if n_chunks == 1 {
290 vec![accumulate_chunk(coords, weights, channels, center, 0..n)?]
291 } else {
292 (0..n_chunks)
293 .into_par_iter()
294 .map(|chunk| {
295 let start = chunk * MEASURE_JET_MOMENT_CHUNK_ROWS;
296 let end = (start + MEASURE_JET_MOMENT_CHUNK_ROWS).min(n);
297 accumulate_chunk(coords, weights, channels, center, start..end)
298 })
299 .collect::<Result<Vec<_>, BasisError>>()?
300 };
301 // Sorted reduction: fold chunk partials in chunk-index order. All
302 // partials share `center`, so the fold is plain componentwise addition.
303 let mut iter = partials.into_iter();
304 let (mut m0, mut m1, mut m2) = iter
305 .next()
306 .expect("chunk count is clamped to at least one partial");
307 for (p0, p1, p2) in iter {
308 m0 += &p0;
309 m1 += &p1;
310 for (total, part) in m2.iter_mut().zip(&p2) {
311 *total += part;
312 }
313 }
314 Ok(MeasureJetMomentTable {
315 center: center.to_owned(),
316 m0,
317 m1,
318 m2,
319 })
320}
321
322/// Exact recentering via the binomial shift: the same frozen-weight
323/// polynomial table re-expressed about `new_center`, with no kernel
324/// re-evaluation. This is not a moving-kernel identity; if the Gaussian
325/// center changes, the caller must recompute or approximate the weights.
326///
327/// Derivation (per channel; write `Δ = c − c′` so `x − c′ = (x − c) + Δ`):
328///
329/// - order 0: `μ′_0 = Σ w g = μ_0` — unchanged;
330/// - order 1: `μ′_1 = Σ w g ((x−c) + Δ) = μ_1 + Δ·μ_0`;
331/// - order 2: `μ′_2 = Σ w g ((x−c)+Δ)((x−c)+Δ)ᵀ
332/// = μ_2 + Δ·μ_1ᵀ + μ_1·Δᵀ + ΔΔᵀ·μ_0`,
333///
334/// which is the multi-index binomial identity
335/// `μ′_α = Σ_{β≤α} C(α,β)(c−c′)^{α−β} μ_β` specialized to `|α| ≤ 2`. Every
336/// term is a finite product of stored moments and `Δ`, so the shift is an
337/// algebraic identity of the frozen-weight table — exact up to floating-point
338/// rounding, and exactly exact whenever the arithmetic is (dyadic lattices;
339/// pinned in the tests).
340///
341/// Bit-determinism: the order-2 entry is evaluated in the ONE fixed order
342/// `((μ_2 + Δ_k·μ_{1,l}) + μ_{1,k}·Δ_l) + (Δ_k·Δ_l)·μ_0`; same inputs always
343/// produce the same bits.
344pub fn recenter_moment_table(
345 t: &MeasureJetMomentTable,
346 new_center: ArrayView1<'_, f64>,
347) -> MeasureJetMomentTable {
348 let d = t.center.len();
349 assert_eq!(
350 new_center.len(),
351 d,
352 "measure-jet recenter: new center length {} does not match table dimension {d}",
353 new_center.len()
354 );
355 let n_channels = t.m0.len();
356 let mut delta = Array1::<f64>::zeros(d);
357 for k in 0..d {
358 delta[k] = t.center[k] - new_center[k];
359 }
360 let m0 = t.m0.clone();
361 let mut m1 = Array2::<f64>::zeros((n_channels, d));
362 for ch in 0..n_channels {
363 for k in 0..d {
364 m1[(ch, k)] = t.m1[(ch, k)] + delta[k] * t.m0[ch];
365 }
366 }
367 let mut m2 = Vec::with_capacity(n_channels);
368 for ch in 0..n_channels {
369 let src = &t.m2[ch];
370 let mut out = Array2::<f64>::zeros((d, d));
371 for k in 0..d {
372 for l in 0..d {
373 out[(k, l)] = ((src[(k, l)] + delta[k] * t.m1[(ch, l)]) + t.m1[(ch, k)] * delta[l])
374 + (delta[k] * delta[l]) * t.m0[ch];
375 }
376 }
377 m2.push(out);
378 }
379 MeasureJetMomentTable {
380 center: new_center.to_owned(),
381 m0,
382 m1,
383 m2,
384 }
385}
386
387/// Lexicographic total order on cell centers (`f64::total_cmp` per
388/// coordinate). The canonical-orientation key that makes the merge bitwise
389/// argument-order-independent.
390pub(crate) fn lex_cmp_centers(a: &Array1<f64>, b: &Array1<f64>) -> Ordering {
391 for (x, y) in a.iter().zip(b.iter()) {
392 let ord = x.total_cmp(y);
393 if ord != Ordering::Equal {
394 return ord;
395 }
396 }
397 Ordering::Equal
398}
399
400/// Monoid merge: recenter compatible frozen-weight tables onto a common
401/// reference, then add componentwise. Exact for those polynomial moments
402/// (pure binomial shift, no kernel re-evaluation) and deterministic.
403///
404/// Canonical orientation: the merged table lives at the lexicographically
405/// SMALLER of the two operand centers ([`lex_cmp_centers`]), and the other
406/// operand is the one recentered. Because the (host, guest) roles depend
407/// only on the centers — never on argument position — `merge(a, b)` and
408/// `merge(b, a)` execute identical arithmetic and agree BITWISE for
409/// arbitrary inputs (IEEE addition is commutative; only grouping is not).
410/// This is a deliberate strengthening of the naive "recenter `b` onto
411/// `a.center`" rule, which is only commutative up to a recentering.
412pub fn merge_moment_tables(
413 a: &MeasureJetMomentTable,
414 b: &MeasureJetMomentTable,
415) -> Result<MeasureJetMomentTable, BasisError> {
416 let (a_channels, a_dim) = validate_table_shape(a, "a")?;
417 let (b_channels, b_dim) = validate_table_shape(b, "b")?;
418 if a_dim != b_dim || a_channels != b_channels {
419 crate::bail_dim_basis!(
420 "measure-jet merge needs matching tables; got (channels, d) = ({a_channels}, {a_dim}) vs ({b_channels}, {b_dim})"
421 );
422 }
423 let (host, guest) = if lex_cmp_centers(&a.center, &b.center) != Ordering::Greater {
424 (a, b)
425 } else {
426 (b, a)
427 };
428 let moved = recenter_moment_table(guest, host.center.view());
429 let mut m2 = Vec::with_capacity(a_channels);
430 for (h, g) in host.m2.iter().zip(&moved.m2) {
431 m2.push(h + g);
432 }
433 Ok(MeasureJetMomentTable {
434 center: host.center.clone(),
435 m0: &host.m0 + &moved.m0,
436 m1: &host.m1 + &moved.m1,
437 m2,
438 })
439}
440
441/// The local jet-fit sufficient statistics read off one table — exactly the
442/// per-block quantities `assemble_weighted_forms` (measure_jet_smooth.rs)
443/// computes from raw points when the table weights are frozen at the same
444/// center and scale, reproduced in closed form from stored moments.
445#[derive(Debug, Clone, PartialEq)]
446pub struct MeasureJetJetStats {
447 /// Kernel mass `q = Σ w_i` (unit-channel zeroth moment).
448 pub q: f64,
449 /// Weighted mean of the requested value channel: `uᵀv = m0[ch]/q`.
450 pub mean: f64,
451 /// Dimensionless slope Gram `G = Φ̃ᵀWΦ̃/q = m2[0]/(qε²) − ā·āᵀ` with
452 /// `ā = m1[0]/(qε)` (`Φ` rows are `(x_i − c)/ε`).
453 pub gram: Array2<f64>,
454 /// Local-fit right-hand side `Bᵀv/q = m1[ch]/(qε) − ā·(m0[ch]/q)` — the
455 /// vector the exact weighted affine projection consumes.
456 pub cross: Array1<f64>,
457}
458
459/// Read the local jet-fit sufficient statistics off a moment table at scale
460/// `eps`, for value channel `channel`.
461///
462/// 1:1 with `assemble_weighted_forms`' per-block math (its symbols on the
463/// right), under the energy convention that local features are the
464/// ε-SCALED offsets `Φ_{jk} = (x_{jk} − c_k)/ε`:
465///
466/// - `q = m0[0]` ↔ `q = Σ_j w_j`,
467/// - `ā_k = m1[0,k]/(q·ε)` ↔ `a_mean = Φᵀw/q`,
468/// - `G_kl = m2[0][k,l]/(q·ε²) − ā_k·ā_l` ↔ `G = (ΦᵀWΦ)/q − a·aᵀ`,
469/// - `mean = m0[ch]/q` ↔ `uᵀv` (the weighted-centering
470/// projection `Cv = v − (uᵀv)·1` of the constant-annihilation contract),
471/// - `cross_k = m1[ch,k]/(q·ε) − ā_k·mean` ↔ `Bᵀv/q` with
472/// `B = WΦ − w·aᵀ` (column-centering makes `Φ̃ᵀW·1 = 0`, so
473/// `Φ̃ᵀWCv/q = Bᵀv/q` — the exact RHS of the local affine projection).
474///
475/// For `channel == 0` (the unit channel) `mean` is exactly `1.0` and `cross`
476/// is identically `+0.0` (the same division is subtracted from itself) —
477/// the moment-level restatement of exact constant annihilation.
478pub fn jet_sufficient_stats(
479 t: &MeasureJetMomentTable,
480 eps: f64,
481 channel: usize,
482) -> Result<MeasureJetJetStats, BasisError> {
483 let (n_channels, d) = validate_table_shape(t, "t")?;
484 if !(eps.is_finite() && eps > 0.0) {
485 crate::bail_invalid_basis!(
486 "measure-jet jet stats need a finite positive scale eps; got {eps}"
487 );
488 }
489 if channel >= n_channels {
490 crate::bail_invalid_basis!(
491 "measure-jet jet stats channel {channel} out of range for {n_channels} channels"
492 );
493 }
494 let q = t.m0[0];
495 if !(q.is_finite() && q > 0.0) {
496 crate::bail_invalid_basis!(
497 "measure-jet jet stats need positive unit-channel kernel mass q; got {q}"
498 );
499 }
500 let q_eps = q * eps;
501 let mut a_mean = Array1::<f64>::zeros(d);
502 for k in 0..d {
503 a_mean[k] = t.m1[(0, k)] / q_eps;
504 }
505 let q_eps2 = q * eps * eps;
506 let m2_unit = &t.m2[0];
507 let mut gram = Array2::<f64>::zeros((d, d));
508 for k in 0..d {
509 for l in 0..d {
510 gram[(k, l)] = m2_unit[(k, l)] / q_eps2 - a_mean[k] * a_mean[l];
511 }
512 }
513 let mean = t.m0[channel] / q;
514 let mut cross = Array1::<f64>::zeros(d);
515 for k in 0..d {
516 cross[k] = t.m1[(channel, k)] / q_eps - a_mean[k] * mean;
517 }
518 Ok(MeasureJetJetStats {
519 q,
520 mean,
521 gram,
522 cross,
523 })
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use ndarray::s;
530
531 /// Closeness metric for the recenter-exactness gate: relative at scale,
532 /// absolute `tol` below unit scale (`|x−y| ≤ tol·(1 + max(|x|,|y|))`).
533 pub(crate) fn assert_tables_close(
534 a: &MeasureJetMomentTable,
535 b: &MeasureJetMomentTable,
536 tol: f64,
537 ) {
538 let pairs = |xs: &[f64], ys: &[f64], label: &str| {
539 assert_eq!(xs.len(), ys.len(), "{label}: length mismatch");
540 for (i, (x, y)) in xs.iter().zip(ys.iter()).enumerate() {
541 let scale = 1.0 + x.abs().max(y.abs());
542 assert!(
543 (x - y).abs() <= tol * scale,
544 "{label}[{i}]: {x} vs {y} differ beyond {tol} rel"
545 );
546 }
547 };
548 pairs(
549 a.center.as_slice().expect("contiguous center"),
550 b.center.as_slice().expect("contiguous center"),
551 "center",
552 );
553 pairs(
554 a.m0.as_slice().expect("contiguous m0"),
555 b.m0.as_slice().expect("contiguous m0"),
556 "m0",
557 );
558 pairs(
559 a.m1.as_slice().expect("contiguous m1"),
560 b.m1.as_slice().expect("contiguous m1"),
561 "m1",
562 );
563 assert_eq!(a.m2.len(), b.m2.len(), "m2: channel count mismatch");
564 for (ch, (x, y)) in a.m2.iter().zip(b.m2.iter()).enumerate() {
565 pairs(
566 x.as_slice().expect("contiguous m2"),
567 y.as_slice().expect("contiguous m2"),
568 &format!("m2[{ch}]"),
569 );
570 }
571 }
572
573 /// Bit-identity gate: every stored f64 must agree by `to_bits`.
574 pub(crate) fn assert_tables_bit_identical(
575 a: &MeasureJetMomentTable,
576 b: &MeasureJetMomentTable,
577 ) {
578 let bits = |xs: &[f64], ys: &[f64], label: &str| {
579 assert_eq!(xs.len(), ys.len(), "{label}: length mismatch");
580 for (i, (x, y)) in xs.iter().zip(ys.iter()).enumerate() {
581 assert_eq!(
582 x.to_bits(),
583 y.to_bits(),
584 "{label}[{i}]: {x} vs {y} differ bitwise"
585 );
586 }
587 };
588 bits(
589 a.center.as_slice().expect("contiguous center"),
590 b.center.as_slice().expect("contiguous center"),
591 "center",
592 );
593 bits(
594 a.m0.as_slice().expect("contiguous m0"),
595 b.m0.as_slice().expect("contiguous m0"),
596 "m0",
597 );
598 bits(
599 a.m1.as_slice().expect("contiguous m1"),
600 b.m1.as_slice().expect("contiguous m1"),
601 "m1",
602 );
603 assert_eq!(a.m2.len(), b.m2.len(), "m2: channel count mismatch");
604 for (ch, (x, y)) in a.m2.iter().zip(b.m2.iter()).enumerate() {
605 bits(
606 x.as_slice().expect("contiguous m2"),
607 y.as_slice().expect("contiguous m2"),
608 &format!("m2[{ch}]"),
609 );
610 }
611 }
612
613 /// Deterministic generic-float dataset (no RNG): low-discrepancy
614 /// fractional parts, d = 3, with a unit channel and one value channel.
615 pub(crate) fn float_dataset(n: usize) -> (Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>) {
616 let mut coords = Array2::<f64>::zeros((n, 3));
617 let mut weights = Array1::<f64>::zeros(n);
618 let mut ones = Array1::<f64>::zeros(n);
619 let mut y = Array1::<f64>::zeros(n);
620 for i in 0..n {
621 let t = (i + 1) as f64;
622 coords[(i, 0)] = (t * 0.618034).fract() * 4.0 - 2.0;
623 coords[(i, 1)] = (t * 0.414214).fract() * 3.0 - 1.0;
624 coords[(i, 2)] = (t * 0.732051).fract() * 2.0 - 1.5;
625 weights[i] = 0.05 + (t * 0.292893).fract();
626 ones[i] = 1.0;
627 y[i] = (t * 0.539345).fract() * 6.0 - 3.0;
628 }
629 (coords, weights, ones, y)
630 }
631
632 /// Dyadic-lattice dataset: integer coordinates and channel values,
633 /// dyadic weights — every moment product and sum is exactly
634 /// representable in f64, so the algebraic monoid laws become BIT
635 /// identities and the tests below can pin them with `to_bits`.
636 pub(crate) fn dyadic_dataset() -> (Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>) {
637 let coords = ndarray::array![
638 [3.0, -2.0],
639 [1.0, 4.0],
640 [-5.0, 2.0],
641 [2.0, 2.0],
642 [4.0, -1.0],
643 [0.0, 5.0],
644 [-3.0, -4.0],
645 [6.0, 1.0],
646 [-1.0, 3.0],
647 [5.0, -3.0],
648 [2.0, 7.0],
649 [-6.0, -2.0],
650 [3.0, 3.0],
651 [1.0, -5.0],
652 [4.0, 6.0],
653 [-2.0, -3.0],
654 ];
655 let weights = ndarray::array![
656 0.5, 1.0, 2.0, 0.25, 1.5, 0.75, 1.0, 0.5, 2.5, 1.25, 0.5, 3.0, 0.75, 1.0, 1.75, 2.0
657 ];
658 let ones = Array1::<f64>::ones(16);
659 let y = ndarray::array![
660 2.0, -3.0, 5.0, 1.0, -4.0, 7.0, 2.0, -6.0, 3.0, 4.0, -2.0, 8.0, 1.0, -7.0, 5.0, -1.0
661 ];
662 (coords, weights, ones, y)
663 }
664
665 #[test]
666 pub(crate) fn recenter_is_exact() {
667 let (coords, weights, ones, y) = float_dataset(40);
668 let channels = [ones.view(), y.view()];
669 let c = ndarray::array![0.4, -0.3, 0.9];
670 let c_prime = ndarray::array![-1.1, 0.25, 0.5];
671 let at_c = accumulate_moment_table(coords.view(), weights.view(), &channels, c.view())
672 .expect("accumulation about c");
673 let shifted = recenter_moment_table(&at_c, c_prime.view());
674 let direct =
675 accumulate_moment_table(coords.view(), weights.view(), &channels, c_prime.view())
676 .expect("accumulation about c'");
677 assert_tables_close(&shifted, &direct, 1e-14);
678 // Round trip back to c reproduces the original to the same gate.
679 let back = recenter_moment_table(&shifted, c.view());
680 assert_tables_close(&back, &at_c, 1e-14);
681 }
682
683 #[test]
684 pub(crate) fn merge_is_associative_and_commutative_bitwise() {
685 // Dyadic lattice ⇒ all moment/shift arithmetic is exact, so the
686 // monoid laws hold BITWISE across groupings (the sorted-reduction
687 // convention covers generic-float grouping determinism; see the
688 // module docs).
689 let (coords, weights, ones, y) = dyadic_dataset();
690 let chunk = |rows: Range<usize>, center: &Array1<f64>| {
691 let ones_c = ones.slice(s![rows.clone()]);
692 let y_c = y.slice(s![rows.clone()]);
693 accumulate_moment_table(
694 coords.slice(s![rows.clone(), ..]),
695 weights.slice(s![rows]),
696 &[ones_c, y_c],
697 center.view(),
698 )
699 .expect("chunk accumulation")
700 };
701 let c_a = ndarray::array![2.0, -1.0];
702 let c_b = ndarray::array![0.0, 3.0];
703 let c_c = ndarray::array![-4.0, 1.0];
704 let a = chunk(0..5, &c_a);
705 let b = chunk(5..9, &c_b);
706 let c = chunk(9..14, &c_c);
707
708 // Commutativity is bitwise for ARBITRARY inputs: the canonical
709 // center orientation makes merge(a, b) and merge(b, a) execute
710 // identical arithmetic. No recentering needed before comparing.
711 let ab = merge_moment_tables(&a, &b).expect("a+b");
712 let ba = merge_moment_tables(&b, &a).expect("b+a");
713 assert_tables_bit_identical(&ab, &ba);
714 // ... including on generic (non-dyadic) float data.
715 let (fc, fw, fo, fy) = float_dataset(24);
716 let fa = accumulate_moment_table(
717 fc.slice(s![0..12, ..]),
718 fw.slice(s![0..12]),
719 &[fo.slice(s![0..12]), fy.slice(s![0..12])],
720 ndarray::array![0.3, -0.7, 0.1].view(),
721 )
722 .expect("float chunk a");
723 let fb = accumulate_moment_table(
724 fc.slice(s![12..24, ..]),
725 fw.slice(s![12..24]),
726 &[fo.slice(s![12..24]), fy.slice(s![12..24])],
727 ndarray::array![-0.9, 0.4, 0.6].view(),
728 )
729 .expect("float chunk b");
730 assert_tables_bit_identical(
731 &merge_moment_tables(&fa, &fb).expect("fa+fb"),
732 &merge_moment_tables(&fb, &fa).expect("fb+fa"),
733 );
734
735 // Associativity, bitwise on the exact lattice.
736 let ab_c = merge_moment_tables(&ab, &c).expect("(a+b)+c");
737 let bc = merge_moment_tables(&b, &c).expect("b+c");
738 let a_bc = merge_moment_tables(&a, &bc).expect("a+(b+c)");
739 assert_tables_bit_identical(&ab_c, &a_bc);
740 // And after recentering both to a common reference.
741 let c_ref = ndarray::array![1.0, 2.0];
742 assert_tables_bit_identical(
743 &recenter_moment_table(&ab_c, c_ref.view()),
744 &recenter_moment_table(&a_bc, c_ref.view()),
745 );
746 }
747
748 #[test]
749 pub(crate) fn jet_stats_match_assemble_weighted_forms_math() {
750 // Small 2-D point set, replicating assemble_weighted_forms' local
751 // loop verbatim from raw points: w_j = mass_j·exp(−d²/(2ε²)),
752 // q = Σ w, Φ_{jk} = (x_{jk} − c_k)/ε, a = Φᵀw/q,
753 // G = (ΦᵀWΦ)/q − a·aᵀ, uᵀv = wᵀv/q, Bᵀv/q with B = WΦ − w·aᵀ.
754 let pts = ndarray::array![
755 [0.0, 0.0],
756 [0.45, -0.2],
757 [-0.35, 0.4],
758 [0.25, 0.55],
759 [-0.5, -0.45],
760 [0.6, 0.3]
761 ];
762 let masses = ndarray::array![0.22, 0.13, 0.19, 0.11, 0.2, 0.15];
763 let y = ndarray::array![0.7, -1.3, 2.1, 0.4, -0.6, 1.9];
764 let center = ndarray::array![0.0, 0.0];
765 let eps = 0.75;
766 let m = pts.nrows();
767 let d = pts.ncols();
768
769 // Kernel weights exactly as the workhorse forms them.
770 let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
771 let mut w = Array1::<f64>::zeros(m);
772 let mut q = 0.0_f64;
773 for j in 0..m {
774 let mut dist2 = 0.0_f64;
775 for k in 0..d {
776 let dlt = pts[(j, k)] - center[k];
777 dist2 += dlt * dlt;
778 }
779 w[j] = masses[j] * (-dist2 * inv_two_eps2).exp();
780 q += w[j];
781 }
782 let mut phi = Array2::<f64>::zeros((m, d));
783 for j in 0..m {
784 for k in 0..d {
785 phi[(j, k)] = (pts[(j, k)] - center[k]) / eps;
786 }
787 }
788 let a_mean = phi.t().dot(&w) / q;
789 let mut wphi = phi.clone();
790 for (j, mut row) in wphi.outer_iter_mut().enumerate() {
791 row.mapv_inplace(|v| v * w[j]);
792 }
793 let mut g_ref = phi.t().dot(&wphi);
794 g_ref.mapv_inplace(|v| v / q);
795 for r in 0..d {
796 for c in 0..d {
797 g_ref[(r, c)] -= a_mean[r] * a_mean[c];
798 }
799 }
800 let mean_ref = w.dot(&y) / q;
801 let mut cross_ref = Array1::<f64>::zeros(d);
802 for k in 0..d {
803 let mut acc = 0.0_f64;
804 for j in 0..m {
805 acc += (wphi[(j, k)] - w[j] * a_mean[k]) * y[j];
806 }
807 cross_ref[k] = acc / q;
808 }
809
810 // Substrate path: same caller-computed weights into a moment table.
811 let ones = Array1::<f64>::ones(m);
812 let table = accumulate_moment_table(
813 pts.view(),
814 w.view(),
815 &[ones.view(), y.view()],
816 center.view(),
817 )
818 .expect("moment table");
819 let stats = jet_sufficient_stats(&table, eps, 1).expect("jet stats");
820
821 let tol = 1e-14;
822 let close = |x: f64, y_: f64, label: &str| {
823 let scale = 1.0 + x.abs().max(y_.abs());
824 assert!(
825 (x - y_).abs() <= tol * scale,
826 "{label}: {x} vs {y_} beyond {tol} rel"
827 );
828 };
829 close(stats.q, q, "q");
830 close(stats.mean, mean_ref, "mean");
831 for k in 0..d {
832 close(stats.cross[k], cross_ref[k], &format!("cross[{k}]"));
833 for l in 0..d {
834 close(stats.gram[(k, l)], g_ref[(k, l)], &format!("gram[{k},{l}]"));
835 }
836 }
837
838 // Unit channel: exact constant annihilation at the moment level —
839 // mean is exactly 1, cross is identically +0.0.
840 let unit_stats = jet_sufficient_stats(&table, eps, 0).expect("unit-channel stats");
841 assert_eq!(unit_stats.mean, 1.0, "unit-channel mean must be exactly 1");
842 for k in 0..d {
843 assert_eq!(
844 unit_stats.cross[k], 0.0,
845 "unit-channel cross[{k}] must be exactly zero"
846 );
847 }
848 }
849
850 /// LEVEL/TILT truth-recovery gate (#1041). The deficit pattern flagged in
851 /// the 8-dataset benchmark — worst on pooled/pointwise risk (RMSE/Brier/R²)
852 /// but only mid-pack on calibration SLOPE — is the fingerprint of a biased
853 /// affine projection: a systematic shift in the recovered LEVEL `c₀` or a
854 /// TILT in the recovered gradient `g`. The local affine sufficient statistic
855 /// this module computes (`mean`, `G`, `cross`) is the exact object that
856 /// projection consumes, so a bias there would surface here.
857 ///
858 /// Construct a channel value that is EXACTLY affine in the coordinates,
859 /// `v(x) = c₀ + gᵀ(x − center)`, under ARBITRARY (non-symmetric) weights.
860 /// The weighted affine projection must then recover `(c₀, g)` with ZERO
861 /// residual — the curved/higher-order energy is empty, so any nonzero level
862 /// or tilt error is pure projection bias, not a smoothing artifact. We
863 /// assert this across SHRINKING kernel widths ε (concentrating the weights),
864 /// the regime where a level/tilt bias in the centered second moment `G` or
865 /// the centered cross `Bᵀv/q` would be amplified.
866 #[test]
867 pub(crate) fn affine_projection_recovers_level_and_tilt_without_bias() {
868 // Asymmetric, off-center point cloud so the weighted barycenter does
869 // NOT coincide with the reference center: this is exactly where a
870 // mis-centered (biased) projection would leak the level into the tilt
871 // and vice versa.
872 let pts = ndarray::array![
873 [0.10, -0.30],
874 [0.62, 0.05],
875 [-0.18, 0.44],
876 [0.37, 0.51],
877 [-0.46, -0.22],
878 [0.71, 0.33],
879 [0.05, 0.62],
880 [-0.33, 0.14],
881 ];
882 // Strictly positive, deliberately uneven masses (no symmetry to lean on).
883 let masses = ndarray::array![0.31, 0.07, 0.22, 0.05, 0.19, 0.11, 0.27, 0.13];
884 let center = ndarray::array![0.05, 0.10];
885 let m = pts.nrows();
886 let d = pts.ncols();
887
888 // Exact affine truth in ambient coordinates: level c0, gradient g.
889 let c0 = 1.37_f64;
890 let g = ndarray::array![-0.85_f64, 0.42_f64];
891 let mut v = Array1::<f64>::zeros(m);
892 for j in 0..m {
893 let mut acc = c0;
894 for k in 0..d {
895 acc += g[k] * (pts[(j, k)] - center[k]);
896 }
897 v[j] = acc;
898 }
899
900 let ones = Array1::<f64>::ones(m);
901 // Tighten the kernel across several scales: shrinking eps concentrates
902 // the Gaussian weights and amplifies any centering/projection bias.
903 for &eps in &[1.0_f64, 0.5, 0.25, 0.12, 0.06] {
904 let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
905 let mut w = Array1::<f64>::zeros(m);
906 for j in 0..m {
907 let mut dist2 = 0.0_f64;
908 for k in 0..d {
909 let dlt = pts[(j, k)] - center[k];
910 dist2 += dlt * dlt;
911 }
912 w[j] = masses[j] * (-dist2 * inv_two_eps2).exp();
913 }
914
915 let table = accumulate_moment_table(
916 pts.view(),
917 w.view(),
918 &[ones.view(), v.view()],
919 center.view(),
920 )
921 .expect("moment table");
922 let stats = jet_sufficient_stats(&table, eps, 1).expect("affine jet stats");
923
924 // The weighted affine projection solves `G b̂ = cross` for the
925 // ε-scaled slope; the ambient gradient is b̂/ε and the recovered
926 // LEVEL is `mean − āᵀ b̂` (the weighted mean minus the slope's
927 // contribution at the weighted barycenter). For an exactly affine
928 // truth both must equal the truth with zero residual.
929 //
930 // Solve the 2×2 SPD system directly (no external solver) so the
931 // test pins the projection math, not a library inverse.
932 let g00 = stats.gram[(0, 0)];
933 let g01 = stats.gram[(0, 1)];
934 let g11 = stats.gram[(1, 1)];
935 let det = g00 * g11 - g01 * g01;
936 assert!(
937 det > 1e-10,
938 "centered slope Gram must stay nondegenerate at eps={eps}; det={det}"
939 );
940 let b0 = (g11 * stats.cross[0] - g01 * stats.cross[1]) / det;
941 let b1 = (-g01 * stats.cross[0] + g00 * stats.cross[1]) / det;
942 // Ambient gradient = scaled slope / eps (Φ rows are (x−c)/ε).
943 let grad = [b0 / eps, b1 / eps];
944
945 // Recovered weighted barycenter offset ā (ambient) = a_mean·ε.
946 // Level at the reference center = mean − gradᵀ·(barycenter − center)
947 // = mean − (b̂ᵀ ā).
948 let a_mean0 = table.m1[(0, 0)] / (stats.q * eps);
949 let a_mean1 = table.m1[(0, 1)] / (stats.q * eps);
950 let level = stats.mean - (b0 * a_mean0 + b1 * a_mean1);
951
952 // TILT: the recovered gradient must match the truth — no systematic
953 // rotation/scaling of the slope channel.
954 assert!(
955 (grad[0] - g[0]).abs() <= 1e-9 && (grad[1] - g[1]).abs() <= 1e-9,
956 "TILT bias at eps={eps}: recovered gradient {grad:?} vs truth {g:?}"
957 );
958 // LEVEL: the recovered intercept at the reference center must match
959 // the truth — no systematic offset of the reconstructed surface.
960 assert!(
961 (level - c0).abs() <= 1e-9,
962 "LEVEL bias at eps={eps}: recovered {level} vs truth {c0}"
963 );
964 }
965 }
966
967 #[test]
968 pub(crate) fn streaming_chunked_accumulation_matches_single_pass() {
969 // Four chunks, each accumulated about its OWN center, merged in
970 // chunk-index order (the sorted reduction) — versus one pass about
971 // the lexicographically smallest chunk center. Dyadic lattice ⇒ the
972 // agreement is exact, pinned bitwise.
973 let (coords, weights, ones, y) = dyadic_dataset();
974 let centers = [
975 ndarray::array![-3.0, 2.0], // lexicographic minimum: merge target
976 ndarray::array![0.0, 0.0],
977 ndarray::array![1.0, -5.0],
978 ndarray::array![4.0, 1.0],
979 ];
980 let chunk = |rows: Range<usize>, center: &Array1<f64>| {
981 let ones_c = ones.slice(s![rows.clone()]);
982 let y_c = y.slice(s![rows.clone()]);
983 accumulate_moment_table(
984 coords.slice(s![rows.clone(), ..]),
985 weights.slice(s![rows]),
986 &[ones_c, y_c],
987 center.view(),
988 )
989 .expect("chunk accumulation")
990 };
991 let t0 = chunk(0..4, ¢ers[0]);
992 let t1 = chunk(4..8, ¢ers[1]);
993 let t2 = chunk(8..12, ¢ers[2]);
994 let t3 = chunk(12..16, ¢ers[3]);
995 let merged = merge_moment_tables(
996 &merge_moment_tables(&merge_moment_tables(&t0, &t1).expect("t0+t1"), &t2)
997 .expect("(t0+t1)+t2"),
998 &t3,
999 )
1000 .expect("((t0+t1)+t2)+t3");
1001 let single = accumulate_moment_table(
1002 coords.view(),
1003 weights.view(),
1004 &[ones.view(), y.view()],
1005 centers[0].view(),
1006 )
1007 .expect("single pass");
1008 // The fold target is the lex-min center, so no final recentering is
1009 // even needed; pin that and the bitwise agreement.
1010 assert_tables_bit_identical(&merged, &single);
1011 // Merging the identity is a no-op.
1012 let with_zero =
1013 merge_moment_tables(&merged, &MeasureJetMomentTable::zero(centers[0].clone(), 2))
1014 .expect("merge with identity");
1015 assert_tables_bit_identical(&with_zero, &merged);
1016 }
1017}