Skip to main content

gam_terms/basis/
measure_jet_moments.rs

1//! Measure-jet frame data interface: per-cell frozen-weight polynomial
2//! moment tables with a binomial-shift merge monoid
3//! (`docs/measure_jet_frame.md`, §2 "Data interface: moments or
4//! nothing").
5//!
6//! This module aggregates caller-computed weights into order-0..2 coordinate
7//! moments. Those tables exactly determine polynomial couplings under the
8//! same frozen weights, including the local affine sufficient statistics used
9//! by `measure_jet_smooth.rs`. They do NOT exactly determine Gaussian
10//! transforms at moved kernel centers: support curves, Gaussian Gram entries,
11//! and Gaussian `XᵀWX` products need their own kernel pass or a separately
12//! controlled approximation. Truncation does NOT live here either: the caller
13//! computes the Gaussian weights `w_i` (mass × kernel profile, with whatever
14//! cutoff its explicit `e^{−ρ²/2}` tolerance budget licenses) and this module
15//! only aggregates what it is handed.
16//!
17//! # The monoid
18//!
19//! A table holds, per response channel `g` and per coordinate multi-index
20//! `α` with `|α| ≤ 2`, the centered moment `μ_α = Σ_i w_i g_i (x_i − c)^α`
21//! about the cell reference point `c`. The binomial shift
22//!
23//! ```text
24//!   μ′_α = Σ_{β ≤ α}  C(α, β) (c − c′)^{α−β} μ_β
25//! ```
26//!
27//! re-expresses the same frozen-weight polynomial table about any other
28//! center `c′` exactly as a finite polynomial identity. It does not move the
29//! Gaussian kernel center or recompute weights. Merging two tables with
30//! already-compatible frozen weights is therefore "recenter to a common
31//! reference, add componentwise":
32//! an associative, commutative monoid whose identity is the empty (all-zero)
33//! table at any center. Exact distributed fitting, exact online updates, and
34//! bit-reproducibility under sorted reduction are corollaries of that one
35//! algebraic fact ([`merge_moment_tables`] is a monoid homomorphism from
36//! disjoint row sets under union to tables under ⊕).
37//!
38//! # Determinism / bit-exactness convention (sorted reduction)
39//!
40//! Floating-point addition is commutative but not associative, so the monoid
41//! laws hold algebraically while bit-patterns depend on reduction ORDER.
42//! This module pins one order everywhere:
43//!
44//! - [`accumulate_moment_table`] splits rows into fixed-size chunks
45//!   ([`MEASURE_JET_MOMENT_CHUNK_ROWS`], never derived from thread count),
46//!   accumulates each chunk sequentially in row order, and folds the chunk
47//!   partials sequentially in chunk-index order — the sorted reduction. The
48//!   result is bit-identical across runs, machines, and rayon pool sizes.
49//! - [`recenter_moment_table`] evaluates the shift in ONE fixed expression
50//!   order (documented at the site).
51//! - [`merge_moment_tables`] canonically orients its operands by the
52//!   lexicographic total order on centers (`f64::total_cmp` per coordinate),
53//!   so `a ⊕ b` and `b ⊕ a` execute the SAME instruction stream and are
54//!   bit-identical for arbitrary inputs.
55//!
56//! Cross-GROUPING bit-identity — `(A⊕B)⊕C` vs `A⊕(B⊕C)` — additionally
57//! requires the moment arithmetic itself to be exact; the in-module tests
58//! pin it on dyadic lattices (integer coordinates/channels, dyadic weights),
59//! where every product and sum is exactly representable, and callers
60//! reducing many chunks get run-to-run determinism by folding in chunk-index
61//! order exactly as the accumulator does.
62//!
63//! # 1:1 contract with `assemble_weighted_forms`
64//!
65//! [`jet_sufficient_stats`] reproduces, in closed form from a stored table
66//! whose weights were computed for the same center and scale, exactly the
67//! local-fit quantities the current workhorse
68//! (`measure_jet_smooth.rs::assemble_weighted_forms`) computes from raw
69//! points per (center, scale) block: the kernel mass `q`, the dimensionless
70//! weighted feature mean `a_mean`, the dimensionless slope Gram
71//! `G = Φ̃ᵀWΦ̃/q`, the weighted channel mean `uᵀv`, and the exact-projection
72//! right-hand side `Bᵀv/q` — so the substrate can later replace that
73//! same-center point loop without changing a single number.
74
75use std::cmp::Ordering;
76use std::ops::Range;
77
78use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
79use rayon::prelude::*;
80
81use super::BasisError;
82
83/// Rows per chunk in the streaming accumulation fan-out. Fixed (never
84/// derived from the thread count) so the chunk partition — and therefore the
85/// sorted-reduction bit pattern — is invariant across machines and rayon
86/// pool sizes. Sized like the design evaluators' streaming blocks: large
87/// enough to amortize per-chunk setup, small enough that per-chunk partial
88/// tables stay cache-resident for the d ≤ 8 regimes the jet order targets.
89pub(crate) const MEASURE_JET_MOMENT_CHUNK_ROWS: usize = 8192;
90
91/// Per-cell moment table: Gaussian-weighted coordinate moments of orders
92/// 0..=2 crossed with response channels, all centered at the cell's
93/// reference point `c`.
94///
95/// Channel convention: channel 0 is the UNIT channel (`g ≡ 1`); further
96/// channels carry responses (`y`, and later `y²`, PIRLS working `z`, `w` per
97/// the frame notes). The table itself never enforces the convention — it
98/// aggregates whatever the caller hands it — but [`jet_sufficient_stats`]
99/// reads `q`, `a_mean`, and the Gram off channel 0.
100///
101/// `m2` is stored as the full (symmetric-by-construction) `d×d` second
102/// moment per channel.
103#[derive(Debug, Clone, PartialEq)]
104pub struct MeasureJetMomentTable {
105    /// Reference point `c` (length `d`).
106    pub center: Array1<f64>,
107    /// Per channel: `Σ_i w_i g_i`.
108    pub m0: Array1<f64>,
109    /// Per channel × d: `Σ_i w_i g_i (x_i − c)`.
110    pub m1: Array2<f64>,
111    /// Per channel: `d×d` matrix `Σ_i w_i g_i (x_i − c)(x_i − c)ᵀ`.
112    pub m2: Vec<Array2<f64>>,
113}
114
115impl MeasureJetMomentTable {
116    /// The monoid identity at `center`: an all-zero table over `n_channels`
117    /// channels. Merging it (at ANY center) into another table leaves that
118    /// table's moments unchanged up to the exact zero shift.
119    pub fn zero(center: Array1<f64>, n_channels: usize) -> Self {
120        let d = center.len();
121        Self {
122            center,
123            m0: Array1::zeros(n_channels),
124            m1: Array2::zeros((n_channels, d)),
125            m2: (0..n_channels).map(|_| Array2::zeros((d, d))).collect(),
126        }
127    }
128
129    /// Ambient dimension `d` of the cell.
130    pub fn dim(&self) -> usize {
131        self.center.len()
132    }
133
134    /// Number of response channels stored (channel 0 = unit by convention).
135    pub fn n_channels(&self) -> usize {
136        self.m0.len()
137    }
138}
139
140/// Shape/finiteness self-consistency of a (publicly constructible) table:
141/// returns `(n_channels, d)`. Single validation source for the fallible
142/// consumers ([`merge_moment_tables`], [`jet_sufficient_stats`]).
143pub(crate) fn validate_table_shape(
144    t: &MeasureJetMomentTable,
145    label: &str,
146) -> Result<(usize, usize), BasisError> {
147    let d = t.center.len();
148    let n_channels = t.m0.len();
149    if t.center.iter().any(|v| !v.is_finite()) {
150        crate::bail_invalid_basis!("measure-jet moment table `{label}` has a non-finite center");
151    }
152    if t.m1.dim() != (n_channels, d) {
153        crate::bail_dim_basis!(
154            "measure-jet moment table `{label}` m1 shape {:?} does not match (channels, d) = ({n_channels}, {d})",
155            t.m1.dim()
156        );
157    }
158    if t.m2.len() != n_channels {
159        crate::bail_dim_basis!(
160            "measure-jet moment table `{label}` has {} m2 blocks for {n_channels} channels",
161            t.m2.len()
162        );
163    }
164    for (ch, block) in t.m2.iter().enumerate() {
165        if block.dim() != (d, d) {
166            crate::bail_dim_basis!(
167                "measure-jet moment table `{label}` m2[{ch}] shape {:?} is not ({d}, {d})",
168                block.dim()
169            );
170        }
171    }
172    Ok((n_channels, d))
173}
174
175/// Sequential moment accumulation over one row chunk, in row order. The
176/// per-entry update order is fixed — `wg = w·g`, then `m1 += wg·dx_k`, then
177/// `m2 += (wg·dx_k)·dx_l` with that exact association — as part of the
178/// module's bit-determinism contract.
179pub(crate) fn accumulate_chunk(
180    coords: ArrayView2<'_, f64>,
181    weights: ArrayView1<'_, f64>,
182    channels: &[ArrayView1<'_, f64>],
183    center: ArrayView1<'_, f64>,
184    rows: Range<usize>,
185) -> Result<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>), BasisError> {
186    let d = center.len();
187    let n_channels = channels.len();
188    let mut m0 = Array1::<f64>::zeros(n_channels);
189    let mut m1 = Array2::<f64>::zeros((n_channels, d));
190    let mut m2: Vec<Array2<f64>> = (0..n_channels).map(|_| Array2::zeros((d, d))).collect();
191    let mut dx = vec![0.0_f64; d];
192    for r in rows {
193        let w = weights[r];
194        if !(w.is_finite() && w >= 0.0) {
195            crate::bail_invalid_basis!(
196                "measure-jet moment accumulation needs finite nonnegative weights; got {w} at row {r}"
197            );
198        }
199        for k in 0..d {
200            let x = coords[(r, k)];
201            if !x.is_finite() {
202                crate::bail_invalid_basis!(
203                    "measure-jet moment accumulation hit a non-finite coordinate at row {r}, axis {k}"
204                );
205            }
206            dx[k] = x - center[k];
207        }
208        for (ch, g) in channels.iter().enumerate() {
209            let gv = g[r];
210            if !gv.is_finite() {
211                crate::bail_invalid_basis!(
212                    "measure-jet moment accumulation hit a non-finite channel value at row {r}, channel {ch}"
213                );
214            }
215            let wg = w * gv;
216            m0[ch] += wg;
217            let m2_ch = &mut m2[ch];
218            for k in 0..d {
219                let wg_dk = wg * dx[k];
220                m1[(ch, k)] += wg_dk;
221                for l in 0..d {
222                    m2_ch[(k, l)] += wg_dk * dx[l];
223                }
224            }
225        }
226    }
227    Ok((m0, m1, m2))
228}
229
230/// Accumulate one cell's moment table from raw rows. The single point where
231/// data rows are read; everything downstream is closed-form algebra on the
232/// result.
233///
234/// `weights` are the caller-computed Gaussian kernel weights
235/// `w_i = mass_i · exp(−‖x_i − c‖²/(2ε²))` (or their truncated variant — the
236/// cutoff and its `e^{−ρ²/2}` budget are the caller's responsibility);
237/// `channels` are the per-row response channels `g_i`, typically
238/// `[ones, y]`, with channel 0 = unit by convention.
239///
240/// Streaming/parallel layout: rows are split into fixed
241/// [`MEASURE_JET_MOMENT_CHUNK_ROWS`]-sized chunks (the bms-style chunked row
242/// reduction), each chunk is accumulated sequentially in row order, and the
243/// chunk partials are folded sequentially in chunk-index order — the sorted
244/// reduction that makes the output bit-deterministic regardless of thread
245/// scheduling. `rows == 0` is allowed and yields the monoid identity at
246/// `center`.
247pub fn accumulate_moment_table(
248    coords: ArrayView2<'_, f64>,
249    weights: ArrayView1<'_, f64>,
250    channels: &[ArrayView1<'_, f64>],
251    center: ArrayView1<'_, f64>,
252) -> Result<MeasureJetMomentTable, BasisError> {
253    let n = coords.nrows();
254    let d = coords.ncols();
255    if d == 0 {
256        crate::bail_invalid_basis!(
257            "measure-jet moment accumulation needs at least one coordinate axis"
258        );
259    }
260    if center.len() != d {
261        crate::bail_dim_basis!(
262            "measure-jet moment center length {} does not match coordinate dimension {d}",
263            center.len()
264        );
265    }
266    if center.iter().any(|v| !v.is_finite()) {
267        crate::bail_invalid_basis!("measure-jet moment accumulation needs a finite center");
268    }
269    if weights.len() != n {
270        crate::bail_dim_basis!(
271            "measure-jet moment weights length {} does not match {n} rows",
272            weights.len()
273        );
274    }
275    if channels.is_empty() {
276        crate::bail_invalid_basis!(
277            "measure-jet moment accumulation needs at least one response channel (channel 0 = unit)"
278        );
279    }
280    for (ch, g) in channels.iter().enumerate() {
281        if g.len() != n {
282            crate::bail_dim_basis!(
283                "measure-jet moment channel {ch} length {} does not match {n} rows",
284                g.len()
285            );
286        }
287    }
288    let n_chunks = n.div_ceil(MEASURE_JET_MOMENT_CHUNK_ROWS).max(1);
289    let partials: Vec<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> = if n_chunks == 1 {
290        vec![accumulate_chunk(coords, weights, channels, center, 0..n)?]
291    } else {
292        (0..n_chunks)
293            .into_par_iter()
294            .map(|chunk| {
295                let start = chunk * MEASURE_JET_MOMENT_CHUNK_ROWS;
296                let end = (start + MEASURE_JET_MOMENT_CHUNK_ROWS).min(n);
297                accumulate_chunk(coords, weights, channels, center, start..end)
298            })
299            .collect::<Result<Vec<_>, BasisError>>()?
300    };
301    // Sorted reduction: fold chunk partials in chunk-index order. All
302    // partials share `center`, so the fold is plain componentwise addition.
303    let mut iter = partials.into_iter();
304    let (mut m0, mut m1, mut m2) = iter
305        .next()
306        .expect("chunk count is clamped to at least one partial");
307    for (p0, p1, p2) in iter {
308        m0 += &p0;
309        m1 += &p1;
310        for (total, part) in m2.iter_mut().zip(&p2) {
311            *total += part;
312        }
313    }
314    Ok(MeasureJetMomentTable {
315        center: center.to_owned(),
316        m0,
317        m1,
318        m2,
319    })
320}
321
322/// Exact recentering via the binomial shift: the same frozen-weight
323/// polynomial table re-expressed about `new_center`, with no kernel
324/// re-evaluation. This is not a moving-kernel identity; if the Gaussian
325/// center changes, the caller must recompute or approximate the weights.
326///
327/// Derivation (per channel; write `Δ = c − c′` so `x − c′ = (x − c) + Δ`):
328///
329/// - order 0: `μ′_0 = Σ w g = μ_0` — unchanged;
330/// - order 1: `μ′_1 = Σ w g ((x−c) + Δ) = μ_1 + Δ·μ_0`;
331/// - order 2: `μ′_2 = Σ w g ((x−c)+Δ)((x−c)+Δ)ᵀ
332///                 = μ_2 + Δ·μ_1ᵀ + μ_1·Δᵀ + ΔΔᵀ·μ_0`,
333///
334/// which is the multi-index binomial identity
335/// `μ′_α = Σ_{β≤α} C(α,β)(c−c′)^{α−β} μ_β` specialized to `|α| ≤ 2`. Every
336/// term is a finite product of stored moments and `Δ`, so the shift is an
337/// algebraic identity of the frozen-weight table — exact up to floating-point
338/// rounding, and exactly exact whenever the arithmetic is (dyadic lattices;
339/// pinned in the tests).
340///
341/// Bit-determinism: the order-2 entry is evaluated in the ONE fixed order
342/// `((μ_2 + Δ_k·μ_{1,l}) + μ_{1,k}·Δ_l) + (Δ_k·Δ_l)·μ_0`; same inputs always
343/// produce the same bits.
344pub fn recenter_moment_table(
345    t: &MeasureJetMomentTable,
346    new_center: ArrayView1<'_, f64>,
347) -> MeasureJetMomentTable {
348    let d = t.center.len();
349    assert_eq!(
350        new_center.len(),
351        d,
352        "measure-jet recenter: new center length {} does not match table dimension {d}",
353        new_center.len()
354    );
355    let n_channels = t.m0.len();
356    let mut delta = Array1::<f64>::zeros(d);
357    for k in 0..d {
358        delta[k] = t.center[k] - new_center[k];
359    }
360    let m0 = t.m0.clone();
361    let mut m1 = Array2::<f64>::zeros((n_channels, d));
362    for ch in 0..n_channels {
363        for k in 0..d {
364            m1[(ch, k)] = t.m1[(ch, k)] + delta[k] * t.m0[ch];
365        }
366    }
367    let mut m2 = Vec::with_capacity(n_channels);
368    for ch in 0..n_channels {
369        let src = &t.m2[ch];
370        let mut out = Array2::<f64>::zeros((d, d));
371        for k in 0..d {
372            for l in 0..d {
373                out[(k, l)] = ((src[(k, l)] + delta[k] * t.m1[(ch, l)]) + t.m1[(ch, k)] * delta[l])
374                    + (delta[k] * delta[l]) * t.m0[ch];
375            }
376        }
377        m2.push(out);
378    }
379    MeasureJetMomentTable {
380        center: new_center.to_owned(),
381        m0,
382        m1,
383        m2,
384    }
385}
386
387/// Lexicographic total order on cell centers (`f64::total_cmp` per
388/// coordinate). The canonical-orientation key that makes the merge bitwise
389/// argument-order-independent.
390pub(crate) fn lex_cmp_centers(a: &Array1<f64>, b: &Array1<f64>) -> Ordering {
391    for (x, y) in a.iter().zip(b.iter()) {
392        let ord = x.total_cmp(y);
393        if ord != Ordering::Equal {
394            return ord;
395        }
396    }
397    Ordering::Equal
398}
399
400/// Monoid merge: recenter compatible frozen-weight tables onto a common
401/// reference, then add componentwise. Exact for those polynomial moments
402/// (pure binomial shift, no kernel re-evaluation) and deterministic.
403///
404/// Canonical orientation: the merged table lives at the lexicographically
405/// SMALLER of the two operand centers ([`lex_cmp_centers`]), and the other
406/// operand is the one recentered. Because the (host, guest) roles depend
407/// only on the centers — never on argument position — `merge(a, b)` and
408/// `merge(b, a)` execute identical arithmetic and agree BITWISE for
409/// arbitrary inputs (IEEE addition is commutative; only grouping is not).
410/// This is a deliberate strengthening of the naive "recenter `b` onto
411/// `a.center`" rule, which is only commutative up to a recentering.
412pub fn merge_moment_tables(
413    a: &MeasureJetMomentTable,
414    b: &MeasureJetMomentTable,
415) -> Result<MeasureJetMomentTable, BasisError> {
416    let (a_channels, a_dim) = validate_table_shape(a, "a")?;
417    let (b_channels, b_dim) = validate_table_shape(b, "b")?;
418    if a_dim != b_dim || a_channels != b_channels {
419        crate::bail_dim_basis!(
420            "measure-jet merge needs matching tables; got (channels, d) = ({a_channels}, {a_dim}) vs ({b_channels}, {b_dim})"
421        );
422    }
423    let (host, guest) = if lex_cmp_centers(&a.center, &b.center) != Ordering::Greater {
424        (a, b)
425    } else {
426        (b, a)
427    };
428    let moved = recenter_moment_table(guest, host.center.view());
429    let mut m2 = Vec::with_capacity(a_channels);
430    for (h, g) in host.m2.iter().zip(&moved.m2) {
431        m2.push(h + g);
432    }
433    Ok(MeasureJetMomentTable {
434        center: host.center.clone(),
435        m0: &host.m0 + &moved.m0,
436        m1: &host.m1 + &moved.m1,
437        m2,
438    })
439}
440
441/// The local jet-fit sufficient statistics read off one table — exactly the
442/// per-block quantities `assemble_weighted_forms` (measure_jet_smooth.rs)
443/// computes from raw points when the table weights are frozen at the same
444/// center and scale, reproduced in closed form from stored moments.
445#[derive(Debug, Clone, PartialEq)]
446pub struct MeasureJetJetStats {
447    /// Kernel mass `q = Σ w_i` (unit-channel zeroth moment).
448    pub q: f64,
449    /// Weighted mean of the requested value channel: `uᵀv = m0[ch]/q`.
450    pub mean: f64,
451    /// Dimensionless slope Gram `G = Φ̃ᵀWΦ̃/q = m2[0]/(qε²) − ā·āᵀ` with
452    /// `ā = m1[0]/(qε)` (`Φ` rows are `(x_i − c)/ε`).
453    pub gram: Array2<f64>,
454    /// Local-fit right-hand side `Bᵀv/q = m1[ch]/(qε) − ā·(m0[ch]/q)` — the
455    /// vector the exact weighted affine projection consumes.
456    pub cross: Array1<f64>,
457}
458
459/// Read the local jet-fit sufficient statistics off a moment table at scale
460/// `eps`, for value channel `channel`.
461///
462/// 1:1 with `assemble_weighted_forms`' per-block math (its symbols on the
463/// right), under the energy convention that local features are the
464/// ε-SCALED offsets `Φ_{jk} = (x_{jk} − c_k)/ε`:
465///
466/// - `q     = m0[0]`                          ↔ `q = Σ_j w_j`,
467/// - `ā_k   = m1[0,k]/(q·ε)`                  ↔ `a_mean = Φᵀw/q`,
468/// - `G_kl  = m2[0][k,l]/(q·ε²) − ā_k·ā_l`    ↔ `G = (ΦᵀWΦ)/q − a·aᵀ`,
469/// - `mean  = m0[ch]/q`                       ↔ `uᵀv` (the weighted-centering
470///   projection `Cv = v − (uᵀv)·1` of the constant-annihilation contract),
471/// - `cross_k = m1[ch,k]/(q·ε) − ā_k·mean`    ↔ `Bᵀv/q` with
472///   `B = WΦ − w·aᵀ` (column-centering makes `Φ̃ᵀW·1 = 0`, so
473///   `Φ̃ᵀWCv/q = Bᵀv/q` — the exact RHS of the local affine projection).
474///
475/// For `channel == 0` (the unit channel) `mean` is exactly `1.0` and `cross`
476/// is identically `+0.0` (the same division is subtracted from itself) —
477/// the moment-level restatement of exact constant annihilation.
478pub fn jet_sufficient_stats(
479    t: &MeasureJetMomentTable,
480    eps: f64,
481    channel: usize,
482) -> Result<MeasureJetJetStats, BasisError> {
483    let (n_channels, d) = validate_table_shape(t, "t")?;
484    if !(eps.is_finite() && eps > 0.0) {
485        crate::bail_invalid_basis!(
486            "measure-jet jet stats need a finite positive scale eps; got {eps}"
487        );
488    }
489    if channel >= n_channels {
490        crate::bail_invalid_basis!(
491            "measure-jet jet stats channel {channel} out of range for {n_channels} channels"
492        );
493    }
494    let q = t.m0[0];
495    if !(q.is_finite() && q > 0.0) {
496        crate::bail_invalid_basis!(
497            "measure-jet jet stats need positive unit-channel kernel mass q; got {q}"
498        );
499    }
500    let q_eps = q * eps;
501    let mut a_mean = Array1::<f64>::zeros(d);
502    for k in 0..d {
503        a_mean[k] = t.m1[(0, k)] / q_eps;
504    }
505    let q_eps2 = q * eps * eps;
506    let m2_unit = &t.m2[0];
507    let mut gram = Array2::<f64>::zeros((d, d));
508    for k in 0..d {
509        for l in 0..d {
510            gram[(k, l)] = m2_unit[(k, l)] / q_eps2 - a_mean[k] * a_mean[l];
511        }
512    }
513    let mean = t.m0[channel] / q;
514    let mut cross = Array1::<f64>::zeros(d);
515    for k in 0..d {
516        cross[k] = t.m1[(channel, k)] / q_eps - a_mean[k] * mean;
517    }
518    Ok(MeasureJetJetStats {
519        q,
520        mean,
521        gram,
522        cross,
523    })
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use ndarray::s;
530
531    /// Closeness metric for the recenter-exactness gate: relative at scale,
532    /// absolute `tol` below unit scale (`|x−y| ≤ tol·(1 + max(|x|,|y|))`).
533    pub(crate) fn assert_tables_close(
534        a: &MeasureJetMomentTable,
535        b: &MeasureJetMomentTable,
536        tol: f64,
537    ) {
538        let pairs = |xs: &[f64], ys: &[f64], label: &str| {
539            assert_eq!(xs.len(), ys.len(), "{label}: length mismatch");
540            for (i, (x, y)) in xs.iter().zip(ys.iter()).enumerate() {
541                let scale = 1.0 + x.abs().max(y.abs());
542                assert!(
543                    (x - y).abs() <= tol * scale,
544                    "{label}[{i}]: {x} vs {y} differ beyond {tol} rel"
545                );
546            }
547        };
548        pairs(
549            a.center.as_slice().expect("contiguous center"),
550            b.center.as_slice().expect("contiguous center"),
551            "center",
552        );
553        pairs(
554            a.m0.as_slice().expect("contiguous m0"),
555            b.m0.as_slice().expect("contiguous m0"),
556            "m0",
557        );
558        pairs(
559            a.m1.as_slice().expect("contiguous m1"),
560            b.m1.as_slice().expect("contiguous m1"),
561            "m1",
562        );
563        assert_eq!(a.m2.len(), b.m2.len(), "m2: channel count mismatch");
564        for (ch, (x, y)) in a.m2.iter().zip(b.m2.iter()).enumerate() {
565            pairs(
566                x.as_slice().expect("contiguous m2"),
567                y.as_slice().expect("contiguous m2"),
568                &format!("m2[{ch}]"),
569            );
570        }
571    }
572
573    /// Bit-identity gate: every stored f64 must agree by `to_bits`.
574    pub(crate) fn assert_tables_bit_identical(
575        a: &MeasureJetMomentTable,
576        b: &MeasureJetMomentTable,
577    ) {
578        let bits = |xs: &[f64], ys: &[f64], label: &str| {
579            assert_eq!(xs.len(), ys.len(), "{label}: length mismatch");
580            for (i, (x, y)) in xs.iter().zip(ys.iter()).enumerate() {
581                assert_eq!(
582                    x.to_bits(),
583                    y.to_bits(),
584                    "{label}[{i}]: {x} vs {y} differ bitwise"
585                );
586            }
587        };
588        bits(
589            a.center.as_slice().expect("contiguous center"),
590            b.center.as_slice().expect("contiguous center"),
591            "center",
592        );
593        bits(
594            a.m0.as_slice().expect("contiguous m0"),
595            b.m0.as_slice().expect("contiguous m0"),
596            "m0",
597        );
598        bits(
599            a.m1.as_slice().expect("contiguous m1"),
600            b.m1.as_slice().expect("contiguous m1"),
601            "m1",
602        );
603        assert_eq!(a.m2.len(), b.m2.len(), "m2: channel count mismatch");
604        for (ch, (x, y)) in a.m2.iter().zip(b.m2.iter()).enumerate() {
605            bits(
606                x.as_slice().expect("contiguous m2"),
607                y.as_slice().expect("contiguous m2"),
608                &format!("m2[{ch}]"),
609            );
610        }
611    }
612
613    /// Deterministic generic-float dataset (no RNG): low-discrepancy
614    /// fractional parts, d = 3, with a unit channel and one value channel.
615    pub(crate) fn float_dataset(n: usize) -> (Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>) {
616        let mut coords = Array2::<f64>::zeros((n, 3));
617        let mut weights = Array1::<f64>::zeros(n);
618        let mut ones = Array1::<f64>::zeros(n);
619        let mut y = Array1::<f64>::zeros(n);
620        for i in 0..n {
621            let t = (i + 1) as f64;
622            coords[(i, 0)] = (t * 0.618034).fract() * 4.0 - 2.0;
623            coords[(i, 1)] = (t * 0.414214).fract() * 3.0 - 1.0;
624            coords[(i, 2)] = (t * 0.732051).fract() * 2.0 - 1.5;
625            weights[i] = 0.05 + (t * 0.292893).fract();
626            ones[i] = 1.0;
627            y[i] = (t * 0.539345).fract() * 6.0 - 3.0;
628        }
629        (coords, weights, ones, y)
630    }
631
632    /// Dyadic-lattice dataset: integer coordinates and channel values,
633    /// dyadic weights — every moment product and sum is exactly
634    /// representable in f64, so the algebraic monoid laws become BIT
635    /// identities and the tests below can pin them with `to_bits`.
636    pub(crate) fn dyadic_dataset() -> (Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>) {
637        let coords = ndarray::array![
638            [3.0, -2.0],
639            [1.0, 4.0],
640            [-5.0, 2.0],
641            [2.0, 2.0],
642            [4.0, -1.0],
643            [0.0, 5.0],
644            [-3.0, -4.0],
645            [6.0, 1.0],
646            [-1.0, 3.0],
647            [5.0, -3.0],
648            [2.0, 7.0],
649            [-6.0, -2.0],
650            [3.0, 3.0],
651            [1.0, -5.0],
652            [4.0, 6.0],
653            [-2.0, -3.0],
654        ];
655        let weights = ndarray::array![
656            0.5, 1.0, 2.0, 0.25, 1.5, 0.75, 1.0, 0.5, 2.5, 1.25, 0.5, 3.0, 0.75, 1.0, 1.75, 2.0
657        ];
658        let ones = Array1::<f64>::ones(16);
659        let y = ndarray::array![
660            2.0, -3.0, 5.0, 1.0, -4.0, 7.0, 2.0, -6.0, 3.0, 4.0, -2.0, 8.0, 1.0, -7.0, 5.0, -1.0
661        ];
662        (coords, weights, ones, y)
663    }
664
665    #[test]
666    pub(crate) fn recenter_is_exact() {
667        let (coords, weights, ones, y) = float_dataset(40);
668        let channels = [ones.view(), y.view()];
669        let c = ndarray::array![0.4, -0.3, 0.9];
670        let c_prime = ndarray::array![-1.1, 0.25, 0.5];
671        let at_c = accumulate_moment_table(coords.view(), weights.view(), &channels, c.view())
672            .expect("accumulation about c");
673        let shifted = recenter_moment_table(&at_c, c_prime.view());
674        let direct =
675            accumulate_moment_table(coords.view(), weights.view(), &channels, c_prime.view())
676                .expect("accumulation about c'");
677        assert_tables_close(&shifted, &direct, 1e-14);
678        // Round trip back to c reproduces the original to the same gate.
679        let back = recenter_moment_table(&shifted, c.view());
680        assert_tables_close(&back, &at_c, 1e-14);
681    }
682
683    #[test]
684    pub(crate) fn merge_is_associative_and_commutative_bitwise() {
685        // Dyadic lattice ⇒ all moment/shift arithmetic is exact, so the
686        // monoid laws hold BITWISE across groupings (the sorted-reduction
687        // convention covers generic-float grouping determinism; see the
688        // module docs).
689        let (coords, weights, ones, y) = dyadic_dataset();
690        let chunk = |rows: Range<usize>, center: &Array1<f64>| {
691            let ones_c = ones.slice(s![rows.clone()]);
692            let y_c = y.slice(s![rows.clone()]);
693            accumulate_moment_table(
694                coords.slice(s![rows.clone(), ..]),
695                weights.slice(s![rows]),
696                &[ones_c, y_c],
697                center.view(),
698            )
699            .expect("chunk accumulation")
700        };
701        let c_a = ndarray::array![2.0, -1.0];
702        let c_b = ndarray::array![0.0, 3.0];
703        let c_c = ndarray::array![-4.0, 1.0];
704        let a = chunk(0..5, &c_a);
705        let b = chunk(5..9, &c_b);
706        let c = chunk(9..14, &c_c);
707
708        // Commutativity is bitwise for ARBITRARY inputs: the canonical
709        // center orientation makes merge(a, b) and merge(b, a) execute
710        // identical arithmetic. No recentering needed before comparing.
711        let ab = merge_moment_tables(&a, &b).expect("a+b");
712        let ba = merge_moment_tables(&b, &a).expect("b+a");
713        assert_tables_bit_identical(&ab, &ba);
714        // ... including on generic (non-dyadic) float data.
715        let (fc, fw, fo, fy) = float_dataset(24);
716        let fa = accumulate_moment_table(
717            fc.slice(s![0..12, ..]),
718            fw.slice(s![0..12]),
719            &[fo.slice(s![0..12]), fy.slice(s![0..12])],
720            ndarray::array![0.3, -0.7, 0.1].view(),
721        )
722        .expect("float chunk a");
723        let fb = accumulate_moment_table(
724            fc.slice(s![12..24, ..]),
725            fw.slice(s![12..24]),
726            &[fo.slice(s![12..24]), fy.slice(s![12..24])],
727            ndarray::array![-0.9, 0.4, 0.6].view(),
728        )
729        .expect("float chunk b");
730        assert_tables_bit_identical(
731            &merge_moment_tables(&fa, &fb).expect("fa+fb"),
732            &merge_moment_tables(&fb, &fa).expect("fb+fa"),
733        );
734
735        // Associativity, bitwise on the exact lattice.
736        let ab_c = merge_moment_tables(&ab, &c).expect("(a+b)+c");
737        let bc = merge_moment_tables(&b, &c).expect("b+c");
738        let a_bc = merge_moment_tables(&a, &bc).expect("a+(b+c)");
739        assert_tables_bit_identical(&ab_c, &a_bc);
740        // And after recentering both to a common reference.
741        let c_ref = ndarray::array![1.0, 2.0];
742        assert_tables_bit_identical(
743            &recenter_moment_table(&ab_c, c_ref.view()),
744            &recenter_moment_table(&a_bc, c_ref.view()),
745        );
746    }
747
748    #[test]
749    pub(crate) fn jet_stats_match_assemble_weighted_forms_math() {
750        // Small 2-D point set, replicating assemble_weighted_forms' local
751        // loop verbatim from raw points: w_j = mass_j·exp(−d²/(2ε²)),
752        // q = Σ w, Φ_{jk} = (x_{jk} − c_k)/ε, a = Φᵀw/q,
753        // G = (ΦᵀWΦ)/q − a·aᵀ, uᵀv = wᵀv/q, Bᵀv/q with B = WΦ − w·aᵀ.
754        let pts = ndarray::array![
755            [0.0, 0.0],
756            [0.45, -0.2],
757            [-0.35, 0.4],
758            [0.25, 0.55],
759            [-0.5, -0.45],
760            [0.6, 0.3]
761        ];
762        let masses = ndarray::array![0.22, 0.13, 0.19, 0.11, 0.2, 0.15];
763        let y = ndarray::array![0.7, -1.3, 2.1, 0.4, -0.6, 1.9];
764        let center = ndarray::array![0.0, 0.0];
765        let eps = 0.75;
766        let m = pts.nrows();
767        let d = pts.ncols();
768
769        // Kernel weights exactly as the workhorse forms them.
770        let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
771        let mut w = Array1::<f64>::zeros(m);
772        let mut q = 0.0_f64;
773        for j in 0..m {
774            let mut dist2 = 0.0_f64;
775            for k in 0..d {
776                let dlt = pts[(j, k)] - center[k];
777                dist2 += dlt * dlt;
778            }
779            w[j] = masses[j] * (-dist2 * inv_two_eps2).exp();
780            q += w[j];
781        }
782        let mut phi = Array2::<f64>::zeros((m, d));
783        for j in 0..m {
784            for k in 0..d {
785                phi[(j, k)] = (pts[(j, k)] - center[k]) / eps;
786            }
787        }
788        let a_mean = phi.t().dot(&w) / q;
789        let mut wphi = phi.clone();
790        for (j, mut row) in wphi.outer_iter_mut().enumerate() {
791            row.mapv_inplace(|v| v * w[j]);
792        }
793        let mut g_ref = phi.t().dot(&wphi);
794        g_ref.mapv_inplace(|v| v / q);
795        for r in 0..d {
796            for c in 0..d {
797                g_ref[(r, c)] -= a_mean[r] * a_mean[c];
798            }
799        }
800        let mean_ref = w.dot(&y) / q;
801        let mut cross_ref = Array1::<f64>::zeros(d);
802        for k in 0..d {
803            let mut acc = 0.0_f64;
804            for j in 0..m {
805                acc += (wphi[(j, k)] - w[j] * a_mean[k]) * y[j];
806            }
807            cross_ref[k] = acc / q;
808        }
809
810        // Substrate path: same caller-computed weights into a moment table.
811        let ones = Array1::<f64>::ones(m);
812        let table = accumulate_moment_table(
813            pts.view(),
814            w.view(),
815            &[ones.view(), y.view()],
816            center.view(),
817        )
818        .expect("moment table");
819        let stats = jet_sufficient_stats(&table, eps, 1).expect("jet stats");
820
821        let tol = 1e-14;
822        let close = |x: f64, y_: f64, label: &str| {
823            let scale = 1.0 + x.abs().max(y_.abs());
824            assert!(
825                (x - y_).abs() <= tol * scale,
826                "{label}: {x} vs {y_} beyond {tol} rel"
827            );
828        };
829        close(stats.q, q, "q");
830        close(stats.mean, mean_ref, "mean");
831        for k in 0..d {
832            close(stats.cross[k], cross_ref[k], &format!("cross[{k}]"));
833            for l in 0..d {
834                close(stats.gram[(k, l)], g_ref[(k, l)], &format!("gram[{k},{l}]"));
835            }
836        }
837
838        // Unit channel: exact constant annihilation at the moment level —
839        // mean is exactly 1, cross is identically +0.0.
840        let unit_stats = jet_sufficient_stats(&table, eps, 0).expect("unit-channel stats");
841        assert_eq!(unit_stats.mean, 1.0, "unit-channel mean must be exactly 1");
842        for k in 0..d {
843            assert_eq!(
844                unit_stats.cross[k], 0.0,
845                "unit-channel cross[{k}] must be exactly zero"
846            );
847        }
848    }
849
850    /// LEVEL/TILT truth-recovery gate (#1041). The deficit pattern flagged in
851    /// the 8-dataset benchmark — worst on pooled/pointwise risk (RMSE/Brier/R²)
852    /// but only mid-pack on calibration SLOPE — is the fingerprint of a biased
853    /// affine projection: a systematic shift in the recovered LEVEL `c₀` or a
854    /// TILT in the recovered gradient `g`. The local affine sufficient statistic
855    /// this module computes (`mean`, `G`, `cross`) is the exact object that
856    /// projection consumes, so a bias there would surface here.
857    ///
858    /// Construct a channel value that is EXACTLY affine in the coordinates,
859    /// `v(x) = c₀ + gᵀ(x − center)`, under ARBITRARY (non-symmetric) weights.
860    /// The weighted affine projection must then recover `(c₀, g)` with ZERO
861    /// residual — the curved/higher-order energy is empty, so any nonzero level
862    /// or tilt error is pure projection bias, not a smoothing artifact. We
863    /// assert this across SHRINKING kernel widths ε (concentrating the weights),
864    /// the regime where a level/tilt bias in the centered second moment `G` or
865    /// the centered cross `Bᵀv/q` would be amplified.
866    #[test]
867    pub(crate) fn affine_projection_recovers_level_and_tilt_without_bias() {
868        // Asymmetric, off-center point cloud so the weighted barycenter does
869        // NOT coincide with the reference center: this is exactly where a
870        // mis-centered (biased) projection would leak the level into the tilt
871        // and vice versa.
872        let pts = ndarray::array![
873            [0.10, -0.30],
874            [0.62, 0.05],
875            [-0.18, 0.44],
876            [0.37, 0.51],
877            [-0.46, -0.22],
878            [0.71, 0.33],
879            [0.05, 0.62],
880            [-0.33, 0.14],
881        ];
882        // Strictly positive, deliberately uneven masses (no symmetry to lean on).
883        let masses = ndarray::array![0.31, 0.07, 0.22, 0.05, 0.19, 0.11, 0.27, 0.13];
884        let center = ndarray::array![0.05, 0.10];
885        let m = pts.nrows();
886        let d = pts.ncols();
887
888        // Exact affine truth in ambient coordinates: level c0, gradient g.
889        let c0 = 1.37_f64;
890        let g = ndarray::array![-0.85_f64, 0.42_f64];
891        let mut v = Array1::<f64>::zeros(m);
892        for j in 0..m {
893            let mut acc = c0;
894            for k in 0..d {
895                acc += g[k] * (pts[(j, k)] - center[k]);
896            }
897            v[j] = acc;
898        }
899
900        let ones = Array1::<f64>::ones(m);
901        // Tighten the kernel across several scales: shrinking eps concentrates
902        // the Gaussian weights and amplifies any centering/projection bias.
903        for &eps in &[1.0_f64, 0.5, 0.25, 0.12, 0.06] {
904            let inv_two_eps2 = 1.0 / (2.0 * eps * eps);
905            let mut w = Array1::<f64>::zeros(m);
906            for j in 0..m {
907                let mut dist2 = 0.0_f64;
908                for k in 0..d {
909                    let dlt = pts[(j, k)] - center[k];
910                    dist2 += dlt * dlt;
911                }
912                w[j] = masses[j] * (-dist2 * inv_two_eps2).exp();
913            }
914
915            let table = accumulate_moment_table(
916                pts.view(),
917                w.view(),
918                &[ones.view(), v.view()],
919                center.view(),
920            )
921            .expect("moment table");
922            let stats = jet_sufficient_stats(&table, eps, 1).expect("affine jet stats");
923
924            // The weighted affine projection solves `G b̂ = cross` for the
925            // ε-scaled slope; the ambient gradient is b̂/ε and the recovered
926            // LEVEL is `mean − āᵀ b̂` (the weighted mean minus the slope's
927            // contribution at the weighted barycenter). For an exactly affine
928            // truth both must equal the truth with zero residual.
929            //
930            // Solve the 2×2 SPD system directly (no external solver) so the
931            // test pins the projection math, not a library inverse.
932            let g00 = stats.gram[(0, 0)];
933            let g01 = stats.gram[(0, 1)];
934            let g11 = stats.gram[(1, 1)];
935            let det = g00 * g11 - g01 * g01;
936            assert!(
937                det > 1e-10,
938                "centered slope Gram must stay nondegenerate at eps={eps}; det={det}"
939            );
940            let b0 = (g11 * stats.cross[0] - g01 * stats.cross[1]) / det;
941            let b1 = (-g01 * stats.cross[0] + g00 * stats.cross[1]) / det;
942            // Ambient gradient = scaled slope / eps (Φ rows are (x−c)/ε).
943            let grad = [b0 / eps, b1 / eps];
944
945            // Recovered weighted barycenter offset ā (ambient) = a_mean·ε.
946            // Level at the reference center = mean − gradᵀ·(barycenter − center)
947            //                               = mean − (b̂ᵀ ā).
948            let a_mean0 = table.m1[(0, 0)] / (stats.q * eps);
949            let a_mean1 = table.m1[(0, 1)] / (stats.q * eps);
950            let level = stats.mean - (b0 * a_mean0 + b1 * a_mean1);
951
952            // TILT: the recovered gradient must match the truth — no systematic
953            // rotation/scaling of the slope channel.
954            assert!(
955                (grad[0] - g[0]).abs() <= 1e-9 && (grad[1] - g[1]).abs() <= 1e-9,
956                "TILT bias at eps={eps}: recovered gradient {grad:?} vs truth {g:?}"
957            );
958            // LEVEL: the recovered intercept at the reference center must match
959            // the truth — no systematic offset of the reconstructed surface.
960            assert!(
961                (level - c0).abs() <= 1e-9,
962                "LEVEL bias at eps={eps}: recovered {level} vs truth {c0}"
963            );
964        }
965    }
966
967    #[test]
968    pub(crate) fn streaming_chunked_accumulation_matches_single_pass() {
969        // Four chunks, each accumulated about its OWN center, merged in
970        // chunk-index order (the sorted reduction) — versus one pass about
971        // the lexicographically smallest chunk center. Dyadic lattice ⇒ the
972        // agreement is exact, pinned bitwise.
973        let (coords, weights, ones, y) = dyadic_dataset();
974        let centers = [
975            ndarray::array![-3.0, 2.0], // lexicographic minimum: merge target
976            ndarray::array![0.0, 0.0],
977            ndarray::array![1.0, -5.0],
978            ndarray::array![4.0, 1.0],
979        ];
980        let chunk = |rows: Range<usize>, center: &Array1<f64>| {
981            let ones_c = ones.slice(s![rows.clone()]);
982            let y_c = y.slice(s![rows.clone()]);
983            accumulate_moment_table(
984                coords.slice(s![rows.clone(), ..]),
985                weights.slice(s![rows]),
986                &[ones_c, y_c],
987                center.view(),
988            )
989            .expect("chunk accumulation")
990        };
991        let t0 = chunk(0..4, &centers[0]);
992        let t1 = chunk(4..8, &centers[1]);
993        let t2 = chunk(8..12, &centers[2]);
994        let t3 = chunk(12..16, &centers[3]);
995        let merged = merge_moment_tables(
996            &merge_moment_tables(&merge_moment_tables(&t0, &t1).expect("t0+t1"), &t2)
997                .expect("(t0+t1)+t2"),
998            &t3,
999        )
1000        .expect("((t0+t1)+t2)+t3");
1001        let single = accumulate_moment_table(
1002            coords.view(),
1003            weights.view(),
1004            &[ones.view(), y.view()],
1005            centers[0].view(),
1006        )
1007        .expect("single pass");
1008        // The fold target is the lex-min center, so no final recentering is
1009        // even needed; pin that and the bitwise agreement.
1010        assert_tables_bit_identical(&merged, &single);
1011        // Merging the identity is a no-op.
1012        let with_zero =
1013            merge_moment_tables(&merged, &MeasureJetMomentTable::zero(centers[0].clone(), 2))
1014                .expect("merge with identity");
1015        assert_tables_bit_identical(&with_zero, &merged);
1016    }
1017}