Skip to main content

gam_solve/pirls/
pls_solver.rs

1//! Penalized least-squares solver and Gaussian fast paths.
2//!
3//! Owns:
4//! - `GaussianFixedCache` — `XᵀWX`/`XᵀW(y−offset)` cache for the
5//!   Gaussian-Identity short-circuit that the REML outer loop reuses across
6//!   smoothing-parameter candidates.
7//! - `SparseXtwxPrecomputed` — the sparse-pattern-aligned twin of the above
8//!   for designs that take the sparse-native PIRLS path.
9//! - `solve_penalized_least_squares_implicit` — identity/Gaussian implicit
10//!   PLS, dense and sparse-native paths.
11
12use super::loop_driver::max_symmetric_asymmetry;
13use super::{
14    FIXED_STABILIZATION_RIDGE, PirlsPenalty, PirlsWorkspace, SparseXtWxCache, StablePLSResult,
15    WorkingReparamTransform, calculate_edf_from_sparse_factor,
16    calculate_edfwithworkspace_from_factor, ensure_sparse_positive_definitewithridge,
17    solve_sparse_spd,
18};
19use crate::estimate::EstimationError;
20use gam_linalg::faer_ndarray::{FaerLinalgError, array1_to_col_matmut};
21use gam_linalg::utils::{StableSolver, array_is_finite};
22use gam_linalg::matrix::{DesignMatrix, LinearOperator, SymmetricMatrix};
23use gam_problem::{Coefficients, LinkFunction};
24use faer::sparse::SparseColMat;
25use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder};
26use std::sync::Arc;
27
28/// Reusable `XᵀWX` and `XᵀW(y − offset)` for Gaussian + Identity REML fits.
29///
30/// The Gaussian-identity P-IRLS short-circuit solves a single linear system
31/// `(XᵀWX + Σ λ_k S_k + ρ·I) β = XᵀW(y − offset)`. The right-hand-side matrix
32/// and vector are independent of the smoothing parameters `λ`, so when the
33/// outer REML loop evaluates the same problem at many `(λ_1, …, λ_k)`
34/// candidates we only need to assemble them **once** before the loop and
35/// reuse them inside every inner PIRLS call.
36///
37/// Stored in *original* coordinates (no Qs rotation applied). When the
38/// inner solver uses a `WorkingReparamTransform`, it conjugates / projects
39/// these matrices on the fly — that step is O(p³) / O(p²), independent of N.
40#[derive(Debug)]
41pub struct GaussianFixedCache {
42    /// `XᵀWX` in the original coefficient basis. Symmetric, p × p.
43    pub xtwx_orig: Array2<f64>,
44    /// `XᵀW(y − offset)` in the original basis. Length p.
45    pub xtwy_orig: Array1<f64>,
46    /// `(y − offset)ᵀW(y − offset)`.
47    ///
48    /// Together with `xtwx_orig` and `xtwy_orig`, this is the last scalar
49    /// sufficient statistic needed to evaluate the Gaussian penalized RSS
50    /// exactly at any λ without re-streaming the rows.
51    pub centered_weighted_y_sq: f64,
52    /// When true, the caller is deliberately serving a design-moving trial from
53    /// sufficient statistics and the `DesignMatrix` rows on the current REML
54    /// surface may be a stale reference surface. Consumers must not apply those
55    /// rows for fitted values, RSS, or likelihood summaries.
56    pub row_prediction_is_stale: bool,
57    /// `XᵀWX` precomputed for the sparse path, aligned with the symbolic
58    /// pattern of `SparseXtWxCache::new(x)` on the original sparse design.
59    /// `None` when the design has no sparse form (e.g. dense-only fits).
60    ///
61    /// The sparse REML path rebuilds `H = XᵀWX + Sλ + δI` per outer
62    /// evaluation. For Gaussian-Identity the weights never change, so the
63    /// `XᵀWX` contribution is invariant across the outer loop and can be
64    /// scattered from this cached values vector instead of re-doing the
65    /// O(nnz²/n) SpGEMM each call.
66    pub xtwx_sparse_orig: Option<Arc<SparseXtwxPrecomputed>>,
67}
68
69/// Precomputed numerical values of `XᵀWX` aligned with the symbolic pattern
70/// that `SparseXtWxCache::new(x)` produces on its first call. Two such caches
71/// built from the same sparse `x` produce byte-identical symbolic patterns
72/// (faer's `sparse_sparse_matmul_symbolic` is deterministic), so the cached
73/// values can be installed back into a fresh `SparseXtWxCache` for the same
74/// `x` without rerunning the SpGEMM.
75///
76/// We snapshot the symbolic pattern (`col_ptr` / `row_idx`) alongside the
77/// values so the consumer can verify pattern equivalence and fall through to
78/// the per-call recomputation if anything diverges (e.g. an `x` with a
79/// different symbolic shape sneaks in).
80#[derive(Debug, Clone)]
81pub struct SparseXtwxPrecomputed {
82    pub xtwx_symbolic_col_ptr: Vec<usize>,
83    pub xtwx_symbolic_row_idx: Vec<usize>,
84    pub xtwxvalues: Vec<f64>,
85}
86
87impl SparseXtwxPrecomputed {
88    /// Build the precomputed `XᵀWX` value layout for `x` at the given
89    /// `weights`. The output reuses the same construction path the inner
90    /// PIRLS workspace uses, so it lands in exactly the symbolic pattern
91    /// the consumer expects.
92    pub fn build(
93        x: &SparseColMat<usize, f64>,
94        weights: &Array1<f64>,
95    ) -> Result<Self, EstimationError> {
96        let mut cache = SparseXtWxCache::new(x)?;
97        cache.compute_numeric(x, weights)?;
98        Ok(Self {
99            xtwx_symbolic_col_ptr: cache.xtwx_symbolic.col_ptr().to_vec(),
100            xtwx_symbolic_row_idx: cache.xtwx_symbolic.row_idx().to_vec(),
101            xtwxvalues: cache.xtwxvalues,
102        })
103    }
104}
105
106/// Identity-link solver that operates in original or QS-transformed coordinates
107/// without materializing X·Qs.  When the design is sparse and `qs` is `None`
108/// (sparse-native path), uses sparse Cholesky for O(nnz^{1.5}) cost instead
109/// of the O(p³) dense Cholesky.
110pub(super) fn solve_penalized_least_squares_implicit(
111    x_original: &DesignMatrix,
112    transform: Option<&WorkingReparamTransform>,
113    z: ArrayView1<f64>,
114    weights: ArrayView1<f64>,
115    offset: ArrayView1<f64>,
116    penalty: &PirlsPenalty,
117    workspace: &mut PirlsWorkspace,
118    y: ArrayView1<f64>,
119    link_function: LinkFunction,
120    gaussian_fixed_cache: Option<&GaussianFixedCache>,
121) -> Result<(StablePLSResult, usize), EstimationError> {
122    let p_dim = penalty.dim();
123
124    // ── Sparse-native fast path ──────────────────────────────────────────
125    // When design is sparse and we are in original coordinates (qs = None),
126    // assemble the penalized Hessian in sparse format and solve with sparse
127    // Cholesky.  This avoids O(p²) dense X'WX and O(p³) dense factorization.
128    if transform.is_none()
129        && let Some(x_sparse) = x_original.as_sparse()
130    {
131        let PirlsPenalty::Dense { s_transformed, .. } = penalty else {
132            crate::bail_invalid_estim!(
133                "sparse-native PIRLS requires a dense transformed penalty matrix"
134            );
135        };
136        let weights_owned = weights.to_owned();
137
138        // Gaussian-Identity fast path: the inner sparse `XᵀWX` is invariant
139        // across the outer REML loop because the IRLS weights are constant
140        // (W = priorweights). The cached values land in the inner workspace
141        // and bypass the per-eval SpGEMM.
142        let precomputed_xtwx =
143            gaussian_fixed_cache.and_then(|c| c.xtwx_sparse_orig.as_ref().map(|arc| arc.as_ref()));
144
145        // 1. Sparse penalized Hessian: H = X'diag(w)X + S_λ + ridge·I.
146        //    The Cholesky factor is reused from the SPD check so we avoid
147        //    factorizing the same matrix twice.
148        let (h_sparse, factor, ridge_used) = ensure_sparse_positive_definitewithridge(|ridge| {
149            let ridge = if ridge == 0.0 {
150                FIXED_STABILIZATION_RIDGE
151            } else {
152                ridge
153            };
154            workspace.assemble_sparse_penalized_hessian(
155                x_sparse,
156                &weights_owned,
157                s_transformed,
158                ridge,
159                precomputed_xtwx,
160            )
161        })?;
162
163        // 2. RHS = X'W(z - offset) + S_λ μ + ridge_used · μ.
164        // The `ridge_used · μ` term matches the diagonal ridge added to
165        // the Hessian in step 1, keeping the augmented system a
166        // Tikhonov regularization centered at the prior mean target
167        // rather than at zero (see `prior_mean_target` field docs).
168        let mut wz = z.to_owned();
169        wz -= &offset;
170        wz *= &weights_owned;
171        let mut rhs = x_original.transpose_vector_multiply(&wz);
172        rhs += penalty.linear_shift();
173        if ridge_used > 0.0 {
174            let prior_mean_target = penalty.prior_mean_target();
175            if prior_mean_target.len() == rhs.len() {
176                rhs.scaled_add(ridge_used, prior_mean_target);
177            }
178        }
179
180        // 3. Sparse Cholesky solve (factor reused from step 1)
181        let betavec = solve_sparse_spd(&factor, &rhs)?;
182
183        // 4. EDF — reuse the sparse Cholesky factor from step 1 to avoid a
184        // second O(nnz·…) factorization of the identical penalized Hessian.
185        let h_sym = SymmetricMatrix::Sparse(h_sparse);
186        let edf = calculate_edf_from_sparse_factor(&factor, penalty)?;
187
188        // 5. Scale. When Gaussian sufficient statistics are installed, compute
189        // RSS from k-space only; the design rows may be a stale reference
190        // surface on the #1033 ψ-tensor fast path.
191        let standard_deviation = match link_function {
192            LinkFunction::Identity => {
193                let weighted_rss = if let Some(cache) = gaussian_fixed_cache {
194                    let quadratic = betavec.dot(&cache.xtwx_orig.dot(&betavec));
195                    (cache.centered_weighted_y_sq - 2.0 * betavec.dot(&cache.xtwy_orig) + quadratic)
196                        .max(0.0)
197                } else {
198                    let fitted_vals = {
199                        let xb = x_original.apply(&betavec);
200                        let mut f = xb;
201                        f += &offset;
202                        f
203                    };
204                    let residuals = &y - &fitted_vals;
205                    weights
206                        .iter()
207                        .zip(residuals.iter())
208                        .map(|(&w, &r)| w * r * r)
209                        .sum()
210                };
211                let effective_n = y.len() as f64;
212                (weighted_rss / (effective_n - edf).max(1.0)).sqrt()
213            }
214            _ => 1.0,
215        };
216
217        return Ok((
218            StablePLSResult {
219                beta: Coefficients::new(betavec),
220                penalized_hessian: h_sym,
221                edf,
222                standard_deviation,
223                ridge_used,
224            },
225            p_dim,
226        ));
227    }
228
229    // ── Dense / QS-rotated path ──────────────────────────────────────────
230
231    // 1. Prepare weighted buffers
232    if workspace.wz.len() != z.len() {
233        workspace.wz = Array1::zeros(z.len());
234    }
235    workspace.wz.assign(&z);
236    workspace.wz -= &offset;
237    workspace.wz *= &weights;
238
239    // 2. Form X'WX: compute in original coordinates, then rotate by Qs.
240    //
241    // Gaussian + Identity REML reuses a precomputed `XᵀWX` (the weights and
242    // design never change across the outer loop in that family), so when the
243    // caller supplied a `GaussianFixedCache` we skip the O(N·p²) dense
244    // assembly here and adopt the cached matrix as-is.
245    let weights_owned = weights.to_owned();
246    let xtwx_orig = if let Some(cache) = gaussian_fixed_cache {
247        // Cache hit: weights and design are invariant for Gaussian-Identity
248        // across the outer REML loop, so adopt the precomputed XᵀWX directly
249        // and avoid the O(N·p²) dense assembly entirely.
250        let p = x_original.ncols();
251        if cache.xtwx_orig.nrows() != p || cache.xtwx_orig.ncols() != p {
252            return Err(EstimationError::InvalidInput(format!(
253                "GaussianFixedCache XᵀWX shape {}×{} does not match design p={}",
254                cache.xtwx_orig.nrows(),
255                cache.xtwx_orig.ncols(),
256                p,
257            )));
258        }
259        cache.xtwx_orig.clone()
260    } else {
261        match x_original {
262            // Only materialized dense designs can use the shared dense assembly path.
263            // Lazy operator-backed dense designs route to diag_xtw_x like sparse.
264            DesignMatrix::Dense(x_dense) if x_dense.is_materialized_dense() => {
265                let p = x_dense.ncols();
266                let x_dense = x_dense.to_dense_arc();
267                if workspace.hessian_buf.nrows() != p || workspace.hessian_buf.ncols() != p {
268                    workspace.hessian_buf = Array2::zeros((p, p).f());
269                } else {
270                    workspace.hessian_buf.fill(0.0);
271                }
272                PirlsWorkspace::add_dense_xtwx_signed(
273                    &weights_owned,
274                    &mut workspace.weighted_x_chunk,
275                    x_dense.as_ref(),
276                    &mut workspace.hessian_buf,
277                );
278                std::mem::take(&mut workspace.hessian_buf)
279            }
280            _ => {
281                // Operator-form fallback: sparse designs and lazy operator-backed
282                // dense designs cannot be densified, so route through the signed
283                // XᵀWX operator.
284                gam_linalg::matrix::xt_diag_x_signed(
285                    x_original,
286                    gam_linalg::matrix::SignedWeightsView::from_array(&weights_owned),
287                )
288                .map(|h| h.to_dense())
289                .map_err(EstimationError::InvalidInput)?
290            }
291        }
292    };
293    let xtwx_orig_asym = max_symmetric_asymmetry(&xtwx_orig);
294    let xtwx_transformed = if let Some(transform) = transform {
295        transform.conjugate_matrix(&xtwx_orig)
296    } else {
297        xtwx_orig
298    };
299    let mut penalized_hessian = xtwx_transformed.clone();
300    penalty.add_to_hessian(&mut penalized_hessian);
301
302    // 3. Form X'Wz: compute in original coordinates, then rotate.
303    //    With the Gaussian-Identity cache `z = y` and `wz = W·(y − offset)`
304    //    is identical across outer iterations, so reuse the precomputed
305    //    `XᵀW(y − offset)` directly.
306    let xtwy_orig = if let Some(cache) = gaussian_fixed_cache {
307        assert_eq!(
308            cache.xtwy_orig.len(),
309            x_original.ncols(),
310            "GaussianFixedCache XᵀW(y−offset) length must match design p"
311        );
312        cache.xtwy_orig.clone()
313    } else {
314        x_original.transpose_vector_multiply(&workspace.wz)
315    };
316    if workspace.vec_buf_p.len() != p_dim {
317        workspace.vec_buf_p = Array1::zeros(p_dim);
318    }
319    if let Some(transform) = transform {
320        workspace
321            .vec_buf_p
322            .assign(&transform.apply_transpose(&xtwy_orig));
323    } else {
324        workspace.vec_buf_p.assign(&xtwy_orig);
325    }
326    workspace.vec_buf_p += penalty.linear_shift();
327
328    {
329        // The penalized Hessian is assembled from symmetric pieces (XᵀWX and
330        // the penalty), so any asymmetry is pure floating-point accumulation
331        // error; anything above this floor signals a genuine assembly bug.
332        const PENALIZED_HESSIAN_ASYMMETRY_TOL: f64 = 1e-8;
333        let xtwx_asym = max_symmetric_asymmetry(&xtwx_transformed);
334        let penalty_asym = match penalty {
335            PirlsPenalty::Dense { s_transformed, .. } => max_symmetric_asymmetry(s_transformed),
336            PirlsPenalty::Diagonal { .. } => 0.0,
337        };
338        let total_asym = max_symmetric_asymmetry(&penalized_hessian);
339        assert!(
340            total_asym <= PENALIZED_HESSIAN_ASYMMETRY_TOL,
341            "implicit PLS penalized Hessian asymmetry too large: total={total_asym:.3e}, xtwx_orig={xtwx_orig_asym:.3e}, xtwx={xtwx_asym:.3e}, penalty={penalty_asym:.3e}, tol={PENALIZED_HESSIAN_ASYMMETRY_TOL:.3e}",
342        );
343    }
344
345    // 4. Ridge stabilization. Augment both sides by the ridge so the
346    // stabilization is a Tikhonov regularization centered at the prior
347    // mean target: (H + δI) β = r + δ μ. The prior_mean_target is zero
348    // when no penalty block carries a non-zero prior mean, so this is a
349    // no-op in the common case but recovers `β = μ` exactly on
350    // X'WX = 0 / X'Wz = 0 problems where the data carries no information.
351    let nugget = FIXED_STABILIZATION_RIDGE;
352    let mut regularizedhessian = penalized_hessian.clone();
353    if nugget > 0.0 {
354        for i in 0..p_dim {
355            regularizedhessian[[i, i]] += nugget;
356        }
357    }
358    let ridge_used = nugget;
359
360    // 5. Solve
361    if workspace.rhs_full.len() != p_dim {
362        workspace.rhs_full = Array1::zeros(p_dim);
363    }
364    workspace.rhs_full.assign(&workspace.vec_buf_p);
365    if nugget > 0.0 {
366        let prior_mean_target = penalty.prior_mean_target();
367        if prior_mean_target.len() == p_dim {
368            workspace.rhs_full.scaled_add(nugget, prior_mean_target);
369        }
370    }
371    let factor = StableSolver::new("pirls implicit pls")
372        .factorize(&regularizedhessian)
373        .map_err(EstimationError::LinearSystemSolveFailed)?;
374    let mut rhsview = array1_to_col_matmut(&mut workspace.rhs_full);
375    factor.solve_in_place(rhsview.as_mut());
376    if !array_is_finite(&workspace.rhs_full) {
377        return Err(EstimationError::LinearSystemSolveFailed(
378            FaerLinalgError::FactorizationFailed {
379                context: "PIRLS implicit PLS non-finite solve",
380            },
381        ));
382    }
383    let betavec = workspace.rhs_full.clone();
384
385    // 6. EDF — reuse the factor already produced in step 5 to avoid a second
386    // O(p³) factorization of the identical regularized Hessian.
387    let edf = calculate_edfwithworkspace_from_factor(&factor, penalty, workspace)?;
388
389    // 7. Scale (composed: eta = offset + X Qs beta). When Gaussian sufficient
390    // statistics are installed, compute RSS from k-space only; the design rows
391    // may be a stale reference surface on the #1033 ψ-tensor fast path.
392    let qbeta = if let Some(transform) = transform {
393        transform.apply(&betavec)
394    } else {
395        betavec.clone()
396    };
397    let standard_deviation = match link_function {
398        LinkFunction::Identity => {
399            let weighted_rss = if let Some(cache) = gaussian_fixed_cache {
400                let quadratic = qbeta.dot(&cache.xtwx_orig.dot(&qbeta));
401                (cache.centered_weighted_y_sq - 2.0 * qbeta.dot(&cache.xtwy_orig) + quadratic)
402                    .max(0.0)
403            } else {
404                let xqbeta = x_original.apply(&qbeta);
405                let mut fitted = xqbeta;
406                fitted += &offset;
407                let residuals = &y - &fitted;
408                weights
409                    .iter()
410                    .zip(residuals.iter())
411                    .map(|(&w, &r)| w * r * r)
412                    .sum()
413            };
414            let effective_n = y.len() as f64;
415            (weighted_rss / (effective_n - edf).max(1.0)).sqrt()
416        }
417        _ => 1.0,
418    };
419
420    Ok((
421        StablePLSResult {
422            beta: Coefficients::new(betavec),
423            penalized_hessian: SymmetricMatrix::Dense(penalized_hessian),
424            edf,
425            standard_deviation,
426            ridge_used,
427        },
428        p_dim,
429    ))
430}