Skip to main content

gam_problem/
gauge.rs

1// One Gauge object (#933).
2//
3// Every identifiability mechanism in the engine performs the same
4// mathematical act: quotient the coefficient space by directions in
5// ker(J) ∩ ker(S), pick a section, fit in the reduced coordinates θ,
6// and lift estimates / covariance / geometry back to the raw
7// coordinates β. This module owns that act once.
8//
9// A `Gauge` is the affine section itself: the lift matrix
10// `T : reduced → raw` plus an affine shift `a`
11// (`β_raw = T · θ + a`) together with the per-block partitions
12// of both coordinate systems. Block-diagonal `T`
13// (independent per-block reductions, the canonical-audit case) and
14// block-upper-triangular `T` (cross-block residualisation, the
15// survival V+M-exact compile) are the same object — the partitions
16// record where each block's rows/columns live.
17//
18// Lift conventions (the whole point — there is exactly one):
19//   - point estimate:   β_raw = T · θ + a
20//   - covariance / any symmetric bilinear form: Σ_raw = T · Σ_θ · Tᵀ
21//   - η is invariant:   X_raw · (T · θ + a) = X_reduced · θ + offset_reduced
22//
23// Raw directions outside the section (zero rows of `T`) receive exactly
24// zero estimate, zero variance, and zero covariance with every other
25// coordinate: a coordinate the reduced fit cannot move carries no
26// posterior uncertainty in raw space.
27
28use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
29
30use gam_linalg::faer_ndarray::{fast_ab, fast_abt, fast_atb};
31
32/// Neutral view of a compiled identifiability reparametrisation that
33/// [`Gauge::from_compiled_map`] consumes. The concrete `CompiledMap`
34/// emitted by the identifiability compiler lives ABOVE this crate, so
35/// `Gauge` names only this trait (inverted dependency #1521); the
36/// compiler crate provides the `impl`.
37///
38/// `raw_from_compiled` IS the global triangular lift `T`; the two block
39/// range slices give the raw-width and compiled-width column partitions.
40pub trait CompiledBlockMap {
41    /// The `(p_raw × p_compiled)` raw-from-compiled reparam matrix `T`.
42    fn raw_from_compiled(&self) -> &Array2<f64>;
43    /// Per-block raw-width column ranges.
44    fn raw_block_ranges(&self) -> &[std::ops::Range<usize>];
45    /// Per-block compiled-width column ranges, parallel to
46    /// [`Self::raw_block_ranges`].
47    fn compiled_block_ranges(&self) -> &[std::ops::Range<usize>];
48}
49
50/// The lift `T : reduced → raw` plus the per-block partitions of both
51/// coordinate systems. See the module docs for the lift conventions.
52#[derive(Debug, Clone)]
53pub struct Gauge {
54    /// Global lift matrix, shape `(Σ p_b_raw) × (Σ r_b_reduced)`.
55    pub t_full: Array2<f64>,
56    /// Global affine shift in raw coordinates, length `Σ p_b_raw`.
57    pub affine_shift: Array1<f64>,
58    /// Raw-coordinate block partition: `block_starts_raw[b]..block_starts_raw[b+1]`
59    /// is block `b`'s raw row range in `t_full`. Length `n_blocks + 1`, starts at 0.
60    pub block_starts_raw: Vec<usize>,
61    /// Reduced-coordinate block partition (columns of `t_full`), same layout.
62    pub block_starts_reduced: Vec<usize>,
63}
64
65fn starts_from_widths(widths: &[usize]) -> Vec<usize> {
66    let mut starts = Vec::with_capacity(widths.len() + 1);
67    starts.push(0);
68    for w in widths {
69        starts.push(starts.last().copied().unwrap() + w);
70    }
71    starts
72}
73
74/// Assemble a block-upper-triangular lift `T` from per-block diagonal
75/// `V_b` matrices and strictly-upper residualisation blocks `R_{a→b}`.
76///
77/// `r_per_term[b]` (when `Some`) packs ALL strictly-upper off-diagonal
78/// columns for block `b` stacked row-wise across all earlier-priority
79/// blocks `a < b`: `nrows = Σ_{a<b} v_per_term[a].nrows()`,
80/// `ncols = v_per_term[b].ncols()`. The assembled `T` carries `V_b` on
81/// the diagonal and `−R_{a→b}` at `(a, b)`. `r_per_term[0]` must be
82/// `None` (no earlier block to residualise against).
83pub fn assemble_block_triangular_t(
84    v_per_term: &[Array2<f64>],
85    r_per_term: &[Option<Array2<f64>>],
86) -> Array2<f64> {
87    assert_eq!(
88        v_per_term.len(),
89        r_per_term.len(),
90        "assemble_block_triangular_t: v_per_term len {} != r_per_term len {}",
91        v_per_term.len(),
92        r_per_term.len(),
93    );
94    let raw_widths: Vec<usize> = v_per_term.iter().map(|v| v.nrows()).collect();
95    let kept_widths: Vec<usize> = v_per_term.iter().map(|v| v.ncols()).collect();
96    let row_offsets = starts_from_widths(&raw_widths);
97    let col_offsets = starts_from_widths(&kept_widths);
98    let total_rows = row_offsets.last().copied().unwrap_or(0);
99    let total_cols = col_offsets.last().copied().unwrap_or(0);
100    let mut t = Array2::<f64>::zeros((total_rows, total_cols));
101    // Diagonal: place V_b at (b, b).
102    for (b, v) in v_per_term.iter().enumerate() {
103        let r = v.nrows();
104        let c = v.ncols();
105        if r > 0 && c > 0 {
106            t.slice_mut(ndarray::s![
107                row_offsets[b]..row_offsets[b] + r,
108                col_offsets[b]..col_offsets[b] + c
109            ])
110            .assign(v);
111        }
112    }
113    // Strict upper triangle: for each b ≥ 1, place −R_{a→b} at (a, b),
114    // a < b, slicing the row-stacked `r_per_term[b]` in earlier-block order.
115    for b in 1..v_per_term.len() {
116        let Some(r_stack) = r_per_term[b].as_ref() else {
117            continue;
118        };
119        let kept_b = kept_widths[b];
120        assert_eq!(
121            r_stack.ncols(),
122            kept_b,
123            "assemble_block_triangular_t: r_per_term[{b}] has {} cols, expected {}",
124            r_stack.ncols(),
125            kept_b,
126        );
127        let expected_rows: usize = raw_widths.iter().take(b).sum();
128        assert_eq!(
129            r_stack.nrows(),
130            expected_rows,
131            "assemble_block_triangular_t: r_per_term[{b}] has {} rows, expected {} \
132             (sum of raw_widths[0..{}])",
133            r_stack.nrows(),
134            expected_rows,
135            b,
136        );
137        let mut local_row = 0usize;
138        for a in 0..b {
139            let r_a = raw_widths[a];
140            if r_a == 0 || kept_b == 0 {
141                local_row += r_a;
142                continue;
143            }
144            let block = r_stack.slice(ndarray::s![local_row..local_row + r_a, ..]);
145            let mut dst = t.slice_mut(ndarray::s![
146                row_offsets[a]..row_offsets[a] + r_a,
147                col_offsets[b]..col_offsets[b] + kept_b
148            ]);
149            for i in 0..r_a {
150                for j in 0..kept_b {
151                    dst[[i, j]] = -block[[i, j]];
152                }
153            }
154            local_row += r_a;
155        }
156    }
157    t
158}
159
160impl Gauge {
161    /// The trivial section: raw == reduced for every block.
162    pub fn identity(raw_widths: &[usize]) -> Self {
163        let transforms: Vec<Array2<f64>> =
164            raw_widths.iter().map(|&w| Array2::<f64>::eye(w)).collect();
165        Self::from_block_transforms(&transforms)
166    }
167
168    /// Block-diagonal section from independent per-block lifts
169    /// `T_b : reduced_b → raw_b` (selection matrices from the canonical
170    /// audit, orthogonalisation `V_b`s, or their compositions).
171    pub fn from_block_transforms(transforms: &[Array2<f64>]) -> Self {
172        let raw_total: usize = transforms.iter().map(|t| t.nrows()).sum();
173        Self::from_block_transforms_with_shift(transforms, Array1::zeros(raw_total))
174    }
175
176    /// Block-diagonal affine section from independent per-block lifts
177    /// plus one concatenated raw-coordinate shift.
178    pub fn from_block_transforms_with_shift(
179        transforms: &[Array2<f64>],
180        affine_shift: Array1<f64>,
181    ) -> Self {
182        let r_none: Vec<Option<Array2<f64>>> = transforms.iter().map(|_| None).collect();
183        let mut gauge = Self::from_v_and_r(transforms, &r_none);
184        assert_eq!(
185            affine_shift.len(),
186            gauge.raw_total(),
187            "Gauge::from_block_transforms_with_shift: affine shift len {} != raw width {}",
188            affine_shift.len(),
189            gauge.raw_total(),
190        );
191        gauge.affine_shift = affine_shift;
192        gauge
193    }
194
195    /// Single-block affine section.
196    pub fn from_block_transform_with_shift(
197        transform: Array2<f64>,
198        affine_shift: Array1<f64>,
199    ) -> Self {
200        Self::from_block_transforms_with_shift(&[transform], affine_shift)
201    }
202
203    /// Block-upper-triangular section from per-block `V_b` plus
204    /// cross-block residualisation stacks `R_{a→b}` — see
205    /// [`assemble_block_triangular_t`] for the packing convention.
206    pub fn from_v_and_r(v_per_term: &[Array2<f64>], r_per_term: &[Option<Array2<f64>>]) -> Self {
207        let raw_widths: Vec<usize> = v_per_term.iter().map(|v| v.nrows()).collect();
208        let reduced_widths: Vec<usize> = v_per_term.iter().map(|v| v.ncols()).collect();
209        Self {
210            t_full: assemble_block_triangular_t(v_per_term, r_per_term),
211            affine_shift: Array1::zeros(raw_widths.iter().sum::<usize>()),
212            block_starts_raw: starts_from_widths(&raw_widths),
213            block_starts_reduced: starts_from_widths(&reduced_widths),
214        }
215    }
216
217    /// The sum-to-zero (centering) section as a first-class single-block
218    /// gauge. `z` is the `(k × (k−1))` reparametrisation matrix returned by
219    /// `terms::basis::duchon_thinplate::apply_sum_to_zero_constraint`
220    /// (an orthonormal basis for `null(cᵀ)`, `c = Bᵀw` the weighted column
221    /// sums): the constrained design is `B_c = B · z`, so on the model
222    /// `η = B · β_raw = B_c · θ = B · z · θ` the raw coefficients lift back
223    /// from the reduced (centred) coefficients by exactly `β_raw = z · θ`.
224    ///
225    /// That is the one Gauge convention with `T = z` over a single block, so
226    /// the centring constraint stops being a special-cased outside-the-object
227    /// transform and becomes a `Gauge` section like every other reduction:
228    /// the covariance / penalised-Hessian of the centred fit pushes forward to
229    /// the raw basis through the SAME `z` via [`Gauge::lift_covariance`].
230    ///
231    /// `z` is taken as the section itself (rather than recomputed from a basis)
232    /// because the constraint matrix is the only gauge-relevant artifact — the
233    /// basis the column sums were taken over is irrelevant to the lift. The
234    /// only requirement is the structural one of a centring section:
235    /// `z.ncols() < z.nrows()` (at least one direction is removed); an identity
236    /// `z` would be `Gauge::identity` and is rejected so callers do not silently
237    /// treat an unconstrained block as centred.
238    pub fn sum_to_zero(z: Array2<f64>) -> Self {
239        let (k, r) = z.dim();
240        assert!(
241            k > 0 && r < k,
242            "Gauge::sum_to_zero: z must be a tall reparametrisation ({k}×{r}); \
243             a centring section removes at least one direction (r < k)",
244        );
245        Self::from_block_transforms(&[z])
246    }
247
248    /// Wrap an already-assembled global `T` given the per-block raw and
249    /// reduced width partitions.
250    pub fn from_t(t_full: Array2<f64>, raw_widths: &[usize], reduced_widths: &[usize]) -> Self {
251        let total_raw: usize = raw_widths.iter().sum();
252        Self::from_t_with_shift(t_full, raw_widths, reduced_widths, Array1::zeros(total_raw))
253    }
254
255    /// Wrap an already-assembled global affine section `β = Tθ + a` given the
256    /// per-block raw and reduced width partitions.
257    pub fn from_t_with_shift(
258        t_full: Array2<f64>,
259        raw_widths: &[usize],
260        reduced_widths: &[usize],
261        affine_shift: Array1<f64>,
262    ) -> Self {
263        assert_eq!(
264            raw_widths.len(),
265            reduced_widths.len(),
266            "Gauge::from_t: raw_widths len {} != reduced_widths len {}",
267            raw_widths.len(),
268            reduced_widths.len(),
269        );
270        let total_raw: usize = raw_widths.iter().sum();
271        let total_reduced: usize = reduced_widths.iter().sum();
272        assert_eq!(
273            t_full.dim(),
274            (total_raw, total_reduced),
275            "Gauge::from_t: T has shape {:?}, expected ({total_raw}, {total_reduced})",
276            t_full.dim(),
277        );
278        assert_eq!(
279            affine_shift.len(),
280            total_raw,
281            "Gauge::from_t_with_shift: affine shift len {} != raw width {total_raw}",
282            affine_shift.len(),
283        );
284        Self {
285            t_full,
286            affine_shift,
287            block_starts_raw: starts_from_widths(raw_widths),
288            block_starts_reduced: starts_from_widths(reduced_widths),
289        }
290    }
291
292    /// Build from a compiled identifiability reparametrisation
293    /// (see [`CompiledBlockMap`], implemented for the `CompiledMap` emitted by
294    /// the identifiability compiler): `map.raw_from_compiled()` IS the global
295    /// triangular `T`, and the block ranges give both partitions. `ordering`
296    /// is accepted purely as a length sanity check.
297    pub fn from_compiled_map<M: CompiledBlockMap, O>(map: &M, ordering: &[O]) -> Self {
298        assert_eq!(
299            map.raw_block_ranges().len(),
300            map.compiled_block_ranges().len(),
301            "Gauge::from_compiled_map: CompiledMap raw_block_ranges len {} != \
302             compiled_block_ranges len {}",
303            map.raw_block_ranges().len(),
304            map.compiled_block_ranges().len(),
305        );
306        assert_eq!(
307            map.raw_block_ranges().len(),
308            ordering.len(),
309            "Gauge::from_compiled_map: ordering len {} != block count {}",
310            ordering.len(),
311            map.raw_block_ranges().len(),
312        );
313        let mut block_starts_raw = Vec::with_capacity(map.raw_block_ranges().len() + 1);
314        block_starts_raw.push(0);
315        for r in map.raw_block_ranges() {
316            block_starts_raw.push(r.end);
317        }
318        let mut block_starts_reduced = Vec::with_capacity(map.compiled_block_ranges().len() + 1);
319        block_starts_reduced.push(0);
320        for r in map.compiled_block_ranges() {
321            block_starts_reduced.push(r.end);
322        }
323        let total_raw = block_starts_raw.last().copied().unwrap_or(0);
324        Self {
325            t_full: map.raw_from_compiled().clone(),
326            affine_shift: Array1::zeros(total_raw),
327            block_starts_raw,
328            block_starts_reduced,
329        }
330    }
331
332    /// Number of blocks in the partition.
333    pub fn n_blocks(&self) -> usize {
334        self.block_starts_raw.len().saturating_sub(1)
335    }
336
337    /// Total raw width `Σ p_b`.
338    pub fn raw_total(&self) -> usize {
339        self.block_starts_raw.last().copied().unwrap_or(0)
340    }
341
342    /// Total reduced width `Σ r_b`.
343    pub fn reduced_total(&self) -> usize {
344        self.block_starts_reduced.last().copied().unwrap_or(0)
345    }
346
347    /// Per-block raw widths.
348    pub fn raw_widths(&self) -> Vec<usize> {
349        self.block_starts_raw
350            .windows(2)
351            .map(|w| w[1] - w[0])
352            .collect()
353    }
354
355    /// Per-block reduced widths.
356    pub fn reduced_widths(&self) -> Vec<usize> {
357        self.block_starts_reduced
358            .windows(2)
359            .map(|w| w[1] - w[0])
360            .collect()
361    }
362
363    /// The diagonal slab `T_b = T[raw_b, reduced_b]` of block `b`.
364    /// For a block-diagonal gauge this is the whole story for the
365    /// block; for a triangular gauge it omits the cross-block `−R`.
366    pub fn block_transform(&self, b: usize) -> Array2<f64> {
367        assert!(
368            b < self.n_blocks(),
369            "Gauge::block_transform: block {b} out of range {}",
370            self.n_blocks(),
371        );
372        self.t_full
373            .slice(ndarray::s![
374                self.block_starts_raw[b]..self.block_starts_raw[b + 1],
375                self.block_starts_reduced[b]..self.block_starts_reduced[b + 1]
376            ])
377            .to_owned()
378    }
379
380    /// Compose a raw design with the section: `X_reduced = X_raw · T`.
381    pub fn restrict_design<S: Data<Elem = f64>>(
382        &self,
383        raw_design: &ArrayBase<S, Ix2>,
384    ) -> Array2<f64> {
385        let raw_total = self.raw_total();
386        assert_eq!(
387            raw_design.ncols(),
388            raw_total,
389            "Gauge::restrict_design: design has {} columns, expected raw width {raw_total}",
390            raw_design.ncols(),
391        );
392        // A trivial section (`T = I`) leaves the design untouched: `X·I = X`
393        // bit-for-bit (every off-diagonal `T` entry is an exact zero, the
394        // diagonal an exact one, so the reduction is the identity map). The
395        // unconstrained Wahba sphere chart hits this on every build, and the
396        // skipped GEMM is an `(n × w)·(w × w)` product — ~0.8 s of host
397        // matrixmultiply at production shapes (n ≳ 1e5, w ~ 200). Detecting
398        // identity costs O(w²), negligible beside the O(n·w²) it elides.
399        if self.t_full_is_identity() {
400            return raw_design.to_owned();
401        }
402        fast_ab(raw_design, &self.t_full)
403    }
404
405    /// Whether the lift `T` is the exact identity (square with unit diagonal
406    /// and zero off-diagonal). When true, `restrict_design`/`restrict_penalty`
407    /// are no-ops and skip their GEMMs. The comparison is exact equality, not
408    /// a tolerance — only a literal identity short-circuits, so the fast path
409    /// is always bit-identical to the full product.
410    fn t_full_is_identity(&self) -> bool {
411        let (r, c) = self.t_full.dim();
412        if r != c {
413            return false;
414        }
415        self.t_full
416            .indexed_iter()
417            .all(|((i, j), &v)| v == if i == j { 1.0 } else { 0.0 })
418    }
419
420    /// Compose a raw design and offset with the affine section:
421    /// `X_raw · (Tθ + a) + o_raw = (X_raw · T)θ + (o_raw + X_raw · a)`.
422    pub fn restrict_design_and_offset<S: Data<Elem = f64>>(
423        &self,
424        raw_design: &ArrayBase<S, Ix2>,
425        raw_offset: &Array1<f64>,
426    ) -> (Array2<f64>, Array1<f64>) {
427        assert_eq!(
428            raw_design.nrows(),
429            raw_offset.len(),
430            "Gauge::restrict_design_and_offset: design rows {} != offset len {}",
431            raw_design.nrows(),
432            raw_offset.len(),
433        );
434        let reduced_design = self.restrict_design(raw_design);
435        let reduced_offset = raw_offset + &raw_design.dot(&self.affine_shift);
436        (reduced_design, reduced_offset)
437    }
438
439    /// Pull a raw-coordinate quadratic form back to reduced coordinates:
440    /// `S_reduced = Tᵀ · S_raw · T`.
441    pub fn restrict_penalty<S: Data<Elem = f64>>(
442        &self,
443        raw_penalty: &ArrayBase<S, Ix2>,
444    ) -> Array2<f64> {
445        let raw_total = self.raw_total();
446        assert_eq!(
447            raw_penalty.dim(),
448            (raw_total, raw_total),
449            "Gauge::restrict_penalty: matrix has shape {:?}, expected ({raw_total}, {raw_total})",
450            raw_penalty.dim(),
451        );
452        // `Tᵀ S T = S` exactly when `T = I` (see `restrict_design`). Skip the
453        // two `(w × w)·(w × w)` products on the unconstrained chart.
454        if self.t_full_is_identity() {
455            return raw_penalty.to_owned();
456        }
457        let t_s = fast_atb(&self.t_full, raw_penalty);
458        fast_ab(&t_s, &self.t_full)
459    }
460
461    /// Append blocks that were never reduced (raw == reduced, identity
462    /// lift). Used to lift joint objects that span both gauged blocks
463    /// and untouched ones (e.g. the survival flex blocks alongside the
464    /// compiled parametric blocks).
465    pub fn extend_with_identity(&self, extra_raw_widths: &[usize]) -> Self {
466        let extra_total: usize = extra_raw_widths.iter().sum();
467        let raw_total = self.raw_total();
468        let reduced_total = self.reduced_total();
469        let mut t = Array2::<f64>::zeros((raw_total + extra_total, reduced_total + extra_total));
470        t.slice_mut(ndarray::s![0..raw_total, 0..reduced_total])
471            .assign(&self.t_full);
472        for k in 0..extra_total {
473            t[[raw_total + k, reduced_total + k]] = 1.0;
474        }
475        let mut block_starts_raw = self.block_starts_raw.clone();
476        let mut block_starts_reduced = self.block_starts_reduced.clone();
477        for &w in extra_raw_widths {
478            block_starts_raw.push(block_starts_raw.last().copied().unwrap() + w);
479            block_starts_reduced.push(block_starts_reduced.last().copied().unwrap() + w);
480        }
481        let mut affine_shift = Array1::<f64>::zeros(raw_total + extra_total);
482        affine_shift
483            .slice_mut(ndarray::s![0..raw_total])
484            .assign(&self.affine_shift);
485        Self {
486            t_full: t,
487            affine_shift,
488            block_starts_raw,
489            block_starts_reduced,
490        }
491    }
492
493    /// Lift per-block reduced coefficients to per-block raw
494    /// coefficients: concatenate into θ, apply `β = T · θ + a`, split at
495    /// the raw partition.
496    pub fn lift_block_betas(&self, reduced_block_betas: &[Array1<f64>]) -> Vec<Array1<f64>> {
497        let n_blocks = self.n_blocks();
498        assert_eq!(
499            reduced_block_betas.len(),
500            n_blocks,
501            "Gauge::lift_block_betas: got {} reduced block betas, expected {}",
502            reduced_block_betas.len(),
503            n_blocks,
504        );
505        for (b, beta) in reduced_block_betas.iter().enumerate() {
506            let expected = self.block_starts_reduced[b + 1] - self.block_starts_reduced[b];
507            assert_eq!(
508                beta.len(),
509                expected,
510                "Gauge::lift_block_betas: block {b} has β of len {}, expected reduced width {}",
511                beta.len(),
512                expected,
513            );
514        }
515        let mut theta_full = Array1::<f64>::zeros(self.reduced_total());
516        for (b, beta) in reduced_block_betas.iter().enumerate() {
517            let c0 = self.block_starts_reduced[b];
518            let c1 = self.block_starts_reduced[b + 1];
519            theta_full.slice_mut(ndarray::s![c0..c1]).assign(beta);
520        }
521        let beta_full = self.t_full.dot(&theta_full) + &self.affine_shift;
522        let mut out = Vec::with_capacity(n_blocks);
523        for b in 0..n_blocks {
524            let r0 = self.block_starts_raw[b];
525            let r1 = self.block_starts_raw[b + 1];
526            out.push(beta_full.slice(ndarray::s![r0..r1]).to_owned());
527        }
528        out
529    }
530
531    /// Push a reduced-coordinate symmetric matrix (posterior covariance,
532    /// penalized Hessian — any symmetric bilinear form on θ) forward to
533    /// raw coordinates via the exact sandwich `M_raw = T · M_θ · Tᵀ`.
534    ///
535    /// The result is explicitly symmetrised: `T · M · Tᵀ` is symmetric
536    /// for symmetric `M`, but the two matmuls accumulate independent
537    /// rounding, so the transpose pair is averaged to land an exactly
538    /// symmetric matrix for downstream Cholesky / eigensolves.
539    pub fn lift_covariance(&self, m_reduced: &Array2<f64>) -> Array2<f64> {
540        let total_reduced = self.reduced_total();
541        assert_eq!(
542            m_reduced.dim(),
543            (total_reduced, total_reduced),
544            "Gauge::lift_covariance: matrix has shape {:?}, expected ({total_reduced}, {total_reduced})",
545            m_reduced.dim(),
546        );
547        let t_m = fast_ab(&self.t_full, m_reduced);
548        let mut raw = fast_abt(&t_m, &self.t_full);
549        let n = raw.nrows();
550        for i in 0..n {
551            for j in (i + 1)..n {
552                let avg = 0.5 * (raw[[i, j]] + raw[[j, i]]);
553                raw[[i, j]] = avg;
554                raw[[j, i]] = avg;
555            }
556        }
557        raw
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564
565    #[test]
566    fn identity_gauge_round_trips_betas_and_covariance() {
567        let gauge = Gauge::identity(&[2, 3]);
568        assert_eq!(gauge.n_blocks(), 2);
569        assert_eq!(gauge.raw_total(), 5);
570        assert_eq!(gauge.reduced_total(), 5);
571        let theta = vec![
572            Array1::from(vec![0.5, -0.25]),
573            Array1::from(vec![1.0, 2.0, -3.0]),
574        ];
575        let raw = gauge.lift_block_betas(&theta);
576        assert_eq!(raw[0].as_slice().unwrap(), &[0.5, -0.25]);
577        assert_eq!(raw[1].as_slice().unwrap(), &[1.0, 2.0, -3.0]);
578
579        let mut cov = Array2::<f64>::eye(5);
580        cov[[0, 3]] = 0.4;
581        cov[[3, 0]] = 0.4;
582        let lifted = gauge.lift_covariance(&cov);
583        for i in 0..5 {
584            for j in 0..5 {
585                assert!(
586                    (lifted[[i, j]] - cov[[i, j]]).abs() < 1e-14,
587                    "identity gauge must be a covariance no-op at ({i},{j})",
588                );
589            }
590        }
591    }
592
593    #[test]
594    fn identity_section_short_circuits_restrict_bit_exactly() {
595        // A trivial section must restrict design/penalty to the *exact* input,
596        // matching the full GEMM bit-for-bit while skipping it.
597        let gauge = Gauge::identity(&[4]);
598        assert!(gauge.t_full_is_identity());
599
600        // An irregular design with values that would perturb under a real GEMM
601        // if any rounding crept in.
602        let raw_design = Array2::<f64>::from_shape_fn((7, 4), |(i, j)| {
603            ((i as f64) * 0.3 - (j as f64) * 1.7).sin() * 1.000000001
604        });
605        let restricted = gauge.restrict_design(&raw_design);
606        // Bit-exact equality with the input (the identity map).
607        assert_eq!(restricted, raw_design);
608        // And bit-exact with the full product it elides.
609        let via_gemm = fast_ab(&raw_design, &gauge.t_full);
610        assert_eq!(restricted, via_gemm);
611
612        let raw_penalty = Array2::<f64>::from_shape_fn((4, 4), |(i, j)| {
613            (i as f64 + 1.0) * (j as f64 + 2.0) * 0.111
614        });
615        let restricted_pen = gauge.restrict_penalty(&raw_penalty);
616        assert_eq!(restricted_pen, raw_penalty);
617        let pen_via_gemm = fast_ab(&fast_atb(&gauge.t_full, &raw_penalty), &gauge.t_full);
618        assert_eq!(restricted_pen, pen_via_gemm);
619    }
620
621    #[test]
622    fn non_identity_section_is_not_short_circuited() {
623        // A real reparametrisation must NOT take the identity fast path.
624        let mut t = Array2::<f64>::eye(3);
625        t[[0, 1]] = 0.5;
626        let gauge = Gauge::from_t(t.clone(), &[3], &[3]);
627        assert!(!gauge.t_full_is_identity());
628        let raw = Array2::<f64>::from_shape_fn((5, 3), |(i, j)| i as f64 + j as f64 * 0.25);
629        let restricted = gauge.restrict_design(&raw);
630        assert_eq!(restricted, fast_ab(&raw, &t));
631    }
632
633    #[test]
634    fn rectangular_section_is_not_identity() {
635        // A tall centring section is square-free and must never be mistaken
636        // for the identity (it removes a direction).
637        let z = Array2::<f64>::from_shape_vec((3, 2), vec![1.0, 0.0, 0.0, 1.0, -1.0, -1.0]).unwrap();
638        let gauge = Gauge::sum_to_zero(z);
639        assert!(!gauge.t_full_is_identity());
640    }
641
642    #[test]
643    fn affine_gauge_lifts_betas_and_restricts_offsets() {
644        let t = Array2::from_shape_vec((3, 1), vec![2.0, -1.0, 0.5]).unwrap();
645        let shift = Array1::from(vec![0.25, 1.5, -0.75]);
646        let gauge = Gauge::from_block_transform_with_shift(t.clone(), shift.clone());
647        let theta = Array1::from(vec![4.0]);
648
649        let raw = gauge.lift_block_betas(&[theta.clone()]);
650        let expected_raw = t.dot(&theta) + &shift;
651        assert_eq!(raw[0], expected_raw);
652
653        let x = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 2.0, -1.0, 3.0, 0.5]).unwrap();
654        let offset = Array1::from(vec![0.1, -0.2]);
655        let (x_reduced, offset_reduced) = gauge.restrict_design_and_offset(&x, &offset);
656        assert_eq!(x_reduced, x.dot(&t));
657        assert_eq!(offset_reduced, &offset + &x.dot(&shift));
658
659        let eta_raw = x.dot(&expected_raw) + &offset;
660        let eta_reduced = x_reduced.dot(&theta) + &offset_reduced;
661        for i in 0..eta_raw.len() {
662            assert!((eta_raw[i] - eta_reduced[i]).abs() < 1e-14);
663        }
664
665        let cov_reduced = Array2::from_elem((1, 1), 3.0);
666        let lifted_cov = gauge.lift_covariance(&cov_reduced);
667        let expected_cov = t.dot(&cov_reduced).dot(&t.t());
668        assert_eq!(lifted_cov, expected_cov);
669    }
670
671    /// The covariance pushforward of an affine section `β = T·θ + a` must be
672    /// EXACTLY independent of the affine shift `a` — `Cov(T·θ + a) = T·Cov(θ)·Tᵀ`
673    /// for any constant `a`, because a deterministic offset adds no variance. The
674    /// b≡1 unit-log-t pin (#892) folds the warp into `a`; this is the property
675    /// that guarantees reporting the pinned coefficients carries the same
676    /// posterior uncertainty as the unpinned linear section. We assert it two
677    /// ways: (1) the analytic lift is bit-identical across a sweep of shift
678    /// magnitudes spanning the zero-shift linear case up to 1e7; and (2) an
679    /// empirical check — the sample covariance of `T·θ_k + a` over reduced draws
680    /// `θ_k` is unchanged when `a` is replaced by a 1e6-scale offset (the offset
681    /// cancels under centering).
682    #[test]
683    fn affine_shift_leaves_lifted_covariance_invariant() {
684        // A non-trivial 4-raw × 2-reduced section (so T mixes coordinates).
685        let t =
686            Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.5, -1.0, 2.0, 0.3, -0.4, 1.5]).unwrap();
687        let raw_widths = [4usize];
688        let reduced_widths = [2usize];
689
690        // A non-diagonal reduced covariance.
691        let cov_reduced = Array2::from_shape_vec((2, 2), vec![2.0, -0.7, -0.7, 1.3]).unwrap();
692
693        // The reference lift is the zero-shift (purely linear) section.
694        let base =
695            Gauge::from_t_with_shift(t.clone(), &raw_widths, &reduced_widths, Array1::zeros(4));
696        let reference = base.lift_covariance(&cov_reduced);
697
698        // (1) Bit-identical across a wide sweep of shift magnitudes.
699        for &mag in &[0.0, 1e-7, 1.0, 1e3, 1e7] {
700            let shift = Array1::from(vec![mag, -mag, 0.5 * mag, -2.0 * mag]);
701            let gauge = Gauge::from_t_with_shift(t.clone(), &raw_widths, &reduced_widths, shift);
702            let lifted = gauge.lift_covariance(&cov_reduced);
703            for i in 0..4 {
704                for j in 0..4 {
705                    assert_eq!(
706                        lifted[[i, j]],
707                        reference[[i, j]],
708                        "affine shift magnitude {mag} must not perturb the lifted covariance \
709                         at ({i},{j}) — covariance is offset-invariant",
710                    );
711                }
712            }
713        }
714
715        // (2) Empirical check: draw reduced samples, push them through
716        // β = T·θ + a for two very different shifts, and confirm the sample
717        // covariance is the same for both shifts. Draws use a fixed Cholesky
718        // colouring of cov_reduced so the test is deterministic (no RNG).
719        let chol = {
720            let l00 = cov_reduced[[0, 0]].sqrt();
721            let l10 = cov_reduced[[1, 0]] / l00;
722            let l11 = (cov_reduced[[1, 1]] - l10 * l10).sqrt();
723            Array2::from_shape_vec((2, 2), vec![l00, 0.0, l10, l11]).unwrap()
724        };
725        let z_raw = [
726            [1.2, -0.4],
727            [-0.8, 0.9],
728            [0.3, 1.7],
729            [-1.5, -0.6],
730            [0.6, -1.1],
731            [-0.2, 0.3],
732            [1.9, 0.2],
733            [-1.4, -0.9],
734        ];
735        let sample_cov_for_shift = |shift: &Array1<f64>| -> Array2<f64> {
736            let n = z_raw.len();
737            let betas: Vec<Array1<f64>> = z_raw
738                .iter()
739                .map(|z| {
740                    let theta = chol.dot(&Array1::from(vec![z[0], z[1]]));
741                    t.dot(&theta) + shift
742                })
743                .collect();
744            let mut mean = Array1::<f64>::zeros(4);
745            for b in &betas {
746                mean = &mean + b;
747            }
748            mean /= n as f64;
749            let mut cov = Array2::<f64>::zeros((4, 4));
750            for b in &betas {
751                let c = b - &mean;
752                for i in 0..4 {
753                    for j in 0..4 {
754                        cov[[i, j]] += c[i] * c[j] / n as f64;
755                    }
756                }
757            }
758            cov
759        };
760        let cov_small = sample_cov_for_shift(&Array1::zeros(4));
761        let cov_big = sample_cov_for_shift(&Array1::from(vec![1e6, -1e6, 5e5, -2e6]));
762        for i in 0..4 {
763            for j in 0..4 {
764                assert!(
765                    (cov_small[[i, j]] - cov_big[[i, j]]).abs() < 1e-6,
766                    "empirical sample covariance must be offset-invariant at ({i},{j}): \
767                     small-shift {} vs big-shift {}",
768                    cov_small[[i, j]],
769                    cov_big[[i, j]],
770                );
771            }
772        }
773    }
774
775    #[test]
776    fn block_diagonal_gauge_matches_per_block_lift() {
777        // Block 0: selection keeping raw cols {0, 2} of width 3.
778        let mut t0 = Array2::<f64>::zeros((3, 2));
779        t0[[0, 0]] = 1.0;
780        t0[[2, 1]] = 1.0;
781        // Block 1: full identity of width 2.
782        let t1 = Array2::<f64>::eye(2);
783        let gauge = Gauge::from_block_transforms(&[t0.clone(), t1.clone()]);
784        assert_eq!(gauge.raw_widths(), vec![3, 2]);
785        assert_eq!(gauge.reduced_widths(), vec![2, 2]);
786
787        let theta = vec![Array1::from(vec![1.5, -2.5]), Array1::from(vec![0.5, 4.0])];
788        let raw = gauge.lift_block_betas(&theta);
789        assert_eq!(raw[0].as_slice().unwrap(), &[1.5, 0.0, -2.5]);
790        assert_eq!(raw[1].as_slice().unwrap(), &[0.5, 4.0]);
791
792        // block_transform recovers the diagonal slabs exactly.
793        assert_eq!(gauge.block_transform(0), t0);
794        assert_eq!(gauge.block_transform(1), t1);
795    }
796
797    #[test]
798    fn triangular_gauge_applies_negative_r_off_diagonal() {
799        // Two blocks, raw widths 2 and 2; block 1 keeps 1 column and is
800        // residualised against block 0 by R (2×1).
801        let v_a = Array2::<f64>::eye(2);
802        let mut v_b = Array2::<f64>::zeros((2, 1));
803        v_b[[0, 0]] = 1.0;
804        let mut r_ab = Array2::<f64>::zeros((2, 1));
805        r_ab[[0, 0]] = 0.5;
806        r_ab[[1, 0]] = -0.25;
807        let gauge = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_ab)]);
808
809        let theta = vec![Array1::from(vec![1.0, 2.0]), Array1::from(vec![4.0])];
810        let raw = gauge.lift_block_betas(&theta);
811        // β_a = V_a·θ_a − R_{a→b}·θ_b = [1 − 0.5·4, 2 + 0.25·4] = [−1, 3].
812        assert!((raw[0][0] - (-1.0)).abs() < 1e-14);
813        assert!((raw[0][1] - 3.0).abs() < 1e-14);
814        // β_b = V_b·θ_b = [4, 0].
815        assert!((raw[1][0] - 4.0).abs() < 1e-14);
816        assert!((raw[1][1] - 0.0).abs() < 1e-14);
817    }
818
819    /// For a zero-shift gauge, covariance lift must be the exact pushforward of
820    /// the SAME `T` the β lift applies: for a rank-1 `Σ_θ = θθᵀ`, the lifted
821    /// covariance must equal `(Tθ)(Tθ)ᵀ` built from the lifted β.
822    #[test]
823    fn covariance_lift_is_rank1_consistent_with_beta_lift() {
824        let v_a = Array2::<f64>::eye(2);
825        let mut v_b = Array2::<f64>::zeros((2, 1));
826        v_b[[0, 0]] = 1.0;
827        let mut r_ab = Array2::<f64>::zeros((2, 1));
828        r_ab[[0, 0]] = 0.3;
829        r_ab[[1, 0]] = 0.7;
830        let gauge = Gauge::from_v_and_r(&[v_a, v_b], &[None, Some(r_ab)]);
831
832        let theta = vec![Array1::from(vec![0.8, -1.2]), Array1::from(vec![2.0])];
833        let raw = gauge.lift_block_betas(&theta);
834        let beta_full: Vec<f64> = raw.iter().flat_map(|b| b.iter().copied()).collect();
835
836        let theta_full = Array1::from(vec![0.8, -1.2, 2.0]);
837        let cov_rank1 = {
838            let n = theta_full.len();
839            Array2::from_shape_fn((n, n), |(i, j)| theta_full[i] * theta_full[j])
840        };
841        let lifted = gauge.lift_covariance(&cov_rank1);
842        assert_eq!(lifted.dim(), (4, 4));
843        for i in 0..4 {
844            for j in 0..4 {
845                let expected = beta_full[i] * beta_full[j];
846                assert!(
847                    (lifted[[i, j]] - expected).abs() < 1e-12,
848                    "rank-1 covariance lift must equal (Tθ)(Tθ)ᵀ at ({i},{j}): \
849                     got {} expected {expected}",
850                    lifted[[i, j]],
851                );
852            }
853        }
854    }
855
856    /// `Gauge::sum_to_zero(z)` must lift exactly as `β_raw = z · θ`, and the
857    /// lift must preserve the linear predictor: for any centred design
858    /// `B_c = B · z` and any reduced coefficient `θ`, the raw prediction
859    /// `B · (z · θ)` equals the reduced prediction `B_c · θ`. This is the
860    /// invariant that makes `z` the correct section — a wrong gauge would
861    /// preserve coefficients but break η.
862    #[test]
863    fn sum_to_zero_gauge_lifts_via_z_and_preserves_eta() {
864        // A concrete orthonormal centring section: null space of c = [1,1,1]ᵀ
865        // (the unweighted sum-to-zero constraint on a width-3 block), built as
866        // two orthonormal columns each summing to zero.
867        let s = 1.0 / 2.0_f64.sqrt();
868        let s6 = 1.0 / 6.0_f64.sqrt();
869        let mut z = Array2::<f64>::zeros((3, 2));
870        z[[0, 0]] = s;
871        z[[1, 0]] = -s;
872        z[[2, 0]] = 0.0;
873        z[[0, 1]] = s6;
874        z[[1, 1]] = s6;
875        z[[2, 1]] = -2.0 * s6;
876        // The columns are orthonormal and sum to zero (cᵀz = 0).
877        for j in 0..2 {
878            assert!(
879                (z.column(j).sum()).abs() < 1e-14,
880                "column {j} must sum to 0"
881            );
882            assert!(
883                (z.column(j).dot(&z.column(j)) - 1.0).abs() < 1e-14,
884                "column {j} must be unit norm"
885            );
886        }
887
888        let gauge = Gauge::sum_to_zero(z.clone());
889        assert_eq!(gauge.n_blocks(), 1);
890        assert_eq!(gauge.raw_widths(), vec![3]);
891        assert_eq!(gauge.reduced_widths(), vec![2]);
892        assert_eq!(gauge.block_transform(0), z);
893
894        // Lift β_raw = z · θ exactly.
895        let theta = Array1::from(vec![1.3, -0.7]);
896        let raw = gauge.lift_block_betas(&[theta.clone()]);
897        let expected_raw = z.dot(&theta);
898        for i in 0..3 {
899            assert!((raw[0][i] - expected_raw[i]).abs() < 1e-14);
900        }
901        // Centring is satisfied: the raw coefficients sum to zero.
902        assert!(raw[0].sum().abs() < 1e-14, "lifted β must be centred");
903
904        // η preservation: B · (z · θ) == (B · z) · θ for an arbitrary B.
905        let b = Array2::from_shape_vec(
906            (4, 3),
907            vec![
908                1.0, 2.0, -1.0, 0.5, -0.5, 3.0, 2.0, 1.0, 1.0, -1.0, 0.0, 4.0,
909            ],
910        )
911        .unwrap();
912        let b_c = fast_ab(&b, &z); // the constrained design B_c
913        assert_eq!(gauge.restrict_design(&b), b_c);
914        let eta_reduced = b_c.dot(&theta);
915        let eta_raw = b.dot(&expected_raw);
916        for i in 0..4 {
917            assert!(
918                (eta_reduced[i] - eta_raw[i]).abs() < 1e-13,
919                "η must be invariant under the centring lift at row {i}",
920            );
921        }
922
923        // Covariance pushforward through the SAME z (rank-1 consistency).
924        let cov_rank1 = Array2::from_shape_fn((2, 2), |(i, j)| theta[i] * theta[j]);
925        let lifted = gauge.lift_covariance(&cov_rank1);
926        assert_eq!(lifted.dim(), (3, 3));
927        for i in 0..3 {
928            for j in 0..3 {
929                let expect = expected_raw[i] * expected_raw[j];
930                assert!(
931                    (lifted[[i, j]] - expect).abs() < 1e-13,
932                    "centring covariance lift must equal (zθ)(zθ)ᵀ at ({i},{j})",
933                );
934            }
935        }
936
937        let raw_penalty = Array2::from_shape_vec(
938            (3, 3),
939            vec![2.0, 0.5, 0.0, 0.5, 3.0, -0.25, 0.0, -0.25, 4.0],
940        )
941        .unwrap();
942        let reduced_penalty = gauge.restrict_penalty(&raw_penalty);
943        let expected_reduced_penalty = fast_ab(&fast_atb(&z, &raw_penalty), &z);
944        assert_eq!(reduced_penalty, expected_reduced_penalty);
945    }
946
947    #[test]
948    #[should_panic(expected = "removes at least one direction")]
949    fn sum_to_zero_rejects_identity_section() {
950        // A square z removes no direction — that is not a centring section.
951        drop(Gauge::sum_to_zero(Array2::<f64>::eye(3)));
952    }
953
954    #[test]
955    fn extend_with_identity_passes_extra_blocks_through() {
956        let mut t0 = Array2::<f64>::zeros((2, 1));
957        t0[[0, 0]] = 1.0;
958        let gauge = Gauge::from_block_transforms(&[t0]).extend_with_identity(&[2]);
959        assert_eq!(gauge.n_blocks(), 2);
960        assert_eq!(gauge.raw_total(), 4);
961        assert_eq!(gauge.reduced_total(), 3);
962
963        let theta = vec![Array1::from(vec![3.0]), Array1::from(vec![1.0, -1.0])];
964        let raw = gauge.lift_block_betas(&theta);
965        assert_eq!(raw[0].as_slice().unwrap(), &[3.0, 0.0]);
966        assert_eq!(raw[1].as_slice().unwrap(), &[1.0, -1.0]);
967
968        // Covariance: the extra (untouched) block's diagonal sub-matrix
969        // survives the lift bit-for-bit; the reduced block zero-pads.
970        let mut cov = Array2::<f64>::eye(3);
971        cov[[1, 2]] = 0.25;
972        cov[[2, 1]] = 0.25;
973        let lifted = gauge.lift_covariance(&cov);
974        assert_eq!(lifted.dim(), (4, 4));
975        assert!((lifted[[0, 0]] - 1.0).abs() < 1e-14);
976        assert!(
977            (lifted[[1, 1]] - 0.0).abs() < 1e-14,
978            "dropped raw row has zero variance"
979        );
980        assert!((lifted[[2, 2]] - 1.0).abs() < 1e-14);
981        assert!((lifted[[3, 3]] - 1.0).abs() < 1e-14);
982        assert!((lifted[[2, 3]] - 0.25).abs() < 1e-14);
983    }
984}