Skip to main content

gam_solve/pirls/
edf.rs

1use crate::estimate::EstimationError;
2use gam_linalg::faer_ndarray::{FaerSymmetricFactor, array2_to_matmut};
3use gam_linalg::utils::{StableSolver, array_is_finite};
4use gam_linalg::matrix::SymmetricMatrix;
5use gam_problem::Coefficients;
6use ndarray::{Array1, Array2};
7
8use super::{PirlsPenalty, PirlsWorkspace};
9
10/// Result of the stable penalized least squares solve
11#[derive(Clone)]
12pub struct StablePLSResult {
13    /// Solution vector beta
14    pub beta: Coefficients,
15    /// Final penalized Hessian matrix (sparse or dense depending on solve path)
16    pub penalized_hessian: SymmetricMatrix,
17    /// Effective degrees of freedom
18    pub edf: f64,
19    /// Residual standard deviation estimate.
20    ///
21    /// Contract: for Gaussian identity models this is the residual standard
22    /// deviation (sigma), not the residual variance/dispersion.
23    pub standard_deviation: f64,
24    /// Ridge added to ensure the SPD solve is well-posed.
25    pub ridge_used: f64,
26}
27
28/// EDF from an already-factorized dense regularized Hessian (dense path).
29///
30/// Mirrors `calculate_edfwithworkspace_with_penalty` but accepts the
31/// `FaerSymmetricFactor` that PLS already produced, eliminating the redundant
32/// second O(p³) factorization inside every PIRLS outer iteration.
33pub(super) fn calculate_edfwithworkspace_from_factor(
34    factor: &FaerSymmetricFactor,
35    penalty: &PirlsPenalty,
36    workspace: &mut PirlsWorkspace,
37) -> Result<f64, EstimationError> {
38    match penalty {
39        PirlsPenalty::Dense { e_transformed, .. } => {
40            let p = factor.n();
41            let r = e_transformed.nrows();
42            let mp = (p as f64 - r as f64).max(0.0);
43            if r == 0 {
44                return Ok(p as f64);
45            }
46            if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
47                workspace.final_aug_matrix = Array2::zeros((p, r));
48            }
49            for j in 0..r {
50                for i in 0..p {
51                    workspace.final_aug_matrix[[i, j]] = e_transformed[[j, i]];
52                }
53            }
54            {
55                let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
56                factor.solve_in_place(rhsview.as_mut());
57            }
58            if workspace.final_aug_matrix.nrows() == p
59                && workspace.final_aug_matrix.ncols() == r
60                && array_is_finite(&workspace.final_aug_matrix)
61            {
62                return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
63                    workspace.final_aug_matrix[(i, j)]
64                }));
65            }
66            Err(EstimationError::ModelIsIllConditioned {
67                condition_number: f64::INFINITY,
68            })
69        }
70        PirlsPenalty::Diagonal {
71            diag,
72            positive_indices,
73            ..
74        } => {
75            let p = factor.n();
76            let r = positive_indices.len();
77            let mp = (p as f64 - r as f64).max(0.0);
78            if r == 0 {
79                return Ok(p as f64);
80            }
81            if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
82                workspace.final_aug_matrix = Array2::zeros((p, r));
83            } else {
84                workspace.final_aug_matrix.fill(0.0);
85            }
86            for (col, &idx) in positive_indices.iter().enumerate() {
87                workspace.final_aug_matrix[[idx, col]] = 1.0;
88            }
89            {
90                let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
91                factor.solve_in_place(rhsview.as_mut());
92            }
93            let mut tr = 0.0;
94            for (col, &idx) in positive_indices.iter().enumerate() {
95                tr += diag[idx] * workspace.final_aug_matrix[[idx, col]];
96            }
97            Ok((p as f64 - tr).clamp(mp, p as f64))
98        }
99    }
100}
101
102/// EDF from an already-factorized sparse penalized Hessian (sparse path).
103///
104/// Mirrors `calculate_edf_with_penalty` but accepts the `SparseExactFactor`
105/// that PLS already produced, eliminating the redundant second sparse
106/// factorization inside every PIRLS outer iteration.
107///
108/// Only the `PirlsPenalty::Dense` variant is handled because the sparse-native
109/// path requires `PirlsPenalty::Dense` (enforced by the caller).
110pub(super) fn calculate_edf_from_sparse_factor(
111    factor: &gam_linalg::sparse_exact::SparseExactFactor,
112    penalty: &PirlsPenalty,
113) -> Result<f64, EstimationError> {
114    let PirlsPenalty::Dense { e_transformed, .. } = penalty else {
115        crate::bail_invalid_estim!("calculate_edf_from_sparse_factor requires PirlsPenalty::Dense");
116    };
117    // e_transformed has shape (r, p) — cols give the coefficient dimension p.
118    let p = e_transformed.ncols();
119    let r = e_transformed.nrows();
120    let mp = (p as f64 - r as f64).max(0.0);
121    if r == 0 {
122        return Ok(p as f64);
123    }
124    let rhs_arr = e_transformed.t().to_owned();
125    let sol =
126        gam_linalg::sparse_exact::solve_sparse_spdmulti(factor, &rhs_arr).map_err(|_| {
127            EstimationError::ModelIsIllConditioned {
128                condition_number: f64::INFINITY,
129            }
130        })?;
131    if sol.nrows() == p && sol.ncols() == r && sol.iter().all(|v| v.is_finite()) {
132        return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
133            sol[[i, j]]
134        }));
135    }
136    Err(EstimationError::ModelIsIllConditioned {
137        condition_number: f64::INFINITY,
138    })
139}
140
141pub(super) fn calculate_edf(
142    penalized_hessian: &SymmetricMatrix,
143    e_transformed: &Array2<f64>,
144) -> Result<f64, EstimationError> {
145    let p = penalized_hessian.ncols();
146    let r = e_transformed.nrows();
147    let mp = (p as f64 - r as f64).max(0.0);
148    if r == 0 {
149        return Ok(p as f64);
150    }
151    let rhs_arr = e_transformed.t().to_owned();
152    // Use SymmetricMatrix::factorize() which dispatches to sparse Cholesky
153    // for sparse Hessians and dense Cholesky for dense ones.
154    let factor =
155        penalized_hessian
156            .factorize()
157            .map_err(|_| EstimationError::ModelIsIllConditioned {
158                condition_number: f64::INFINITY,
159            })?;
160    let sol = factor
161        .solvemulti(&rhs_arr)
162        .map_err(|_| EstimationError::ModelIsIllConditioned {
163            condition_number: f64::INFINITY,
164        })?;
165    if sol.nrows() == p && sol.ncols() == r && sol.iter().all(|v| v.is_finite()) {
166        return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
167            sol[[i, j]]
168        }));
169    }
170
171    Err(EstimationError::ModelIsIllConditioned {
172        condition_number: f64::INFINITY,
173    })
174}
175
176pub(super) fn calculate_edf_with_penalty(
177    penalized_hessian: &SymmetricMatrix,
178    penalty: &PirlsPenalty,
179) -> Result<f64, EstimationError> {
180    match penalty {
181        PirlsPenalty::Dense { e_transformed, .. } => {
182            calculate_edf(penalized_hessian, e_transformed)
183        }
184        PirlsPenalty::Diagonal {
185            diag,
186            positive_indices,
187            ..
188        } => calculate_edf_from_diagonal_penalty(penalized_hessian, diag, positive_indices),
189    }
190}
191
192pub(super) fn calculate_edfwithworkspace(
193    penalized_hessian: &Array2<f64>,
194    e_transformed: &Array2<f64>,
195    workspace: &mut PirlsWorkspace,
196) -> Result<f64, EstimationError> {
197    let p = penalized_hessian.ncols();
198    let r = e_transformed.nrows();
199    let mp = (p as f64 - r as f64).max(0.0);
200    if r == 0 {
201        return Ok(p as f64);
202    }
203    if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
204        workspace.final_aug_matrix = Array2::zeros((p, r));
205    }
206    for j in 0..r {
207        for i in 0..p {
208            workspace.final_aug_matrix[[i, j]] = e_transformed[[j, i]];
209        }
210    }
211
212    let factor = StableSolver::new("pirls edf workspace")
213        .factorize(penalized_hessian)
214        .map_err(|_| EstimationError::ModelIsIllConditioned {
215            condition_number: f64::INFINITY,
216        })?;
217    {
218        let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
219        factor.solve_in_place(rhsview.as_mut());
220    }
221    if workspace.final_aug_matrix.nrows() == p
222        && workspace.final_aug_matrix.ncols() == r
223        && array_is_finite(&workspace.final_aug_matrix)
224    {
225        return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
226            workspace.final_aug_matrix[(i, j)]
227        }));
228    }
229
230    Err(EstimationError::ModelIsIllConditioned {
231        condition_number: f64::INFINITY,
232    })
233}
234
235pub(super) fn calculate_edfwithworkspace_with_penalty(
236    penalized_hessian: &Array2<f64>,
237    penalty: &PirlsPenalty,
238    workspace: &mut PirlsWorkspace,
239) -> Result<f64, EstimationError> {
240    match penalty {
241        PirlsPenalty::Dense { e_transformed, .. } => {
242            calculate_edfwithworkspace(penalized_hessian, e_transformed, workspace)
243        }
244        PirlsPenalty::Diagonal {
245            diag,
246            positive_indices,
247            ..
248        } => calculate_edfwithworkspace_from_diagonal_penalty(
249            penalized_hessian,
250            diag,
251            positive_indices,
252            workspace,
253        ),
254    }
255}
256
257pub(super) fn calculate_edf_from_diagonal_penalty(
258    penalized_hessian: &SymmetricMatrix,
259    diag: &Array1<f64>,
260    positive_indices: &[usize],
261) -> Result<f64, EstimationError> {
262    let p = penalized_hessian.ncols();
263    let r = positive_indices.len();
264    let mp = (p as f64 - r as f64).max(0.0);
265    if r == 0 {
266        return Ok(p as f64);
267    }
268    let mut rhs_arr = Array2::<f64>::zeros((p, r));
269    for (col, &idx) in positive_indices.iter().enumerate() {
270        rhs_arr[[idx, col]] = 1.0;
271    }
272    let factor =
273        penalized_hessian
274            .factorize()
275            .map_err(|_| EstimationError::ModelIsIllConditioned {
276                condition_number: f64::INFINITY,
277            })?;
278    let sol = factor
279        .solvemulti(&rhs_arr)
280        .map_err(|_| EstimationError::ModelIsIllConditioned {
281            condition_number: f64::INFINITY,
282        })?;
283    let mut tr = 0.0;
284    for (col, &idx) in positive_indices.iter().enumerate() {
285        tr += diag[idx] * sol[[idx, col]];
286    }
287    Ok((p as f64 - tr).clamp(mp, p as f64))
288}
289
290pub(super) fn calculate_edfwithworkspace_from_diagonal_penalty(
291    penalized_hessian: &Array2<f64>,
292    diag: &Array1<f64>,
293    positive_indices: &[usize],
294    workspace: &mut PirlsWorkspace,
295) -> Result<f64, EstimationError> {
296    let p = penalized_hessian.ncols();
297    let r = positive_indices.len();
298    let mp = (p as f64 - r as f64).max(0.0);
299    if r == 0 {
300        return Ok(p as f64);
301    }
302    if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
303        workspace.final_aug_matrix = Array2::zeros((p, r));
304    } else {
305        workspace.final_aug_matrix.fill(0.0);
306    }
307    for (col, &idx) in positive_indices.iter().enumerate() {
308        workspace.final_aug_matrix[[idx, col]] = 1.0;
309    }
310
311    let factor = StableSolver::new("pirls diagonal edf workspace")
312        .factorize(penalized_hessian)
313        .map_err(|_| EstimationError::ModelIsIllConditioned {
314            condition_number: f64::INFINITY,
315        })?;
316    {
317        let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
318        factor.solve_in_place(rhsview.as_mut());
319    }
320    let mut tr = 0.0;
321    for (col, &idx) in positive_indices.iter().enumerate() {
322        tr += diag[idx] * workspace.final_aug_matrix[[idx, col]];
323    }
324    Ok((p as f64 - tr).clamp(mp, p as f64))
325}
326
327#[inline]
328pub(super) fn edf_from_solution<F>(
329    p: usize,
330    r: usize,
331    mp: f64,
332    e_transformed: &Array2<f64>,
333    solved_at: F,
334) -> f64
335where
336    F: Fn(usize, usize) -> f64,
337{
338    let mut tr = 0.0;
339    for j in 0..r {
340        for i in 0..p {
341            tr += solved_at(i, j) * e_transformed[(j, i)];
342        }
343    }
344    (p as f64 - tr).clamp(mp, p as f64)
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use gam_linalg::matrix::SymmetricMatrix;
351    use ndarray::array;
352
353    /// Regression: a penalty with MORE rows than coefficient columns (`r > p`)
354    /// is legitimate for factor-smooth / random-slope / random-effect
355    /// structures whose penalty roots are stacked or full-rank. The
356    /// min-penalty-dof floor `mp = max(p - r, 0)` must be computed in `f64`,
357    /// because the `usize` subtraction `p - r` underflows and panics with
358    /// "attempt to subtract with overflow" when `r > p`. This exercises the
359    /// `r > p` path on the dense EDF entry point and asserts that the floor is
360    /// honored (no panic, finite EDF in `[0, p]`).
361    #[test]
362    pub(crate) fn calculate_edf_floors_when_penalty_rank_exceeds_coefficient_dim() {
363        // p = 2 coefficients, r = 3 penalty rows (r > p).
364        let p = 2usize;
365        // SPD penalized Hessian (well-conditioned, dense path).
366        let hessian = SymmetricMatrix::Dense(array![[4.0, 1.0], [1.0, 3.0]]);
367        // e_transformed has shape (r, p) = (3, 2): more penalty rows than
368        // coefficient columns — the factor/random-slope structure.
369        let e_transformed = array![[1.0, 0.0], [0.0, 1.0], [0.5, 0.5]];
370        assert_eq!(e_transformed.nrows(), 3);
371        assert_eq!(e_transformed.ncols(), p);
372
373        let edf = calculate_edf(&hessian, &e_transformed)
374            .expect("EDF solve should succeed for an SPD Hessian with r > p");
375
376        // mp = max(p - r, 0) = 0, so the EDF is floored at 0 and capped at p.
377        assert!(
378            edf.is_finite(),
379            "EDF must be finite for r > p penalty, got {edf}"
380        );
381        assert!(
382            (0.0..=p as f64).contains(&edf),
383            "EDF must lie in [0, {p}] for r > p penalty, got {edf}"
384        );
385    }
386}