gam_solve/orthogonal_reparam.rs
1//! General exact orthogonal reparameterization of overlapping design blocks
2//! (universal robustness — the "orthogonalize" stage).
3//!
4//! # Why this lives in the shared solver layer (not in a family)
5//!
6//! Several families build a linear predictor from two (or more) design blocks
7//! whose column spans *overlap* by construction. The canonical case is the
8//! Bernoulli/survival marginal-slope index
9//!
10//! ```text
11//! η(x) = M·β_m + diag(z)·S·β_s
12//! ```
13//!
14//! where `M` is the marginal baseline surface and `S` is the score-weighted
15//! ("logslope") surface. Because the exposure `z` correlates with the same PC
16//! smooths that both `M` and `S` are built from, a component of `M·β_m` can be
17//! explained almost equally well by `diag(z)·S·β_s`. That structural confound
18//! makes the *joint* design rank-soft: the inner Newton sees a near-singular
19//! cross-block Hessian and the outer REML never settles.
20//!
21//! An earlier solver papered over this with pinned ridges (a penalty mass aimed
22//! at the confounded direction, now deleted). That *penalizes* the confound.
23//! This module instead *resolves* it by construction: it reparameterizes the
24//! confound block so its
25//! columns are **exactly** orthogonal (in a chosen row metric `W`) to the
26//! primary block's column span. After the transform the cross-block Gram is
27//! exactly zero, so no ridge is needed for identification — and the transform is
28//! a pure change of basis, so the original-basis coefficients are recovered
29//! **exactly** for prediction and reporting.
30//!
31//! The mechanism is family-general: it operates only on dense design columns and
32//! a per-row weight vector, so any family that can hand over a `(primary,
33//! confound, W)` triple inherits it. Activating it for BMS is fine, but nothing
34//! here is BMS-specific.
35//!
36//! # The math (exact, no approximation)
37//!
38//! Let `M` (`n × p_m`) be the primary block, `C` (`n × p_c`) the confound block,
39//! and `W = diag(w)` a non-negative row metric (`w_i ≥ 0`). Define the
40//! W-projection coefficients
41//!
42//! ```text
43//! B = (MᵀW M + ε I)⁻¹ MᵀW C (p_m × p_c)
44//! ```
45//!
46//! and the orthogonalized confound design
47//!
48//! ```text
49//! C̃ = C − M·B.
50//! ```
51//!
52//! Then `Mᵀ W C̃ = MᵀW C − (MᵀW M)·B = MᵀW C − MᵀW C = 0` (exactly, up to the ε
53//! ridge that only acts when `MᵀW M` is rank-deficient), i.e. `C̃` is W-orthogonal
54//! to `span(M)`.
55//!
56//! Crucially this is just a **shear** of the joint coefficient vector. The linear
57//! predictor is invariant:
58//!
59//! ```text
60//! M·β̃_m + C̃·β_c = M·β̃_m + (C − M·B)·β_c
61//! = M·(β̃_m − B·β_c) + C·β_c,
62//! ```
63//!
64//! so if the solver fits `(β̃_m, β_c)` in the reparameterized basis, the
65//! original-basis coefficients are recovered **exactly** by
66//!
67//! ```text
68//! β_m = β̃_m − B·β_c, β_c (unchanged).
69//! ```
70//!
71//! The confound coefficients are untouched; only the primary coefficients absorb
72//! the shear `B·β_c`. [`OrthogonalReparam::recover_original`] performs exactly
73//! this map, and [`OrthogonalReparam::reparameterized_confound`] returns `C̃`.
74//!
75//! Robustness is unconditional: the construction entry point
76//! [`OrthogonalReparam::build_unconditional`] always builds the exact reparam.
77//! The caller decides whether there is a confound block to orthogonalize.
78
79use ndarray::{Array1, Array2, ArrayView2};
80
81use gam_linalg::faer_ndarray::{
82 FaerArrayView, factorize_symmetricwith_fallback, fast_ab, fast_xt_diag_x, fast_xt_diag_y,
83};
84use gam_linalg::matrix::FactorizedSystem;
85use faer::Side;
86
87/// Relative ridge (vs. the largest weighted primary-Gram diagonal) added to
88/// `MᵀW M` before forming the projection coefficients `B`. It only regularizes a
89/// rank-deficient primary design (a dropped/aliased column leaves a zero pivot)
90/// and is negligible against a well-conditioned Gram, so the orthogonality
91/// `MᵀW C̃ ≈ 0` holds to working precision. Matches the magnitude used by the
92/// §3 influence projection so the two share a numerical regime.
93pub const ORTHOGONAL_PROJECTION_RELATIVE_RIDGE: f64 = 1.0e-10;
94
95/// Absolute floor on the projection ridge, so a degenerate (all-zero) weighted
96/// primary Gram still yields an invertible system.
97pub const ORTHOGONAL_PROJECTION_RIDGE_FLOOR: f64 = 1.0e-12;
98
99/// An exact orthogonal reparameterization of one confound block against one
100/// primary block's column span in a fixed row metric `W`.
101///
102/// Holds the shear matrix `B` (`p_m × p_c`) and the reparameterized confound
103/// design `C̃ = C − M·B` (`n × p_c`). The transform is a pure change of basis, so
104/// it is fully described by `B`; `C̃` is cached because the solver needs the new
105/// design and recomputing it is wasteful.
106///
107/// Build with [`OrthogonalReparam::build_unconditional`]. The round-trip
108/// [`recover_original`](Self::recover_original) maps fitted reparameterized
109/// coefficients back to the original basis exactly.
110#[derive(Debug, Clone)]
111pub struct OrthogonalReparam {
112 /// W-projection / shear matrix `B = (MᵀWM + εI)⁻¹ MᵀW C` (`p_m × p_c`).
113 shear: Array2<f64>,
114 /// Reparameterized confound design `C̃ = C − M·B` (`n × p_c`).
115 confound_orthogonal: Array2<f64>,
116}
117
118impl OrthogonalReparam {
119 /// Build the exact orthogonal reparameterization of the `confound` block
120 /// against the `primary` block's column span in the `w_metric` row metric.
121 ///
122 /// Robustness is unconditional, so this always constructs the reparam (the
123 /// caller decides whether there is anything to orthogonalize; an empty span
124 /// `p_m == 0` or `p_c == 0` yields an identity-on-confound transform).
125 ///
126 /// Returns:
127 /// - `Ok(reparam)` with `C̃` exactly W-orthogonal to `span(primary)`.
128 /// - `Err` on a dimension mismatch, a non-finite/negative metric, or a
129 /// non-finite result.
130 ///
131 /// `primary` is `n × p_m`, `confound` is `n × p_c`, `w_metric` is length `n`
132 /// with `w_i ≥ 0` (the PIRLS row inner product at the pilot, so the resulting
133 /// orthogonality holds in the metric the penalized joint solve actually sees;
134 /// pass all-ones for the plain Euclidean metric).
135 pub fn build_unconditional(
136 primary: ArrayView2<f64>,
137 confound: ArrayView2<f64>,
138 w_metric: &Array1<f64>,
139 ) -> Result<Self, String> {
140 let n = primary.nrows();
141 if confound.nrows() != n {
142 return Err(format!(
143 "orthogonal_reparam: primary rows ({n}) != confound rows ({})",
144 confound.nrows()
145 ));
146 }
147 if w_metric.len() != n {
148 return Err(format!(
149 "orthogonal_reparam: row metric length ({}) != design rows ({n})",
150 w_metric.len()
151 ));
152 }
153 if w_metric.iter().any(|v| !v.is_finite() || *v < 0.0) {
154 return Err(
155 "orthogonal_reparam: row metric must be finite and non-negative".to_string(),
156 );
157 }
158 let p_m = primary.ncols();
159 let p_c = confound.ncols();
160
161 // No primary span (or no confound columns) ⇒ nothing to orthogonalize.
162 // Return an identity-shear reparam whose C̃ is the raw confound, so a
163 // caller that already chose Some(..) still gets a consistent object.
164 if p_m == 0 || p_c == 0 {
165 return Ok(Self {
166 shear: Array2::<f64>::zeros((p_m, p_c)),
167 confound_orthogonal: confound.to_owned(),
168 });
169 }
170
171 // Weighted primary Gram MᵀW M and cross term MᵀW C in the row metric.
172 let mut gram = fast_xt_diag_x(&primary, w_metric);
173 let gram_scale = (0..p_m).map(|i| gram[[i, i]]).fold(0.0_f64, f64::max);
174 let eps = (gram_scale * ORTHOGONAL_PROJECTION_RELATIVE_RIDGE)
175 .max(ORTHOGONAL_PROJECTION_RIDGE_FLOOR);
176 for i in 0..p_m {
177 gram[[i, i]] += eps;
178 }
179 let cross = fast_xt_diag_y(&primary, w_metric, &confound.to_owned());
180
181 let gram_view = FaerArrayView::new(&gram);
182 let factor =
183 factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower).map_err(|e| {
184 format!("orthogonal_reparam: weighted primary Gram factorization failed: {e:?}")
185 })?;
186 // B = (MᵀWM + εI)⁻¹ MᵀW C (p_m × p_c)
187 let shear = factor
188 .solvemulti(&cross)
189 .map_err(|e| format!("orthogonal_reparam: projection solve failed: {e}"))?;
190
191 // C̃ = C − M·B.
192 let projection = fast_ab(&primary, &shear);
193 let confound_orthogonal = &confound - &projection;
194
195 if shear.iter().any(|v| !v.is_finite())
196 || confound_orthogonal.iter().any(|v| !v.is_finite())
197 {
198 return Err(
199 "orthogonal_reparam: reparameterization produced non-finite entries".to_string(),
200 );
201 }
202
203 Ok(Self {
204 shear,
205 confound_orthogonal,
206 })
207 }
208
209 /// The shear matrix `B` (`p_m × p_c`). Original primary coefficients are
210 /// `β_m = β̃_m − B·β_c`.
211 #[inline]
212 pub fn shear(&self) -> ArrayView2<'_, f64> {
213 self.shear.view()
214 }
215
216 /// The reparameterized confound design `C̃ = C − M·B` (`n × p_c`), exactly
217 /// W-orthogonal to `span(primary)`. This is the design the solver fits the
218 /// confound coefficients against.
219 #[inline]
220 pub fn reparameterized_confound(&self) -> ArrayView2<'_, f64> {
221 self.confound_orthogonal.view()
222 }
223
224 /// Number of primary columns `p_m`.
225 #[inline]
226 pub fn primary_cols(&self) -> usize {
227 self.shear.nrows()
228 }
229
230 /// Number of confound columns `p_c`.
231 #[inline]
232 pub fn confound_cols(&self) -> usize {
233 self.shear.ncols()
234 }
235
236 /// Map the fitted reparameterized coefficients `(β̃_m, β_c)` back to the
237 /// original basis `(β_m, β_c)` **exactly**:
238 ///
239 /// ```text
240 /// β_m = β̃_m − B·β_c, β_c unchanged.
241 /// ```
242 ///
243 /// `beta_m_reparam` has length `p_m`, `beta_c` has length `p_c`. Returns the
244 /// original-basis `(β_m, β_c)`. Because the predictor `M·β̃_m + C̃·β_c` equals
245 /// `M·β_m + C·β_c` for these recovered coefficients, predictions in the
246 /// original basis are unchanged.
247 pub fn recover_original(
248 &self,
249 beta_m_reparam: &Array1<f64>,
250 beta_c: &Array1<f64>,
251 ) -> Result<(Array1<f64>, Array1<f64>), String> {
252 let p_m = self.primary_cols();
253 let p_c = self.confound_cols();
254 if beta_m_reparam.len() != p_m {
255 return Err(format!(
256 "orthogonal_reparam: reparameterized primary coeffs length ({}) != p_m ({p_m})",
257 beta_m_reparam.len()
258 ));
259 }
260 if beta_c.len() != p_c {
261 return Err(format!(
262 "orthogonal_reparam: confound coeffs length ({}) != p_c ({p_c})",
263 beta_c.len()
264 ));
265 }
266 // β_m = β̃_m − B·β_c.
267 let shear_beta_c = self.shear.dot(beta_c);
268 let beta_m = beta_m_reparam - &shear_beta_c;
269 Ok((beta_m, beta_c.clone()))
270 }
271
272 /// Forward shear: map original-basis primary coefficients `β_m` to the
273 /// reparameterized basis `β̃_m = β_m + B·β_c` (the inverse of
274 /// [`recover_original`](Self::recover_original)). Useful for warm-starting the
275 /// reparameterized solve from an original-basis initial guess.
276 pub fn to_reparameterized(
277 &self,
278 beta_m: &Array1<f64>,
279 beta_c: &Array1<f64>,
280 ) -> Result<Array1<f64>, String> {
281 let p_m = self.primary_cols();
282 let p_c = self.confound_cols();
283 if beta_m.len() != p_m {
284 return Err(format!(
285 "orthogonal_reparam: primary coeffs length ({}) != p_m ({p_m})",
286 beta_m.len()
287 ));
288 }
289 if beta_c.len() != p_c {
290 return Err(format!(
291 "orthogonal_reparam: confound coeffs length ({}) != p_c ({p_c})",
292 beta_c.len()
293 ));
294 }
295 let shear_beta_c = self.shear.dot(beta_c);
296 Ok(beta_m + &shear_beta_c)
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use ndarray::{Array1, Array2};
304
305 /// Build a primary design `M` and a confound `C` that genuinely overlaps it
306 /// (a couple of `C`'s columns are `M` columns plus small noise), and verify
307 /// the W-orthogonality `MᵀW C̃ ≈ 0` holds to working precision.
308 #[test]
309 fn orthogonalized_confound_is_w_orthogonal_to_primary() {
310 let n = 50;
311 let mut m = Array2::<f64>::zeros((n, 3));
312 let mut c = Array2::<f64>::zeros((n, 2));
313 for i in 0..n {
314 let t = i as f64 / n as f64;
315 m[[i, 0]] = 1.0;
316 m[[i, 1]] = t;
317 m[[i, 2]] = (t * 6.0).sin();
318 // C overlaps M: col0 ≈ M col1 (confound), col1 has a fresh direction.
319 c[[i, 0]] = t + 0.01 * (t * 13.0).cos();
320 c[[i, 1]] = (t * 3.0).cos();
321 }
322 let w = Array1::<f64>::from_elem(n, 1.0);
323 let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
324 .expect("build should succeed");
325
326 let c_tilde = reparam.reparameterized_confound().to_owned();
327 // MᵀW C̃ should be ~0.
328 let cross = fast_xt_diag_y(&m, &w, &c_tilde);
329 let max_abs = cross.iter().fold(0.0_f64, |a, v| a.max(v.abs()));
330 assert!(
331 max_abs < 1e-8,
332 "MᵀW C̃ not orthogonal: max |entry| = {max_abs:e}"
333 );
334 }
335
336 /// EXACT round-trip: fit (synthetically) in the reparameterized basis, then
337 /// recover original coefficients and confirm the predictor is identical.
338 #[test]
339 fn coefficient_round_trip_is_exact() {
340 let n = 40;
341 let mut m = Array2::<f64>::zeros((n, 2));
342 let mut c = Array2::<f64>::zeros((n, 2));
343 for i in 0..n {
344 let t = i as f64 / n as f64;
345 m[[i, 0]] = 1.0;
346 m[[i, 1]] = (t * 4.0).sin();
347 c[[i, 0]] = t; // overlaps the linear-ish part of M
348 c[[i, 1]] = (t * 2.0).cos();
349 }
350 let w = Array1::<f64>::from_elem(n, 1.0);
351 let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
352 .expect("build should succeed");
353
354 // Pretend the solver returned these reparameterized-basis coefficients.
355 let beta_m_reparam = Array1::from_vec(vec![0.7, -1.3]);
356 let beta_c = Array1::from_vec(vec![2.1, 0.4]);
357
358 // Predictor in the reparameterized basis: M·β̃_m + C̃·β_c.
359 let c_tilde = reparam.reparameterized_confound().to_owned();
360 let eta_reparam = m.dot(&beta_m_reparam) + c_tilde.dot(&beta_c);
361
362 // Recover original coefficients and form the predictor in the ORIGINAL
363 // basis: M·β_m + C·β_c. Must match to tight tolerance.
364 let (beta_m, beta_c_out) = reparam
365 .recover_original(&beta_m_reparam, &beta_c)
366 .expect("recover should succeed");
367 let eta_original = m.dot(&beta_m) + c.dot(&beta_c_out);
368
369 let max_diff = (&eta_reparam - &eta_original)
370 .iter()
371 .fold(0.0_f64, |a, v| a.max(v.abs()));
372 assert!(
373 max_diff < 1e-10,
374 "predictor changed under round-trip: max |Δη| = {max_diff:e}"
375 );
376 // Confound coefficients are untouched by the reparameterization.
377 let cdiff = (&beta_c_out - &beta_c)
378 .iter()
379 .fold(0.0_f64, |a, v| a.max(v.abs()));
380 assert!(cdiff == 0.0, "confound coeffs changed: {cdiff:e}");
381
382 // Forward map is the exact inverse of recover_original.
383 let back = reparam
384 .to_reparameterized(&beta_m, &beta_c)
385 .expect("forward should succeed");
386 let fdiff = (&back - &beta_m_reparam)
387 .iter()
388 .fold(0.0_f64, |a, v| a.max(v.abs()));
389 assert!(fdiff < 1e-10, "forward/inverse mismatch: {fdiff:e}");
390 }
391
392 /// When the confound does NOT overlap the primary span, predictions are
393 /// unchanged AND the orthogonal design equals the raw confound (no shear),
394 /// confirming the pass touches nothing it should not.
395 #[test]
396 fn absent_confound_leaves_design_and_predictions_unchanged() {
397 let n = 30;
398 // Primary spans constant + linear; confound is a pure quadratic deviation
399 // built to be Euclidean-orthogonal to span{1, t} by centering.
400 let mut m = Array2::<f64>::zeros((n, 2));
401 let mut raw_quad = Vec::with_capacity(n);
402 for i in 0..n {
403 let t = i as f64 / (n as f64 - 1.0);
404 m[[i, 0]] = 1.0;
405 m[[i, 1]] = t;
406 raw_quad.push(t * t);
407 }
408 // Residualize the quadratic against {1, t} by hand so the confound column
409 // is genuinely orthogonal to span(M) under W = I (the "confound absent"
410 // regime). Use the very pass we are testing as the residualizer would be
411 // circular; instead do an explicit least-squares residual.
412 let w = Array1::<f64>::from_elem(n, 1.0);
413 // Solve M b = quad in LS, residual = quad - M b is ⊥ span(M).
414 let gram = fast_xt_diag_x(&m, &w);
415 let quad = Array1::from_vec(raw_quad);
416 let cross = m.t().dot(&quad);
417 let gview = FaerArrayView::new(&gram);
418 let factor = factorize_symmetricwith_fallback(gview.as_ref(), Side::Lower).expect("factor");
419 let b = FactorizedSystem::solve(&factor, &cross).expect("solve");
420 let resid = &quad - &m.dot(&b);
421 let mut c = Array2::<f64>::zeros((n, 1));
422 c.column_mut(0).assign(&resid);
423
424 let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
425 .expect("build should succeed");
426 // No overlap ⇒ shear ≈ 0 ⇒ C̃ ≈ C.
427 let shear_max = reparam.shear().iter().fold(0.0_f64, |a, v| a.max(v.abs()));
428 assert!(shear_max < 1e-8, "expected ~zero shear, got {shear_max:e}");
429 let c_tilde = reparam.reparameterized_confound().to_owned();
430 let design_diff = (&c_tilde - &c).iter().fold(0.0_f64, |a, v| a.max(v.abs()));
431 assert!(
432 design_diff < 1e-8,
433 "orthogonalized design drifted from raw when confound absent: {design_diff:e}"
434 );
435 }
436
437 /// Empty primary span ⇒ confound returned unchanged (nothing to project out).
438 #[test]
439 fn empty_primary_returns_raw_confound() {
440 let n = 8;
441 let m = Array2::<f64>::zeros((n, 0));
442 let mut c = Array2::<f64>::zeros((n, 2));
443 for i in 0..n {
444 c[[i, 0]] = i as f64;
445 c[[i, 1]] = 1.0;
446 }
447 let w = Array1::<f64>::from_elem(n, 1.0);
448 let reparam = OrthogonalReparam::build_unconditional(m.view(), c.view(), &w)
449 .expect("build should succeed");
450 let c_tilde = reparam.reparameterized_confound().to_owned();
451 let diff = (&c_tilde - &c).iter().fold(0.0_f64, |a, v| a.max(v.abs()));
452 assert!(diff == 0.0, "empty primary must return raw confound");
453 }
454}