1use crate::estimate::EstimationError;
2use gam_linalg::faer_ndarray::{FaerSymmetricFactor, array2_to_matmut};
3use gam_linalg::utils::{StableSolver, array_is_finite};
4use gam_linalg::matrix::SymmetricMatrix;
5use gam_problem::Coefficients;
6use ndarray::{Array1, Array2};
7
8use super::{PirlsPenalty, PirlsWorkspace};
9
10#[derive(Clone)]
12pub struct StablePLSResult {
13 pub beta: Coefficients,
15 pub penalized_hessian: SymmetricMatrix,
17 pub edf: f64,
19 pub standard_deviation: f64,
24 pub ridge_used: f64,
26}
27
28pub(super) fn calculate_edfwithworkspace_from_factor(
34 factor: &FaerSymmetricFactor,
35 penalty: &PirlsPenalty,
36 workspace: &mut PirlsWorkspace,
37) -> Result<f64, EstimationError> {
38 match penalty {
39 PirlsPenalty::Dense { e_transformed, .. } => {
40 let p = factor.n();
41 let r = e_transformed.nrows();
42 let mp = (p as f64 - r as f64).max(0.0);
43 if r == 0 {
44 return Ok(p as f64);
45 }
46 if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
47 workspace.final_aug_matrix = Array2::zeros((p, r));
48 }
49 for j in 0..r {
50 for i in 0..p {
51 workspace.final_aug_matrix[[i, j]] = e_transformed[[j, i]];
52 }
53 }
54 {
55 let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
56 factor.solve_in_place(rhsview.as_mut());
57 }
58 if workspace.final_aug_matrix.nrows() == p
59 && workspace.final_aug_matrix.ncols() == r
60 && array_is_finite(&workspace.final_aug_matrix)
61 {
62 return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
63 workspace.final_aug_matrix[(i, j)]
64 }));
65 }
66 Err(EstimationError::ModelIsIllConditioned {
67 condition_number: f64::INFINITY,
68 })
69 }
70 PirlsPenalty::Diagonal {
71 diag,
72 positive_indices,
73 ..
74 } => {
75 let p = factor.n();
76 let r = positive_indices.len();
77 let mp = (p as f64 - r as f64).max(0.0);
78 if r == 0 {
79 return Ok(p as f64);
80 }
81 if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
82 workspace.final_aug_matrix = Array2::zeros((p, r));
83 } else {
84 workspace.final_aug_matrix.fill(0.0);
85 }
86 for (col, &idx) in positive_indices.iter().enumerate() {
87 workspace.final_aug_matrix[[idx, col]] = 1.0;
88 }
89 {
90 let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
91 factor.solve_in_place(rhsview.as_mut());
92 }
93 let mut tr = 0.0;
94 for (col, &idx) in positive_indices.iter().enumerate() {
95 tr += diag[idx] * workspace.final_aug_matrix[[idx, col]];
96 }
97 Ok((p as f64 - tr).clamp(mp, p as f64))
98 }
99 }
100}
101
102pub(super) fn calculate_edf_from_sparse_factor(
111 factor: &gam_linalg::sparse_exact::SparseExactFactor,
112 penalty: &PirlsPenalty,
113) -> Result<f64, EstimationError> {
114 let PirlsPenalty::Dense { e_transformed, .. } = penalty else {
115 crate::bail_invalid_estim!("calculate_edf_from_sparse_factor requires PirlsPenalty::Dense");
116 };
117 let p = e_transformed.ncols();
119 let r = e_transformed.nrows();
120 let mp = (p as f64 - r as f64).max(0.0);
121 if r == 0 {
122 return Ok(p as f64);
123 }
124 let rhs_arr = e_transformed.t().to_owned();
125 let sol =
126 gam_linalg::sparse_exact::solve_sparse_spdmulti(factor, &rhs_arr).map_err(|_| {
127 EstimationError::ModelIsIllConditioned {
128 condition_number: f64::INFINITY,
129 }
130 })?;
131 if sol.nrows() == p && sol.ncols() == r && sol.iter().all(|v| v.is_finite()) {
132 return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
133 sol[[i, j]]
134 }));
135 }
136 Err(EstimationError::ModelIsIllConditioned {
137 condition_number: f64::INFINITY,
138 })
139}
140
141pub(super) fn calculate_edf(
142 penalized_hessian: &SymmetricMatrix,
143 e_transformed: &Array2<f64>,
144) -> Result<f64, EstimationError> {
145 let p = penalized_hessian.ncols();
146 let r = e_transformed.nrows();
147 let mp = (p as f64 - r as f64).max(0.0);
148 if r == 0 {
149 return Ok(p as f64);
150 }
151 let rhs_arr = e_transformed.t().to_owned();
152 let factor =
155 penalized_hessian
156 .factorize()
157 .map_err(|_| EstimationError::ModelIsIllConditioned {
158 condition_number: f64::INFINITY,
159 })?;
160 let sol = factor
161 .solvemulti(&rhs_arr)
162 .map_err(|_| EstimationError::ModelIsIllConditioned {
163 condition_number: f64::INFINITY,
164 })?;
165 if sol.nrows() == p && sol.ncols() == r && sol.iter().all(|v| v.is_finite()) {
166 return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
167 sol[[i, j]]
168 }));
169 }
170
171 Err(EstimationError::ModelIsIllConditioned {
172 condition_number: f64::INFINITY,
173 })
174}
175
176pub(super) fn calculate_edf_with_penalty(
177 penalized_hessian: &SymmetricMatrix,
178 penalty: &PirlsPenalty,
179) -> Result<f64, EstimationError> {
180 match penalty {
181 PirlsPenalty::Dense { e_transformed, .. } => {
182 calculate_edf(penalized_hessian, e_transformed)
183 }
184 PirlsPenalty::Diagonal {
185 diag,
186 positive_indices,
187 ..
188 } => calculate_edf_from_diagonal_penalty(penalized_hessian, diag, positive_indices),
189 }
190}
191
192pub(super) fn calculate_edfwithworkspace(
193 penalized_hessian: &Array2<f64>,
194 e_transformed: &Array2<f64>,
195 workspace: &mut PirlsWorkspace,
196) -> Result<f64, EstimationError> {
197 let p = penalized_hessian.ncols();
198 let r = e_transformed.nrows();
199 let mp = (p as f64 - r as f64).max(0.0);
200 if r == 0 {
201 return Ok(p as f64);
202 }
203 if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
204 workspace.final_aug_matrix = Array2::zeros((p, r));
205 }
206 for j in 0..r {
207 for i in 0..p {
208 workspace.final_aug_matrix[[i, j]] = e_transformed[[j, i]];
209 }
210 }
211
212 let factor = StableSolver::new("pirls edf workspace")
213 .factorize(penalized_hessian)
214 .map_err(|_| EstimationError::ModelIsIllConditioned {
215 condition_number: f64::INFINITY,
216 })?;
217 {
218 let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
219 factor.solve_in_place(rhsview.as_mut());
220 }
221 if workspace.final_aug_matrix.nrows() == p
222 && workspace.final_aug_matrix.ncols() == r
223 && array_is_finite(&workspace.final_aug_matrix)
224 {
225 return Ok(edf_from_solution(p, r, mp, e_transformed, |i, j| {
226 workspace.final_aug_matrix[(i, j)]
227 }));
228 }
229
230 Err(EstimationError::ModelIsIllConditioned {
231 condition_number: f64::INFINITY,
232 })
233}
234
235pub(super) fn calculate_edfwithworkspace_with_penalty(
236 penalized_hessian: &Array2<f64>,
237 penalty: &PirlsPenalty,
238 workspace: &mut PirlsWorkspace,
239) -> Result<f64, EstimationError> {
240 match penalty {
241 PirlsPenalty::Dense { e_transformed, .. } => {
242 calculate_edfwithworkspace(penalized_hessian, e_transformed, workspace)
243 }
244 PirlsPenalty::Diagonal {
245 diag,
246 positive_indices,
247 ..
248 } => calculate_edfwithworkspace_from_diagonal_penalty(
249 penalized_hessian,
250 diag,
251 positive_indices,
252 workspace,
253 ),
254 }
255}
256
257pub(super) fn calculate_edf_from_diagonal_penalty(
258 penalized_hessian: &SymmetricMatrix,
259 diag: &Array1<f64>,
260 positive_indices: &[usize],
261) -> Result<f64, EstimationError> {
262 let p = penalized_hessian.ncols();
263 let r = positive_indices.len();
264 let mp = (p as f64 - r as f64).max(0.0);
265 if r == 0 {
266 return Ok(p as f64);
267 }
268 let mut rhs_arr = Array2::<f64>::zeros((p, r));
269 for (col, &idx) in positive_indices.iter().enumerate() {
270 rhs_arr[[idx, col]] = 1.0;
271 }
272 let factor =
273 penalized_hessian
274 .factorize()
275 .map_err(|_| EstimationError::ModelIsIllConditioned {
276 condition_number: f64::INFINITY,
277 })?;
278 let sol = factor
279 .solvemulti(&rhs_arr)
280 .map_err(|_| EstimationError::ModelIsIllConditioned {
281 condition_number: f64::INFINITY,
282 })?;
283 let mut tr = 0.0;
284 for (col, &idx) in positive_indices.iter().enumerate() {
285 tr += diag[idx] * sol[[idx, col]];
286 }
287 Ok((p as f64 - tr).clamp(mp, p as f64))
288}
289
290pub(super) fn calculate_edfwithworkspace_from_diagonal_penalty(
291 penalized_hessian: &Array2<f64>,
292 diag: &Array1<f64>,
293 positive_indices: &[usize],
294 workspace: &mut PirlsWorkspace,
295) -> Result<f64, EstimationError> {
296 let p = penalized_hessian.ncols();
297 let r = positive_indices.len();
298 let mp = (p as f64 - r as f64).max(0.0);
299 if r == 0 {
300 return Ok(p as f64);
301 }
302 if workspace.final_aug_matrix.nrows() != p || workspace.final_aug_matrix.ncols() != r {
303 workspace.final_aug_matrix = Array2::zeros((p, r));
304 } else {
305 workspace.final_aug_matrix.fill(0.0);
306 }
307 for (col, &idx) in positive_indices.iter().enumerate() {
308 workspace.final_aug_matrix[[idx, col]] = 1.0;
309 }
310
311 let factor = StableSolver::new("pirls diagonal edf workspace")
312 .factorize(penalized_hessian)
313 .map_err(|_| EstimationError::ModelIsIllConditioned {
314 condition_number: f64::INFINITY,
315 })?;
316 {
317 let mut rhsview = array2_to_matmut(&mut workspace.final_aug_matrix);
318 factor.solve_in_place(rhsview.as_mut());
319 }
320 let mut tr = 0.0;
321 for (col, &idx) in positive_indices.iter().enumerate() {
322 tr += diag[idx] * workspace.final_aug_matrix[[idx, col]];
323 }
324 Ok((p as f64 - tr).clamp(mp, p as f64))
325}
326
327#[inline]
328pub(super) fn edf_from_solution<F>(
329 p: usize,
330 r: usize,
331 mp: f64,
332 e_transformed: &Array2<f64>,
333 solved_at: F,
334) -> f64
335where
336 F: Fn(usize, usize) -> f64,
337{
338 let mut tr = 0.0;
339 for j in 0..r {
340 for i in 0..p {
341 tr += solved_at(i, j) * e_transformed[(j, i)];
342 }
343 }
344 (p as f64 - tr).clamp(mp, p as f64)
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use gam_linalg::matrix::SymmetricMatrix;
351 use ndarray::array;
352
353 #[test]
362 pub(crate) fn calculate_edf_floors_when_penalty_rank_exceeds_coefficient_dim() {
363 let p = 2usize;
365 let hessian = SymmetricMatrix::Dense(array![[4.0, 1.0], [1.0, 3.0]]);
367 let e_transformed = array![[1.0, 0.0], [0.0, 1.0], [0.5, 0.5]];
370 assert_eq!(e_transformed.nrows(), 3);
371 assert_eq!(e_transformed.ncols(), p);
372
373 let edf = calculate_edf(&hessian, &e_transformed)
374 .expect("EDF solve should succeed for an SPD Hessian with r > p");
375
376 assert!(
378 edf.is_finite(),
379 "EDF must be finite for r > p penalty, got {edf}"
380 );
381 assert!(
382 (0.0..=p as f64).contains(&edf),
383 "EDF must lie in [0, {p}] for r > p penalty, got {edf}"
384 );
385 }
386}