gam_solve/pirls/pls_solver.rs
1//! Penalized least-squares solver and Gaussian fast paths.
2//!
3//! Owns:
4//! - `GaussianFixedCache` — `XᵀWX`/`XᵀW(y−offset)` cache for the
5//! Gaussian-Identity short-circuit that the REML outer loop reuses across
6//! smoothing-parameter candidates.
7//! - `SparseXtwxPrecomputed` — the sparse-pattern-aligned twin of the above
8//! for designs that take the sparse-native PIRLS path.
9//! - `solve_penalized_least_squares_implicit` — identity/Gaussian implicit
10//! PLS, dense and sparse-native paths.
11
12use super::loop_driver::max_symmetric_asymmetry;
13use super::{
14 FIXED_STABILIZATION_RIDGE, PirlsPenalty, PirlsWorkspace, SparseXtWxCache, StablePLSResult,
15 WorkingReparamTransform, calculate_edf_from_sparse_factor,
16 calculate_edfwithworkspace_from_factor, ensure_sparse_positive_definitewithridge,
17 solve_sparse_spd,
18};
19use crate::estimate::EstimationError;
20use gam_linalg::faer_ndarray::{FaerLinalgError, array1_to_col_matmut};
21use gam_linalg::utils::{StableSolver, array_is_finite};
22use gam_linalg::matrix::{DesignMatrix, LinearOperator, SymmetricMatrix};
23use gam_problem::{Coefficients, LinkFunction};
24use faer::sparse::SparseColMat;
25use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder};
26use std::sync::Arc;
27
28/// Reusable `XᵀWX` and `XᵀW(y − offset)` for Gaussian + Identity REML fits.
29///
30/// The Gaussian-identity P-IRLS short-circuit solves a single linear system
31/// `(XᵀWX + Σ λ_k S_k + ρ·I) β = XᵀW(y − offset)`. The right-hand-side matrix
32/// and vector are independent of the smoothing parameters `λ`, so when the
33/// outer REML loop evaluates the same problem at many `(λ_1, …, λ_k)`
34/// candidates we only need to assemble them **once** before the loop and
35/// reuse them inside every inner PIRLS call.
36///
37/// Stored in *original* coordinates (no Qs rotation applied). When the
38/// inner solver uses a `WorkingReparamTransform`, it conjugates / projects
39/// these matrices on the fly — that step is O(p³) / O(p²), independent of N.
40#[derive(Debug)]
41pub struct GaussianFixedCache {
42 /// `XᵀWX` in the original coefficient basis. Symmetric, p × p.
43 pub xtwx_orig: Array2<f64>,
44 /// `XᵀW(y − offset)` in the original basis. Length p.
45 pub xtwy_orig: Array1<f64>,
46 /// `(y − offset)ᵀW(y − offset)`.
47 ///
48 /// Together with `xtwx_orig` and `xtwy_orig`, this is the last scalar
49 /// sufficient statistic needed to evaluate the Gaussian penalized RSS
50 /// exactly at any λ without re-streaming the rows.
51 pub centered_weighted_y_sq: f64,
52 /// When true, the caller is deliberately serving a design-moving trial from
53 /// sufficient statistics and the `DesignMatrix` rows on the current REML
54 /// surface may be a stale reference surface. Consumers must not apply those
55 /// rows for fitted values, RSS, or likelihood summaries.
56 pub row_prediction_is_stale: bool,
57 /// `XᵀWX` precomputed for the sparse path, aligned with the symbolic
58 /// pattern of `SparseXtWxCache::new(x)` on the original sparse design.
59 /// `None` when the design has no sparse form (e.g. dense-only fits).
60 ///
61 /// The sparse REML path rebuilds `H = XᵀWX + Sλ + δI` per outer
62 /// evaluation. For Gaussian-Identity the weights never change, so the
63 /// `XᵀWX` contribution is invariant across the outer loop and can be
64 /// scattered from this cached values vector instead of re-doing the
65 /// O(nnz²/n) SpGEMM each call.
66 pub xtwx_sparse_orig: Option<Arc<SparseXtwxPrecomputed>>,
67}
68
69/// Precomputed numerical values of `XᵀWX` aligned with the symbolic pattern
70/// that `SparseXtWxCache::new(x)` produces on its first call. Two such caches
71/// built from the same sparse `x` produce byte-identical symbolic patterns
72/// (faer's `sparse_sparse_matmul_symbolic` is deterministic), so the cached
73/// values can be installed back into a fresh `SparseXtWxCache` for the same
74/// `x` without rerunning the SpGEMM.
75///
76/// We snapshot the symbolic pattern (`col_ptr` / `row_idx`) alongside the
77/// values so the consumer can verify pattern equivalence and fall through to
78/// the per-call recomputation if anything diverges (e.g. an `x` with a
79/// different symbolic shape sneaks in).
80#[derive(Debug, Clone)]
81pub struct SparseXtwxPrecomputed {
82 pub xtwx_symbolic_col_ptr: Vec<usize>,
83 pub xtwx_symbolic_row_idx: Vec<usize>,
84 pub xtwxvalues: Vec<f64>,
85}
86
87impl SparseXtwxPrecomputed {
88 /// Build the precomputed `XᵀWX` value layout for `x` at the given
89 /// `weights`. The output reuses the same construction path the inner
90 /// PIRLS workspace uses, so it lands in exactly the symbolic pattern
91 /// the consumer expects.
92 pub fn build(
93 x: &SparseColMat<usize, f64>,
94 weights: &Array1<f64>,
95 ) -> Result<Self, EstimationError> {
96 let mut cache = SparseXtWxCache::new(x)?;
97 cache.compute_numeric(x, weights)?;
98 Ok(Self {
99 xtwx_symbolic_col_ptr: cache.xtwx_symbolic.col_ptr().to_vec(),
100 xtwx_symbolic_row_idx: cache.xtwx_symbolic.row_idx().to_vec(),
101 xtwxvalues: cache.xtwxvalues,
102 })
103 }
104}
105
106/// Identity-link solver that operates in original or QS-transformed coordinates
107/// without materializing X·Qs. When the design is sparse and `qs` is `None`
108/// (sparse-native path), uses sparse Cholesky for O(nnz^{1.5}) cost instead
109/// of the O(p³) dense Cholesky.
110pub(super) fn solve_penalized_least_squares_implicit(
111 x_original: &DesignMatrix,
112 transform: Option<&WorkingReparamTransform>,
113 z: ArrayView1<f64>,
114 weights: ArrayView1<f64>,
115 offset: ArrayView1<f64>,
116 penalty: &PirlsPenalty,
117 workspace: &mut PirlsWorkspace,
118 y: ArrayView1<f64>,
119 link_function: LinkFunction,
120 gaussian_fixed_cache: Option<&GaussianFixedCache>,
121) -> Result<(StablePLSResult, usize), EstimationError> {
122 let p_dim = penalty.dim();
123
124 // ── Sparse-native fast path ──────────────────────────────────────────
125 // When design is sparse and we are in original coordinates (qs = None),
126 // assemble the penalized Hessian in sparse format and solve with sparse
127 // Cholesky. This avoids O(p²) dense X'WX and O(p³) dense factorization.
128 if transform.is_none()
129 && let Some(x_sparse) = x_original.as_sparse()
130 {
131 let PirlsPenalty::Dense { s_transformed, .. } = penalty else {
132 crate::bail_invalid_estim!(
133 "sparse-native PIRLS requires a dense transformed penalty matrix"
134 );
135 };
136 let weights_owned = weights.to_owned();
137
138 // Gaussian-Identity fast path: the inner sparse `XᵀWX` is invariant
139 // across the outer REML loop because the IRLS weights are constant
140 // (W = priorweights). The cached values land in the inner workspace
141 // and bypass the per-eval SpGEMM.
142 let precomputed_xtwx =
143 gaussian_fixed_cache.and_then(|c| c.xtwx_sparse_orig.as_ref().map(|arc| arc.as_ref()));
144
145 // 1. Sparse penalized Hessian: H = X'diag(w)X + S_λ + ridge·I.
146 // The Cholesky factor is reused from the SPD check so we avoid
147 // factorizing the same matrix twice.
148 let (h_sparse, factor, ridge_used) = ensure_sparse_positive_definitewithridge(|ridge| {
149 let ridge = if ridge == 0.0 {
150 FIXED_STABILIZATION_RIDGE
151 } else {
152 ridge
153 };
154 workspace.assemble_sparse_penalized_hessian(
155 x_sparse,
156 &weights_owned,
157 s_transformed,
158 ridge,
159 precomputed_xtwx,
160 )
161 })?;
162
163 // 2. RHS = X'W(z - offset) + S_λ μ + ridge_used · μ.
164 // The `ridge_used · μ` term matches the diagonal ridge added to
165 // the Hessian in step 1, keeping the augmented system a
166 // Tikhonov regularization centered at the prior mean target
167 // rather than at zero (see `prior_mean_target` field docs).
168 let mut wz = z.to_owned();
169 wz -= &offset;
170 wz *= &weights_owned;
171 let mut rhs = x_original.transpose_vector_multiply(&wz);
172 rhs += penalty.linear_shift();
173 if ridge_used > 0.0 {
174 let prior_mean_target = penalty.prior_mean_target();
175 if prior_mean_target.len() == rhs.len() {
176 rhs.scaled_add(ridge_used, prior_mean_target);
177 }
178 }
179
180 // 3. Sparse Cholesky solve (factor reused from step 1)
181 let betavec = solve_sparse_spd(&factor, &rhs)?;
182
183 // 4. EDF — reuse the sparse Cholesky factor from step 1 to avoid a
184 // second O(nnz·…) factorization of the identical penalized Hessian.
185 let h_sym = SymmetricMatrix::Sparse(h_sparse);
186 let edf = calculate_edf_from_sparse_factor(&factor, penalty)?;
187
188 // 5. Scale. When Gaussian sufficient statistics are installed, compute
189 // RSS from k-space only; the design rows may be a stale reference
190 // surface on the #1033 ψ-tensor fast path.
191 let standard_deviation = match link_function {
192 LinkFunction::Identity => {
193 let weighted_rss = if let Some(cache) = gaussian_fixed_cache {
194 let quadratic = betavec.dot(&cache.xtwx_orig.dot(&betavec));
195 (cache.centered_weighted_y_sq - 2.0 * betavec.dot(&cache.xtwy_orig) + quadratic)
196 .max(0.0)
197 } else {
198 let fitted_vals = {
199 let xb = x_original.apply(&betavec);
200 let mut f = xb;
201 f += &offset;
202 f
203 };
204 let residuals = &y - &fitted_vals;
205 weights
206 .iter()
207 .zip(residuals.iter())
208 .map(|(&w, &r)| w * r * r)
209 .sum()
210 };
211 let effective_n = y.len() as f64;
212 (weighted_rss / (effective_n - edf).max(1.0)).sqrt()
213 }
214 _ => 1.0,
215 };
216
217 return Ok((
218 StablePLSResult {
219 beta: Coefficients::new(betavec),
220 penalized_hessian: h_sym,
221 edf,
222 standard_deviation,
223 ridge_used,
224 },
225 p_dim,
226 ));
227 }
228
229 // ── Dense / QS-rotated path ──────────────────────────────────────────
230
231 // 1. Prepare weighted buffers
232 if workspace.wz.len() != z.len() {
233 workspace.wz = Array1::zeros(z.len());
234 }
235 workspace.wz.assign(&z);
236 workspace.wz -= &offset;
237 workspace.wz *= &weights;
238
239 // 2. Form X'WX: compute in original coordinates, then rotate by Qs.
240 //
241 // Gaussian + Identity REML reuses a precomputed `XᵀWX` (the weights and
242 // design never change across the outer loop in that family), so when the
243 // caller supplied a `GaussianFixedCache` we skip the O(N·p²) dense
244 // assembly here and adopt the cached matrix as-is.
245 let weights_owned = weights.to_owned();
246 let xtwx_orig = if let Some(cache) = gaussian_fixed_cache {
247 // Cache hit: weights and design are invariant for Gaussian-Identity
248 // across the outer REML loop, so adopt the precomputed XᵀWX directly
249 // and avoid the O(N·p²) dense assembly entirely.
250 let p = x_original.ncols();
251 if cache.xtwx_orig.nrows() != p || cache.xtwx_orig.ncols() != p {
252 return Err(EstimationError::InvalidInput(format!(
253 "GaussianFixedCache XᵀWX shape {}×{} does not match design p={}",
254 cache.xtwx_orig.nrows(),
255 cache.xtwx_orig.ncols(),
256 p,
257 )));
258 }
259 cache.xtwx_orig.clone()
260 } else {
261 match x_original {
262 // Only materialized dense designs can use the shared dense assembly path.
263 // Lazy operator-backed dense designs route to diag_xtw_x like sparse.
264 DesignMatrix::Dense(x_dense) if x_dense.is_materialized_dense() => {
265 let p = x_dense.ncols();
266 let x_dense = x_dense.to_dense_arc();
267 if workspace.hessian_buf.nrows() != p || workspace.hessian_buf.ncols() != p {
268 workspace.hessian_buf = Array2::zeros((p, p).f());
269 } else {
270 workspace.hessian_buf.fill(0.0);
271 }
272 PirlsWorkspace::add_dense_xtwx_signed(
273 &weights_owned,
274 &mut workspace.weighted_x_chunk,
275 x_dense.as_ref(),
276 &mut workspace.hessian_buf,
277 );
278 std::mem::take(&mut workspace.hessian_buf)
279 }
280 _ => {
281 // Operator-form fallback: sparse designs and lazy operator-backed
282 // dense designs cannot be densified, so route through the signed
283 // XᵀWX operator.
284 gam_linalg::matrix::xt_diag_x_signed(
285 x_original,
286 gam_linalg::matrix::SignedWeightsView::from_array(&weights_owned),
287 )
288 .map(|h| h.to_dense())
289 .map_err(EstimationError::InvalidInput)?
290 }
291 }
292 };
293 let xtwx_orig_asym = max_symmetric_asymmetry(&xtwx_orig);
294 let xtwx_transformed = if let Some(transform) = transform {
295 transform.conjugate_matrix(&xtwx_orig)
296 } else {
297 xtwx_orig
298 };
299 let mut penalized_hessian = xtwx_transformed.clone();
300 penalty.add_to_hessian(&mut penalized_hessian);
301
302 // 3. Form X'Wz: compute in original coordinates, then rotate.
303 // With the Gaussian-Identity cache `z = y` and `wz = W·(y − offset)`
304 // is identical across outer iterations, so reuse the precomputed
305 // `XᵀW(y − offset)` directly.
306 let xtwy_orig = if let Some(cache) = gaussian_fixed_cache {
307 assert_eq!(
308 cache.xtwy_orig.len(),
309 x_original.ncols(),
310 "GaussianFixedCache XᵀW(y−offset) length must match design p"
311 );
312 cache.xtwy_orig.clone()
313 } else {
314 x_original.transpose_vector_multiply(&workspace.wz)
315 };
316 if workspace.vec_buf_p.len() != p_dim {
317 workspace.vec_buf_p = Array1::zeros(p_dim);
318 }
319 if let Some(transform) = transform {
320 workspace
321 .vec_buf_p
322 .assign(&transform.apply_transpose(&xtwy_orig));
323 } else {
324 workspace.vec_buf_p.assign(&xtwy_orig);
325 }
326 workspace.vec_buf_p += penalty.linear_shift();
327
328 {
329 // The penalized Hessian is assembled from symmetric pieces (XᵀWX and
330 // the penalty), so any asymmetry is pure floating-point accumulation
331 // error; anything above this floor signals a genuine assembly bug.
332 const PENALIZED_HESSIAN_ASYMMETRY_TOL: f64 = 1e-8;
333 let xtwx_asym = max_symmetric_asymmetry(&xtwx_transformed);
334 let penalty_asym = match penalty {
335 PirlsPenalty::Dense { s_transformed, .. } => max_symmetric_asymmetry(s_transformed),
336 PirlsPenalty::Diagonal { .. } => 0.0,
337 };
338 let total_asym = max_symmetric_asymmetry(&penalized_hessian);
339 assert!(
340 total_asym <= PENALIZED_HESSIAN_ASYMMETRY_TOL,
341 "implicit PLS penalized Hessian asymmetry too large: total={total_asym:.3e}, xtwx_orig={xtwx_orig_asym:.3e}, xtwx={xtwx_asym:.3e}, penalty={penalty_asym:.3e}, tol={PENALIZED_HESSIAN_ASYMMETRY_TOL:.3e}",
342 );
343 }
344
345 // 4. Ridge stabilization — CONDITIONAL, matching the sparse path
346 // (`ensure_sparse_positive_definitewithridge`) and the dense Newton path
347 // (`ensure_positive_definitewithridge`). A penalized Hessian assembled from
348 // `XᵀWX + S_λ` is mathematically PSD; a fixed tiny nugget is only needed to
349 // cure round-off when the bare matrix narrowly fails Cholesky. Applying the
350 // nugget UNCONDITIONALLY (the previous behaviour) made β̂ the stationary
351 // point of the RIDGED objective `½βᵀ(H+δI)β`, so the inner residual was
352 // `Xᵀu − S_λβ̂ = δβ̂` rather than 0. The outer REML ψ-gradient differentiates
353 // the BARE objective via the envelope theorem (it assumes exact
354 // stationarity), so the gratuitous δ broke the envelope identity: the
355 // analytic datafit derivative `a` was short by `½·δ·βᵀ(dβ̂/dψ)` and the
356 // β-independent `log|H|` term was differentiated on the un-ridged surface
357 // while the criterion VALUE used `log|H+δI|`. For the Matérn iso-κ joint
358 // REML at θ₀ (`TransformedQs` frame, δ_eff ≈ 1.75e-6 in the original basis)
359 // this is exactly the residual outer-gradient↔FD DESYNC of #1122 (gap
360 // 2.565e-2, with `cos(Xᵀu−S_λβ̂, β̂) = 1.0000` pinning the residual to the
361 // ridge gradient). Try the bare matrix first so the well-conditioned common
362 // case carries NO ridge (`ridge_used = 0`) and the envelope identity holds
363 // exactly; fall back to the Tikhonov nugget only when the bare factorization
364 // actually fails. The augmented RHS `r + δμ` keeps the fallback a Tikhonov
365 // regularization centered at the prior-mean target.
366 let bare_factor = StableSolver::new("pirls implicit pls")
367 .factorize(&penalized_hessian)
368 .ok();
369 let (factor, ridge_used) = if let Some(factor) = bare_factor {
370 (factor, 0.0)
371 } else {
372 let nugget = FIXED_STABILIZATION_RIDGE;
373 let mut regularizedhessian = penalized_hessian.clone();
374 if nugget > 0.0 {
375 for i in 0..p_dim {
376 regularizedhessian[[i, i]] += nugget;
377 }
378 }
379 let factor = StableSolver::new("pirls implicit pls")
380 .factorize(®ularizedhessian)
381 .map_err(EstimationError::LinearSystemSolveFailed)?;
382 (factor, nugget)
383 };
384
385 // 5. Solve
386 if workspace.rhs_full.len() != p_dim {
387 workspace.rhs_full = Array1::zeros(p_dim);
388 }
389 workspace.rhs_full.assign(&workspace.vec_buf_p);
390 if ridge_used > 0.0 {
391 let prior_mean_target = penalty.prior_mean_target();
392 if prior_mean_target.len() == p_dim {
393 workspace.rhs_full.scaled_add(ridge_used, prior_mean_target);
394 }
395 }
396 let mut rhsview = array1_to_col_matmut(&mut workspace.rhs_full);
397 factor.solve_in_place(rhsview.as_mut());
398 if !array_is_finite(&workspace.rhs_full) {
399 return Err(EstimationError::LinearSystemSolveFailed(
400 FaerLinalgError::FactorizationFailed {
401 context: "PIRLS implicit PLS non-finite solve",
402 },
403 ));
404 }
405 let betavec = workspace.rhs_full.clone();
406
407 // 6. EDF — reuse the factor already produced in step 5 to avoid a second
408 // O(p³) factorization of the identical regularized Hessian.
409 let edf = calculate_edfwithworkspace_from_factor(&factor, penalty, workspace)?;
410
411 // 7. Scale (composed: eta = offset + X Qs beta). When Gaussian sufficient
412 // statistics are installed, compute RSS from k-space only; the design rows
413 // may be a stale reference surface on the #1033 ψ-tensor fast path.
414 let qbeta = if let Some(transform) = transform {
415 transform.apply(&betavec)
416 } else {
417 betavec.clone()
418 };
419 let standard_deviation = match link_function {
420 LinkFunction::Identity => {
421 let weighted_rss = if let Some(cache) = gaussian_fixed_cache {
422 let quadratic = qbeta.dot(&cache.xtwx_orig.dot(&qbeta));
423 (cache.centered_weighted_y_sq - 2.0 * qbeta.dot(&cache.xtwy_orig) + quadratic)
424 .max(0.0)
425 } else {
426 let xqbeta = x_original.apply(&qbeta);
427 let mut fitted = xqbeta;
428 fitted += &offset;
429 let residuals = &y - &fitted;
430 weights
431 .iter()
432 .zip(residuals.iter())
433 .map(|(&w, &r)| w * r * r)
434 .sum()
435 };
436 let effective_n = y.len() as f64;
437 (weighted_rss / (effective_n - edf).max(1.0)).sqrt()
438 }
439 _ => 1.0,
440 };
441
442 Ok((
443 StablePLSResult {
444 beta: Coefficients::new(betavec),
445 penalized_hessian: SymmetricMatrix::Dense(penalized_hessian),
446 edf,
447 standard_deviation,
448 ridge_used,
449 },
450 p_dim,
451 ))
452}