Skip to main content

gam_solve/
orthogonal_reparam.rs

1//! General exact orthogonal reparameterization of overlapping design blocks
2//! (universal robustness — the "orthogonalize" stage).
3//!
4//! # Why this lives in the shared solver layer (not in a family)
5//!
6//! Several families build a linear predictor from two (or more) design blocks
7//! whose column spans *overlap* by construction. The canonical case is the
8//! Bernoulli/survival marginal-slope index
9//!
10//! ```text
11//!     η(x) = M·β_m  +  diag(z)·S·β_s
12//! ```
13//!
14//! where `M` is the marginal baseline surface and `S` is the score-weighted
15//! ("logslope") surface. Because the exposure `z` correlates with the same PC
16//! smooths that both `M` and `S` are built from, a component of `M·β_m` can be
17//! explained almost equally well by `diag(z)·S·β_s`. That structural confound
18//! makes the *joint* design rank-soft: the inner Newton sees a near-singular
19//! cross-block Hessian and the outer REML never settles.
20//!
21//! An earlier solver papered over this with pinned ridges (a penalty mass aimed
22//! at the confounded direction, now deleted). That *penalizes* the confound.
23//! This module instead *resolves* it by construction: it reparameterizes the
24//! confound block so its
25//! columns are **exactly** orthogonal (in a chosen row metric `W`) to the
26//! primary block's column span. After the transform the cross-block Gram is
27//! exactly zero, so no ridge is needed for identification — and the transform is
28//! a pure change of basis, so the original-basis coefficients are recovered
29//! **exactly** for prediction and reporting.
30//!
31//! The mechanism is family-general: it operates only on dense design columns and
32//! a per-row weight vector, so any family that can hand over a `(primary,
33//! confound, W)` triple inherits it. Activating it for BMS is fine, but nothing
34//! here is BMS-specific.
35//!
36//! # The math (exact, no approximation)
37//!
38//! Let `M` (`n × p_m`) be the primary block, `C` (`n × p_c`) the confound block,
39//! and `W = diag(w)` a non-negative row metric (`w_i ≥ 0`). Define the
40//! W-projection coefficients
41//!
42//! ```text
43//!     B = (MᵀW M + ε I)⁻¹ MᵀW C          (p_m × p_c)
44//! ```
45//!
46//! and the orthogonalized confound design
47//!
48//! ```text
49//!     C̃ = C − M·B.
50//! ```
51//!
52//! Then `Mᵀ W C̃ = MᵀW C − (MᵀW M)·B = MᵀW C − MᵀW C = 0` (exactly, up to the ε
53//! ridge that only acts when `MᵀW M` is rank-deficient), i.e. `C̃` is W-orthogonal
54//! to `span(M)`.
55//!
56//! Crucially this is just a **shear** of the joint coefficient vector. The linear
57//! predictor is invariant:
58//!
59//! ```text
60//!     M·β̃_m + C̃·β_c = M·β̃_m + (C − M·B)·β_c
61//!                    = M·(β̃_m − B·β_c) + C·β_c,
62//! ```
63//!
64//! so if the solver fits `(β̃_m, β_c)` in the reparameterized basis, the
65//! original-basis coefficients are recovered **exactly** by
66//!
67//! ```text
68//!     β_m = β̃_m − B·β_c,      β_c (unchanged).
69//! ```
70//!
71//! The confound coefficients are untouched; only the primary coefficients absorb
72//! the shear `B·β_c`. [`OrthogonalReparam::recover_original`] performs exactly
73//! this map, and [`OrthogonalReparam::reparameterized_confound`] returns `C̃`.
74//!
75//! Robustness is unconditional: the construction entry point
76//! [`OrthogonalReparam::build_unconditional`] always builds the exact reparam.
77//! The caller decides whether there is a confound block to orthogonalize.
78
79use ndarray::{Array1, Array2, ArrayView2};
80
81use gam_linalg::faer_ndarray::{
82    FaerArrayView, factorize_symmetricwith_fallback, fast_ab, fast_xt_diag_x, fast_xt_diag_y,
83};
84use gam_linalg::matrix::FactorizedSystem;
85use faer::Side;
86
87/// Relative ridge (vs. the largest weighted primary-Gram diagonal) added to
88/// `MᵀW M` before forming the projection coefficients `B`. It only regularizes a
89/// rank-deficient primary design (a dropped/aliased column leaves a zero pivot)
90/// and is negligible against a well-conditioned Gram, so the orthogonality
91/// `MᵀW C̃ ≈ 0` holds to working precision. Matches the magnitude used by the
92/// §3 influence projection so the two share a numerical regime.
93pub const ORTHOGONAL_PROJECTION_RELATIVE_RIDGE: f64 = 1.0e-10;
94
95/// Absolute floor on the projection ridge, so a degenerate (all-zero) weighted
96/// primary Gram still yields an invertible system.
97pub const ORTHOGONAL_PROJECTION_RIDGE_FLOOR: f64 = 1.0e-12;
98
99/// An exact orthogonal reparameterization of one confound block against one
100/// primary block's column span in a fixed row metric `W`.
101///
102/// Holds the shear matrix `B` (`p_m × p_c`) and the reparameterized confound
103/// design `C̃ = C − M·B` (`n × p_c`). The transform is a pure change of basis, so
104/// it is fully described by `B`; `C̃` is cached because the solver needs the new
105/// design and recomputing it is wasteful.
106///
107/// Build with [`OrthogonalReparam::build_unconditional`]. The round-trip
108/// [`recover_original`](Self::recover_original) maps fitted reparameterized
109/// coefficients back to the original basis exactly.
110#[derive(Debug, Clone)]
111pub struct OrthogonalReparam {
112    /// W-projection / shear matrix `B = (MᵀWM + εI)⁻¹ MᵀW C` (`p_m × p_c`).
113    shear: Array2<f64>,
114    /// Reparameterized confound design `C̃ = C − M·B` (`n × p_c`).
115    confound_orthogonal: Array2<f64>,
116}
117
118impl OrthogonalReparam {
119    /// Build the exact orthogonal reparameterization of the `confound` block
120    /// against the `primary` block's column span in the `w_metric` row metric.
121    ///
122    /// Robustness is unconditional, so this always constructs the reparam (the
123    /// caller decides whether there is anything to orthogonalize; an empty span
124    /// `p_m == 0` or `p_c == 0` yields an identity-on-confound transform).
125    ///
126    /// Returns:
127    ///   - `Ok(reparam)` with `C̃` exactly W-orthogonal to `span(primary)`.
128    ///   - `Err` on a dimension mismatch, a non-finite/negative metric, or a
129    ///     non-finite result.
130    ///
131    /// `primary` is `n × p_m`, `confound` is `n × p_c`, `w_metric` is length `n`
132    /// with `w_i ≥ 0` (the PIRLS row inner product at the pilot, so the resulting
133    /// orthogonality holds in the metric the penalized joint solve actually sees;
134    /// pass all-ones for the plain Euclidean metric).
135    pub fn build_unconditional(
136        primary: ArrayView2<f64>,
137        confound: ArrayView2<f64>,
138        w_metric: &Array1<f64>,
139    ) -> Result<Self, String> {
140        let n = primary.nrows();
141        if confound.nrows() != n {
142            return Err(format!(
143                "orthogonal_reparam: primary rows ({n}) != confound rows ({})",
144                confound.nrows()
145            ));
146        }
147        if w_metric.len() != n {
148            return Err(format!(
149                "orthogonal_reparam: row metric length ({}) != design rows ({n})",
150                w_metric.len()
151            ));
152        }
153        if w_metric.iter().any(|v| !v.is_finite() || *v < 0.0) {
154            return Err(
155                "orthogonal_reparam: row metric must be finite and non-negative".to_string(),
156            );
157        }
158        let p_m = primary.ncols();
159        let p_c = confound.ncols();
160
161        // No primary span (or no confound columns) ⇒ nothing to orthogonalize.
162        // Return an identity-shear reparam whose C̃ is the raw confound, so a
163        // caller that already chose Some(..) still gets a consistent object.
164        if p_m == 0 || p_c == 0 {
165            return Ok(Self {
166                shear: Array2::<f64>::zeros((p_m, p_c)),
167                confound_orthogonal: confound.to_owned(),
168            });
169        }
170
171        // Weighted primary Gram MᵀW M and cross term MᵀW C in the row metric.
172        let mut gram = fast_xt_diag_x(&primary, w_metric);
173        let gram_scale = (0..p_m).map(|i| gram[[i, i]]).fold(0.0_f64, f64::max);
174        let eps = (gram_scale * ORTHOGONAL_PROJECTION_RELATIVE_RIDGE)
175            .max(ORTHOGONAL_PROJECTION_RIDGE_FLOOR);
176        for i in 0..p_m {
177            gram[[i, i]] += eps;
178        }
179        let cross = fast_xt_diag_y(&primary, w_metric, &confound.to_owned());
180
181        let gram_view = FaerArrayView::new(&gram);
182        let factor =
183            factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower).map_err(|e| {
184                format!("orthogonal_reparam: weighted primary Gram factorization failed: {e:?}")
185            })?;
186        // B = (MᵀWM + εI)⁻¹ MᵀW C   (p_m × p_c)
187        let shear = factor
188            .solvemulti(&cross)
189            .map_err(|e| format!("orthogonal_reparam: projection solve failed: {e}"))?;
190
191        // C̃ = C − M·B.
192        let projection = fast_ab(&primary, &shear);
193        let confound_orthogonal = &confound - &projection;
194
195        if shear.iter().any(|v| !v.is_finite())
196            || confound_orthogonal.iter().any(|v| !v.is_finite())
197        {
198            return Err(
199                "orthogonal_reparam: reparameterization produced non-finite entries".to_string(),
200            );
201        }
202
203        Ok(Self {
204            shear,
205            confound_orthogonal,
206        })
207    }
208
209    /// The shear matrix `B` (`p_m × p_c`). Original primary coefficients are
210    /// `β_m = β̃_m − B·β_c`.
211    #[inline]
212    pub fn shear(&self) -> ArrayView2<'_, f64> {
213        self.shear.view()
214    }
215
216    /// The reparameterized confound design `C̃ = C − M·B` (`n × p_c`), exactly
217    /// W-orthogonal to `span(primary)`. This is the design the solver fits the
218    /// confound coefficients against.
219    #[inline]
220    pub fn reparameterized_confound(&self) -> ArrayView2<'_, f64> {
221        self.confound_orthogonal.view()
222    }
223
224    /// Number of primary columns `p_m`.
225    #[inline]
226    pub fn primary_cols(&self) -> usize {
227        self.shear.nrows()
228    }
229
230    /// Number of confound columns `p_c`.
231    #[inline]
232    pub fn confound_cols(&self) -> usize {
233        self.shear.ncols()
234    }
235
236    /// Map the fitted reparameterized coefficients `(β̃_m, β_c)` back to the
237    /// original basis `(β_m, β_c)` **exactly**:
238    ///
239    /// ```text
240    ///     β_m = β̃_m − B·β_c,      β_c unchanged.
241    /// ```
242    ///
243    /// `beta_m_reparam` has length `p_m`, `beta_c` has length `p_c`. Returns the
244    /// original-basis `(β_m, β_c)`. Because the predictor `M·β̃_m + C̃·β_c` equals
245    /// `M·β_m + C·β_c` for these recovered coefficients, predictions in the
246    /// original basis are unchanged.
247    pub fn recover_original(
248        &self,
249        beta_m_reparam: &Array1<f64>,
250        beta_c: &Array1<f64>,
251    ) -> Result<(Array1<f64>, Array1<f64>), String> {
252        let p_m = self.primary_cols();
253        let p_c = self.confound_cols();
254        if beta_m_reparam.len() != p_m {
255            return Err(format!(
256                "orthogonal_reparam: reparameterized primary coeffs length ({}) != p_m ({p_m})",
257                beta_m_reparam.len()
258            ));
259        }
260        if beta_c.len() != p_c {
261            return Err(format!(
262                "orthogonal_reparam: confound coeffs length ({}) != p_c ({p_c})",
263                beta_c.len()
264            ));
265        }
266        // β_m = β̃_m − B·β_c.
267        let shear_beta_c = self.shear.dot(beta_c);
268        let beta_m = beta_m_reparam - &shear_beta_c;
269        Ok((beta_m, beta_c.clone()))
270    }
271
272    /// Forward shear: map original-basis primary coefficients `β_m` to the
273    /// reparameterized basis `β̃_m = β_m + B·β_c` (the inverse of
274    /// [`recover_original`](Self::recover_original)). Useful for warm-starting the
275    /// reparameterized solve from an original-basis initial guess.
276    pub fn to_reparameterized(
277        &self,
278        beta_m: &Array1<f64>,
279        beta_c: &Array1<f64>,
280    ) -> Result<Array1<f64>, String> {
281        let p_m = self.primary_cols();
282        let p_c = self.confound_cols();
283        if beta_m.len() != p_m {
284            return Err(format!(
285                "orthogonal_reparam: primary coeffs length ({}) != p_m ({p_m})",
286                beta_m.len()
287            ));
288        }
289        if beta_c.len() != p_c {
290            return Err(format!(
291                "orthogonal_reparam: confound coeffs length ({}) != p_c ({p_c})",
292                beta_c.len()
293            ));
294        }
295        let shear_beta_c = self.shear.dot(beta_c);
296        Ok(beta_m + &shear_beta_c)
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use ndarray::{Array1, Array2};
304
305    /// Build a primary design `M` and a confound `C` that genuinely overlaps it
306    /// (a couple of `C`'s columns are `M` columns plus small noise), and verify
307    /// the W-orthogonality `MᵀW C̃ ≈ 0` holds to working precision.
308    #[test]
309    fn orthogonalized_confound_is_w_orthogonal_to_primary() {
310        let n = 50;
311        let mut m = Array2::<f64>::zeros((n, 3));
312        let mut c = Array2::<f64>::zeros((n, 2));
313        for i in 0..n {
314            let t = i as f64 / n as f64;
315            m[[i, 0]] = 1.0;
316            m[[i, 1]] = t;
317            m[[i, 2]] = (t * 6.0).sin();
318            // C overlaps M: col0 ≈ M col1 (confound), col1 has a fresh direction.
319            c[[i, 0]] = t + 0.01 * (t * 13.0).cos();
320            c[[i, 1]] = (t * 3.0).cos();
321        }
322        let w = Array1::<f64>::from_elem(n, 1.0);
323        let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
324            .expect("build should succeed");
325
326        let c_tilde = reparam.reparameterized_confound().to_owned();
327        // MᵀW C̃ should be ~0.
328        let cross = fast_xt_diag_y(&m, &w, &c_tilde);
329        let max_abs = cross.iter().fold(0.0_f64, |a, v| a.max(v.abs()));
330        assert!(
331            max_abs < 1e-8,
332            "MᵀW C̃ not orthogonal: max |entry| = {max_abs:e}"
333        );
334    }
335
336    /// EXACT round-trip: fit (synthetically) in the reparameterized basis, then
337    /// recover original coefficients and confirm the predictor is identical.
338    #[test]
339    fn coefficient_round_trip_is_exact() {
340        let n = 40;
341        let mut m = Array2::<f64>::zeros((n, 2));
342        let mut c = Array2::<f64>::zeros((n, 2));
343        for i in 0..n {
344            let t = i as f64 / n as f64;
345            m[[i, 0]] = 1.0;
346            m[[i, 1]] = (t * 4.0).sin();
347            c[[i, 0]] = t; // overlaps the linear-ish part of M
348            c[[i, 1]] = (t * 2.0).cos();
349        }
350        let w = Array1::<f64>::from_elem(n, 1.0);
351        let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
352            .expect("build should succeed");
353
354        // Pretend the solver returned these reparameterized-basis coefficients.
355        let beta_m_reparam = Array1::from_vec(vec![0.7, -1.3]);
356        let beta_c = Array1::from_vec(vec![2.1, 0.4]);
357
358        // Predictor in the reparameterized basis: M·β̃_m + C̃·β_c.
359        let c_tilde = reparam.reparameterized_confound().to_owned();
360        let eta_reparam = m.dot(&beta_m_reparam) + c_tilde.dot(&beta_c);
361
362        // Recover original coefficients and form the predictor in the ORIGINAL
363        // basis: M·β_m + C·β_c. Must match to tight tolerance.
364        let (beta_m, beta_c_out) = reparam
365            .recover_original(&beta_m_reparam, &beta_c)
366            .expect("recover should succeed");
367        let eta_original = m.dot(&beta_m) + c.dot(&beta_c_out);
368
369        let max_diff = (&eta_reparam - &eta_original)
370            .iter()
371            .fold(0.0_f64, |a, v| a.max(v.abs()));
372        assert!(
373            max_diff < 1e-10,
374            "predictor changed under round-trip: max |Δη| = {max_diff:e}"
375        );
376        // Confound coefficients are untouched by the reparameterization.
377        let cdiff = (&beta_c_out - &beta_c)
378            .iter()
379            .fold(0.0_f64, |a, v| a.max(v.abs()));
380        assert!(cdiff == 0.0, "confound coeffs changed: {cdiff:e}");
381
382        // Forward map is the exact inverse of recover_original.
383        let back = reparam
384            .to_reparameterized(&beta_m, &beta_c)
385            .expect("forward should succeed");
386        let fdiff = (&back - &beta_m_reparam)
387            .iter()
388            .fold(0.0_f64, |a, v| a.max(v.abs()));
389        assert!(fdiff < 1e-10, "forward/inverse mismatch: {fdiff:e}");
390    }
391
392    /// When the confound does NOT overlap the primary span, predictions are
393    /// unchanged AND the orthogonal design equals the raw confound (no shear),
394    /// confirming the pass touches nothing it should not.
395    #[test]
396    fn absent_confound_leaves_design_and_predictions_unchanged() {
397        let n = 30;
398        // Primary spans constant + linear; confound is a pure quadratic deviation
399        // built to be Euclidean-orthogonal to span{1, t} by centering.
400        let mut m = Array2::<f64>::zeros((n, 2));
401        let mut raw_quad = Vec::with_capacity(n);
402        for i in 0..n {
403            let t = i as f64 / (n as f64 - 1.0);
404            m[[i, 0]] = 1.0;
405            m[[i, 1]] = t;
406            raw_quad.push(t * t);
407        }
408        // Residualize the quadratic against {1, t} by hand so the confound column
409        // is genuinely orthogonal to span(M) under W = I (the "confound absent"
410        // regime). Use the very pass we are testing as the residualizer would be
411        // circular; instead do an explicit least-squares residual.
412        let w = Array1::<f64>::from_elem(n, 1.0);
413        // Solve M b = quad in LS, residual = quad - M b is ⊥ span(M).
414        let gram = fast_xt_diag_x(&m, &w);
415        let quad = Array1::from_vec(raw_quad);
416        let cross = m.t().dot(&quad);
417        let gview = FaerArrayView::new(&gram);
418        let factor = factorize_symmetricwith_fallback(gview.as_ref(), Side::Lower).expect("factor");
419        let b = FactorizedSystem::solve(&factor, &cross).expect("solve");
420        let resid = &quad - &m.dot(&b);
421        let mut c = Array2::<f64>::zeros((n, 1));
422        c.column_mut(0).assign(&resid);
423
424        let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
425            .expect("build should succeed");
426        // No overlap ⇒ shear ≈ 0 ⇒ C̃ ≈ C.
427        let shear_max = reparam.shear().iter().fold(0.0_f64, |a, v| a.max(v.abs()));
428        assert!(shear_max < 1e-8, "expected ~zero shear, got {shear_max:e}");
429        let c_tilde = reparam.reparameterized_confound().to_owned();
430        let design_diff = (&c_tilde - &c).iter().fold(0.0_f64, |a, v| a.max(v.abs()));
431        assert!(
432            design_diff < 1e-8,
433            "orthogonalized design drifted from raw when confound absent: {design_diff:e}"
434        );
435    }
436
437    /// Empty primary span ⇒ confound returned unchanged (nothing to project out).
438    #[test]
439    fn empty_primary_returns_raw_confound() {
440        let n = 8;
441        let m = Array2::<f64>::zeros((n, 0));
442        let mut c = Array2::<f64>::zeros((n, 2));
443        for i in 0..n {
444            c[[i, 0]] = i as f64;
445            c[[i, 1]] = 1.0;
446        }
447        let w = Array1::<f64>::from_elem(n, 1.0);
448        let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
449            .expect("build should succeed");
450        let c_tilde = reparam.reparameterized_confound().to_owned();
451        let diff = (&c_tilde - &c).iter().fold(0.0_f64, |a, v| a.max(v.abs()));
452        assert!(diff == 0.0, "empty primary must return raw confound");
453    }
454}