gam_solve/pirls/pls_solver.rs
1//! Penalized least-squares solver and Gaussian fast paths.
2//!
3//! Owns:
4//! - `GaussianFixedCache` — `XᵀWX`/`XᵀW(y−offset)` cache for the
5//! Gaussian-Identity short-circuit that the REML outer loop reuses across
6//! smoothing-parameter candidates.
7//! - `SparseXtwxPrecomputed` — the sparse-pattern-aligned twin of the above
8//! for designs that take the sparse-native PIRLS path.
9//! - `solve_penalized_least_squares_implicit` — identity/Gaussian implicit
10//! PLS, dense and sparse-native paths.
11
12use super::loop_driver::max_symmetric_asymmetry;
13use super::{
14 FIXED_STABILIZATION_RIDGE, PirlsPenalty, PirlsWorkspace, SparseXtWxCache, StablePLSResult,
15 WorkingReparamTransform, calculate_edf_from_sparse_factor,
16 calculate_edfwithworkspace_from_factor, ensure_sparse_positive_definitewithridge,
17 solve_sparse_spd,
18};
19use crate::estimate::EstimationError;
20use gam_linalg::faer_ndarray::{FaerLinalgError, array1_to_col_matmut};
21use gam_linalg::utils::{StableSolver, array_is_finite};
22use gam_linalg::matrix::{DesignMatrix, LinearOperator, SymmetricMatrix};
23use gam_problem::{Coefficients, LinkFunction};
24use faer::sparse::SparseColMat;
25use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder};
26use std::sync::Arc;
27
28/// Reusable `XᵀWX` and `XᵀW(y − offset)` for Gaussian + Identity REML fits.
29///
30/// The Gaussian-identity P-IRLS short-circuit solves a single linear system
31/// `(XᵀWX + Σ λ_k S_k + ρ·I) β = XᵀW(y − offset)`. The right-hand-side matrix
32/// and vector are independent of the smoothing parameters `λ`, so when the
33/// outer REML loop evaluates the same problem at many `(λ_1, …, λ_k)`
34/// candidates we only need to assemble them **once** before the loop and
35/// reuse them inside every inner PIRLS call.
36///
37/// Stored in *original* coordinates (no Qs rotation applied). When the
38/// inner solver uses a `WorkingReparamTransform`, it conjugates / projects
39/// these matrices on the fly — that step is O(p³) / O(p²), independent of N.
40#[derive(Debug)]
41pub struct GaussianFixedCache {
42 /// `XᵀWX` in the original coefficient basis. Symmetric, p × p.
43 pub xtwx_orig: Array2<f64>,
44 /// `XᵀW(y − offset)` in the original basis. Length p.
45 pub xtwy_orig: Array1<f64>,
46 /// `(y − offset)ᵀW(y − offset)`.
47 ///
48 /// Together with `xtwx_orig` and `xtwy_orig`, this is the last scalar
49 /// sufficient statistic needed to evaluate the Gaussian penalized RSS
50 /// exactly at any λ without re-streaming the rows.
51 pub centered_weighted_y_sq: f64,
52 /// When true, the caller is deliberately serving a design-moving trial from
53 /// sufficient statistics and the `DesignMatrix` rows on the current REML
54 /// surface may be a stale reference surface. Consumers must not apply those
55 /// rows for fitted values, RSS, or likelihood summaries.
56 pub row_prediction_is_stale: bool,
57 /// `XᵀWX` precomputed for the sparse path, aligned with the symbolic
58 /// pattern of `SparseXtWxCache::new(x)` on the original sparse design.
59 /// `None` when the design has no sparse form (e.g. dense-only fits).
60 ///
61 /// The sparse REML path rebuilds `H = XᵀWX + Sλ + δI` per outer
62 /// evaluation. For Gaussian-Identity the weights never change, so the
63 /// `XᵀWX` contribution is invariant across the outer loop and can be
64 /// scattered from this cached values vector instead of re-doing the
65 /// O(nnz²/n) SpGEMM each call.
66 pub xtwx_sparse_orig: Option<Arc<SparseXtwxPrecomputed>>,
67}
68
69/// Precomputed numerical values of `XᵀWX` aligned with the symbolic pattern
70/// that `SparseXtWxCache::new(x)` produces on its first call. Two such caches
71/// built from the same sparse `x` produce byte-identical symbolic patterns
72/// (faer's `sparse_sparse_matmul_symbolic` is deterministic), so the cached
73/// values can be installed back into a fresh `SparseXtWxCache` for the same
74/// `x` without rerunning the SpGEMM.
75///
76/// We snapshot the symbolic pattern (`col_ptr` / `row_idx`) alongside the
77/// values so the consumer can verify pattern equivalence and fall through to
78/// the per-call recomputation if anything diverges (e.g. an `x` with a
79/// different symbolic shape sneaks in).
80#[derive(Debug, Clone)]
81pub struct SparseXtwxPrecomputed {
82 pub xtwx_symbolic_col_ptr: Vec<usize>,
83 pub xtwx_symbolic_row_idx: Vec<usize>,
84 pub xtwxvalues: Vec<f64>,
85}
86
87impl SparseXtwxPrecomputed {
88 /// Build the precomputed `XᵀWX` value layout for `x` at the given
89 /// `weights`. The output reuses the same construction path the inner
90 /// PIRLS workspace uses, so it lands in exactly the symbolic pattern
91 /// the consumer expects.
92 pub fn build(
93 x: &SparseColMat<usize, f64>,
94 weights: &Array1<f64>,
95 ) -> Result<Self, EstimationError> {
96 let mut cache = SparseXtWxCache::new(x)?;
97 cache.compute_numeric(x, weights)?;
98 Ok(Self {
99 xtwx_symbolic_col_ptr: cache.xtwx_symbolic.col_ptr().to_vec(),
100 xtwx_symbolic_row_idx: cache.xtwx_symbolic.row_idx().to_vec(),
101 xtwxvalues: cache.xtwxvalues,
102 })
103 }
104}
105
106/// Identity-link solver that operates in original or QS-transformed coordinates
107/// without materializing X·Qs. When the design is sparse and `qs` is `None`
108/// (sparse-native path), uses sparse Cholesky for O(nnz^{1.5}) cost instead
109/// of the O(p³) dense Cholesky.
110pub(super) fn solve_penalized_least_squares_implicit(
111 x_original: &DesignMatrix,
112 transform: Option<&WorkingReparamTransform>,
113 z: ArrayView1<f64>,
114 weights: ArrayView1<f64>,
115 offset: ArrayView1<f64>,
116 penalty: &PirlsPenalty,
117 workspace: &mut PirlsWorkspace,
118 y: ArrayView1<f64>,
119 link_function: LinkFunction,
120 gaussian_fixed_cache: Option<&GaussianFixedCache>,
121) -> Result<(StablePLSResult, usize), EstimationError> {
122 let p_dim = penalty.dim();
123
124 // ── Sparse-native fast path ──────────────────────────────────────────
125 // When design is sparse and we are in original coordinates (qs = None),
126 // assemble the penalized Hessian in sparse format and solve with sparse
127 // Cholesky. This avoids O(p²) dense X'WX and O(p³) dense factorization.
128 if transform.is_none()
129 && let Some(x_sparse) = x_original.as_sparse()
130 {
131 let PirlsPenalty::Dense { s_transformed, .. } = penalty else {
132 crate::bail_invalid_estim!(
133 "sparse-native PIRLS requires a dense transformed penalty matrix"
134 );
135 };
136 let weights_owned = weights.to_owned();
137
138 // Gaussian-Identity fast path: the inner sparse `XᵀWX` is invariant
139 // across the outer REML loop because the IRLS weights are constant
140 // (W = priorweights). The cached values land in the inner workspace
141 // and bypass the per-eval SpGEMM.
142 let precomputed_xtwx =
143 gaussian_fixed_cache.and_then(|c| c.xtwx_sparse_orig.as_ref().map(|arc| arc.as_ref()));
144
145 // 1. Sparse penalized Hessian: H = X'diag(w)X + S_λ + ridge·I.
146 // The Cholesky factor is reused from the SPD check so we avoid
147 // factorizing the same matrix twice.
148 let (h_sparse, factor, ridge_used) = ensure_sparse_positive_definitewithridge(|ridge| {
149 let ridge = if ridge == 0.0 {
150 FIXED_STABILIZATION_RIDGE
151 } else {
152 ridge
153 };
154 workspace.assemble_sparse_penalized_hessian(
155 x_sparse,
156 &weights_owned,
157 s_transformed,
158 ridge,
159 precomputed_xtwx,
160 )
161 })?;
162
163 // 2. RHS = X'W(z - offset) + S_λ μ + ridge_used · μ.
164 // The `ridge_used · μ` term matches the diagonal ridge added to
165 // the Hessian in step 1, keeping the augmented system a
166 // Tikhonov regularization centered at the prior mean target
167 // rather than at zero (see `prior_mean_target` field docs).
168 let mut wz = z.to_owned();
169 wz -= &offset;
170 wz *= &weights_owned;
171 let mut rhs = x_original.transpose_vector_multiply(&wz);
172 rhs += penalty.linear_shift();
173 if ridge_used > 0.0 {
174 let prior_mean_target = penalty.prior_mean_target();
175 if prior_mean_target.len() == rhs.len() {
176 rhs.scaled_add(ridge_used, prior_mean_target);
177 }
178 }
179
180 // 3. Sparse Cholesky solve (factor reused from step 1)
181 let betavec = solve_sparse_spd(&factor, &rhs)?;
182
183 // 4. EDF — reuse the sparse Cholesky factor from step 1 to avoid a
184 // second O(nnz·…) factorization of the identical penalized Hessian.
185 let h_sym = SymmetricMatrix::Sparse(h_sparse);
186 let edf = calculate_edf_from_sparse_factor(&factor, penalty)?;
187
188 // 5. Scale. When Gaussian sufficient statistics are installed, compute
189 // RSS from k-space only; the design rows may be a stale reference
190 // surface on the #1033 ψ-tensor fast path.
191 let standard_deviation = match link_function {
192 LinkFunction::Identity => {
193 let weighted_rss = if let Some(cache) = gaussian_fixed_cache {
194 let quadratic = betavec.dot(&cache.xtwx_orig.dot(&betavec));
195 (cache.centered_weighted_y_sq - 2.0 * betavec.dot(&cache.xtwy_orig) + quadratic)
196 .max(0.0)
197 } else {
198 let fitted_vals = {
199 let xb = x_original.apply(&betavec);
200 let mut f = xb;
201 f += &offset;
202 f
203 };
204 let residuals = &y - &fitted_vals;
205 weights
206 .iter()
207 .zip(residuals.iter())
208 .map(|(&w, &r)| w * r * r)
209 .sum()
210 };
211 let effective_n = y.len() as f64;
212 (weighted_rss / (effective_n - edf).max(1.0)).sqrt()
213 }
214 _ => 1.0,
215 };
216
217 return Ok((
218 StablePLSResult {
219 beta: Coefficients::new(betavec),
220 penalized_hessian: h_sym,
221 edf,
222 standard_deviation,
223 ridge_used,
224 },
225 p_dim,
226 ));
227 }
228
229 // ── Dense / QS-rotated path ──────────────────────────────────────────
230
231 // 1. Prepare weighted buffers
232 if workspace.wz.len() != z.len() {
233 workspace.wz = Array1::zeros(z.len());
234 }
235 workspace.wz.assign(&z);
236 workspace.wz -= &offset;
237 workspace.wz *= &weights;
238
239 // 2. Form X'WX: compute in original coordinates, then rotate by Qs.
240 //
241 // Gaussian + Identity REML reuses a precomputed `XᵀWX` (the weights and
242 // design never change across the outer loop in that family), so when the
243 // caller supplied a `GaussianFixedCache` we skip the O(N·p²) dense
244 // assembly here and adopt the cached matrix as-is.
245 let weights_owned = weights.to_owned();
246 let xtwx_orig = if let Some(cache) = gaussian_fixed_cache {
247 // Cache hit: weights and design are invariant for Gaussian-Identity
248 // across the outer REML loop, so adopt the precomputed XᵀWX directly
249 // and avoid the O(N·p²) dense assembly entirely.
250 let p = x_original.ncols();
251 if cache.xtwx_orig.nrows() != p || cache.xtwx_orig.ncols() != p {
252 return Err(EstimationError::InvalidInput(format!(
253 "GaussianFixedCache XᵀWX shape {}×{} does not match design p={}",
254 cache.xtwx_orig.nrows(),
255 cache.xtwx_orig.ncols(),
256 p,
257 )));
258 }
259 cache.xtwx_orig.clone()
260 } else {
261 match x_original {
262 // Only materialized dense designs can use the shared dense assembly path.
263 // Lazy operator-backed dense designs route to diag_xtw_x like sparse.
264 DesignMatrix::Dense(x_dense) if x_dense.is_materialized_dense() => {
265 let p = x_dense.ncols();
266 let x_dense = x_dense.to_dense_arc();
267 if workspace.hessian_buf.nrows() != p || workspace.hessian_buf.ncols() != p {
268 workspace.hessian_buf = Array2::zeros((p, p).f());
269 } else {
270 workspace.hessian_buf.fill(0.0);
271 }
272 PirlsWorkspace::add_dense_xtwx_signed(
273 &weights_owned,
274 &mut workspace.weighted_x_chunk,
275 x_dense.as_ref(),
276 &mut workspace.hessian_buf,
277 );
278 std::mem::take(&mut workspace.hessian_buf)
279 }
280 _ => {
281 // Operator-form fallback: sparse designs and lazy operator-backed
282 // dense designs cannot be densified, so route through the signed
283 // XᵀWX operator.
284 gam_linalg::matrix::xt_diag_x_signed(
285 x_original,
286 gam_linalg::matrix::SignedWeightsView::from_array(&weights_owned),
287 )
288 .map(|h| h.to_dense())
289 .map_err(EstimationError::InvalidInput)?
290 }
291 }
292 };
293 let xtwx_orig_asym = max_symmetric_asymmetry(&xtwx_orig);
294 let xtwx_transformed = if let Some(transform) = transform {
295 transform.conjugate_matrix(&xtwx_orig)
296 } else {
297 xtwx_orig
298 };
299 let mut penalized_hessian = xtwx_transformed.clone();
300 penalty.add_to_hessian(&mut penalized_hessian);
301
302 // 3. Form X'Wz: compute in original coordinates, then rotate.
303 // With the Gaussian-Identity cache `z = y` and `wz = W·(y − offset)`
304 // is identical across outer iterations, so reuse the precomputed
305 // `XᵀW(y − offset)` directly.
306 let xtwy_orig = if let Some(cache) = gaussian_fixed_cache {
307 assert_eq!(
308 cache.xtwy_orig.len(),
309 x_original.ncols(),
310 "GaussianFixedCache XᵀW(y−offset) length must match design p"
311 );
312 cache.xtwy_orig.clone()
313 } else {
314 x_original.transpose_vector_multiply(&workspace.wz)
315 };
316 if workspace.vec_buf_p.len() != p_dim {
317 workspace.vec_buf_p = Array1::zeros(p_dim);
318 }
319 if let Some(transform) = transform {
320 workspace
321 .vec_buf_p
322 .assign(&transform.apply_transpose(&xtwy_orig));
323 } else {
324 workspace.vec_buf_p.assign(&xtwy_orig);
325 }
326 workspace.vec_buf_p += penalty.linear_shift();
327
328 {
329 // The penalized Hessian is assembled from symmetric pieces (XᵀWX and
330 // the penalty), so any asymmetry is pure floating-point accumulation
331 // error; anything above this floor signals a genuine assembly bug.
332 const PENALIZED_HESSIAN_ASYMMETRY_TOL: f64 = 1e-8;
333 let xtwx_asym = max_symmetric_asymmetry(&xtwx_transformed);
334 let penalty_asym = match penalty {
335 PirlsPenalty::Dense { s_transformed, .. } => max_symmetric_asymmetry(s_transformed),
336 PirlsPenalty::Diagonal { .. } => 0.0,
337 };
338 let total_asym = max_symmetric_asymmetry(&penalized_hessian);
339 assert!(
340 total_asym <= PENALIZED_HESSIAN_ASYMMETRY_TOL,
341 "implicit PLS penalized Hessian asymmetry too large: total={total_asym:.3e}, xtwx_orig={xtwx_orig_asym:.3e}, xtwx={xtwx_asym:.3e}, penalty={penalty_asym:.3e}, tol={PENALIZED_HESSIAN_ASYMMETRY_TOL:.3e}",
342 );
343 }
344
345 // 4. Ridge stabilization. Augment both sides by the ridge so the
346 // stabilization is a Tikhonov regularization centered at the prior
347 // mean target: (H + δI) β = r + δ μ. The prior_mean_target is zero
348 // when no penalty block carries a non-zero prior mean, so this is a
349 // no-op in the common case but recovers `β = μ` exactly on
350 // X'WX = 0 / X'Wz = 0 problems where the data carries no information.
351 let nugget = FIXED_STABILIZATION_RIDGE;
352 let mut regularizedhessian = penalized_hessian.clone();
353 if nugget > 0.0 {
354 for i in 0..p_dim {
355 regularizedhessian[[i, i]] += nugget;
356 }
357 }
358 let ridge_used = nugget;
359
360 // 5. Solve
361 if workspace.rhs_full.len() != p_dim {
362 workspace.rhs_full = Array1::zeros(p_dim);
363 }
364 workspace.rhs_full.assign(&workspace.vec_buf_p);
365 if nugget > 0.0 {
366 let prior_mean_target = penalty.prior_mean_target();
367 if prior_mean_target.len() == p_dim {
368 workspace.rhs_full.scaled_add(nugget, prior_mean_target);
369 }
370 }
371 let factor = StableSolver::new("pirls implicit pls")
372 .factorize(®ularizedhessian)
373 .map_err(EstimationError::LinearSystemSolveFailed)?;
374 let mut rhsview = array1_to_col_matmut(&mut workspace.rhs_full);
375 factor.solve_in_place(rhsview.as_mut());
376 if !array_is_finite(&workspace.rhs_full) {
377 return Err(EstimationError::LinearSystemSolveFailed(
378 FaerLinalgError::FactorizationFailed {
379 context: "PIRLS implicit PLS non-finite solve",
380 },
381 ));
382 }
383 let betavec = workspace.rhs_full.clone();
384
385 // 6. EDF — reuse the factor already produced in step 5 to avoid a second
386 // O(p³) factorization of the identical regularized Hessian.
387 let edf = calculate_edfwithworkspace_from_factor(&factor, penalty, workspace)?;
388
389 // 7. Scale (composed: eta = offset + X Qs beta). When Gaussian sufficient
390 // statistics are installed, compute RSS from k-space only; the design rows
391 // may be a stale reference surface on the #1033 ψ-tensor fast path.
392 let qbeta = if let Some(transform) = transform {
393 transform.apply(&betavec)
394 } else {
395 betavec.clone()
396 };
397 let standard_deviation = match link_function {
398 LinkFunction::Identity => {
399 let weighted_rss = if let Some(cache) = gaussian_fixed_cache {
400 let quadratic = qbeta.dot(&cache.xtwx_orig.dot(&qbeta));
401 (cache.centered_weighted_y_sq - 2.0 * qbeta.dot(&cache.xtwy_orig) + quadratic)
402 .max(0.0)
403 } else {
404 let xqbeta = x_original.apply(&qbeta);
405 let mut fitted = xqbeta;
406 fitted += &offset;
407 let residuals = &y - &fitted;
408 weights
409 .iter()
410 .zip(residuals.iter())
411 .map(|(&w, &r)| w * r * r)
412 .sum()
413 };
414 let effective_n = y.len() as f64;
415 (weighted_rss / (effective_n - edf).max(1.0)).sqrt()
416 }
417 _ => 1.0,
418 };
419
420 Ok((
421 StablePLSResult {
422 beta: Coefficients::new(betavec),
423 penalized_hessian: SymmetricMatrix::Dense(penalized_hessian),
424 edf,
425 standard_deviation,
426 ridge_used,
427 },
428 p_dim,
429 ))
430}