1use super::*;
5
6pub struct SparsePirlsDecision {
7 pub path: PirlsLinearSolvePath,
8 pub reason: &'static str,
9 pub p: usize,
10 pub nnz_x: usize,
11 pub nnz_xtwx_symbolic: Option<usize>,
12 pub nnz_s_lambda: usize,
13 pub nnz_h_est: Option<usize>,
14 pub density_h_est: Option<f64>,
15}
16
17pub(crate) fn fmt_opt_usize(v: Option<usize>) -> String {
18 v.map(|v| v.to_string()).unwrap_or_else(|| "na".to_string())
19}
20
21pub(crate) fn fmt_opt_f64(v: Option<f64>) -> String {
22 v.map(|v| format!("{v:.4}"))
23 .unwrap_or_else(|| "na".to_string())
24}
25
26impl SparsePirlsDecision {
27 pub(crate) fn path_str(&self) -> &'static str {
28 match self.path {
29 PirlsLinearSolvePath::DenseTransformed => "dense_transformed",
30 PirlsLinearSolvePath::SparseNative => "sparse_native",
31 }
32 }
33
34 pub(crate) fn format_fields(&self, path: &str) -> String {
35 format!(
36 "path={path} reason={} p={} nnz_x={} nnz_xtwx_symbolic={} nnz_s_lambda={} nnz_h_est={} density_h_est={}",
37 self.reason,
38 self.p,
39 self.nnz_x,
40 fmt_opt_usize(self.nnz_xtwx_symbolic),
41 self.nnz_s_lambda,
42 fmt_opt_usize(self.nnz_h_est),
43 fmt_opt_f64(self.density_h_est),
44 )
45 }
46
47 pub(crate) fn log_once(&self) {
48 let path = self.path_str();
49 let key = self.format_fields(path);
50 let repetition_count = pirls_decision_repetition_count(key.clone());
51 if repetition_count == 1 {
52 log::debug!("[pirls-path] {key}");
53 return;
54 }
55
56 if should_log_pirls_decision_summary(repetition_count) {
57 log::debug!(
58 "[pirls-path] repeated path={} reason={} count={} (suppressing identical decisions)",
59 path,
60 self.reason,
61 repetition_count,
62 );
63 }
64 }
65}
66
67pub(crate) fn pirls_decision_repetition_count(log_key: String) -> usize {
68 static PIRLS_DECISION_LOG_COUNTS: OnceLock<Mutex<HashMap<String, usize>>> = OnceLock::new();
69 let counts = PIRLS_DECISION_LOG_COUNTS.get_or_init(|| Mutex::new(HashMap::new()));
70 let mut counts = counts.lock().expect("pirls decision log counter poisoned");
71 let count = counts.entry(log_key).or_insert(0);
72 *count += 1;
73 *count
74}
75
76pub(crate) fn should_log_pirls_decision_summary(repetition_count: usize) -> bool {
77 repetition_count > 1 && repetition_count.is_power_of_two()
78}
79
80pub(crate) const SPARSE_NATIVE_MAX_H_DENSITY: f64 = 0.30;
81
82#[derive(Clone, Debug)]
83pub(crate) struct SparsePenaltyPattern {
84 pub(crate) upper_triplets: Vec<(usize, usize, f64)>,
85 pub(crate) nnz_upper: usize,
86}
87
88impl SparsePenaltyPattern {
89 pub(crate) fn from_dense_upper(matrix: &Array2<f64>, tol: f64) -> Self {
90 let p = matrix.nrows().min(matrix.ncols());
91 let mut upper_triplets = Vec::new();
92 for col in 0..p {
93 for row in 0..=col {
94 let value = matrix[[row, col]];
95 if value.abs() > tol {
96 upper_triplets.push((row, col, value));
97 }
98 }
99 }
100 let nnz_upper = upper_triplets.len();
101 Self {
102 upper_triplets,
103 nnz_upper,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
109pub(crate) struct SparsePenalizedSystemStats {
110 pub(crate) nnz_xtwx_symbolic: usize,
111 pub(crate) nnz_s_lambda_upper: usize,
112 pub(crate) nnz_h_upper: usize,
113 pub(crate) density_upper: f64,
114}
115
116pub(crate) struct SparsePenalizedSystemCache {
140 pub(crate) xtwx_cache: SparseXtWxCache,
141 pub(crate) penalty_pattern: SparsePenaltyPattern,
142 pub(crate) h_upper_symbolic: SymbolicSparseColMat<usize>,
143 pub(crate) h_uppervalues: Vec<f64>,
144 pub(crate) h_upper_col_ptr: Vec<usize>,
145 pub(crate) h_upperrow_idx: Vec<usize>,
146 pub(crate) p: usize,
147}
148
149impl SparsePenalizedSystemCache {
150 pub(crate) fn new(
151 x: &SparseColMat<usize, f64>,
152 penalty_pattern: SparsePenaltyPattern,
153 ) -> Result<Self, EstimationError> {
154 let xtwx_cache = SparseXtWxCache::new(x)?;
155 let p = x.ncols();
156 let h_upper_symbolic = build_penalized_symbolic(
157 p,
158 xtwx_cache.xtwx_symbolic.col_ptr(),
159 xtwx_cache.xtwx_symbolic.row_idx(),
160 &penalty_pattern.upper_triplets,
161 )?;
162 let h_uppervalues = vec![0.0; h_upper_symbolic.row_idx().len()];
163 Ok(Self {
164 xtwx_cache,
165 penalty_pattern,
166 h_upper_col_ptr: h_upper_symbolic.col_ptr().to_vec(),
167 h_upperrow_idx: h_upper_symbolic.row_idx().to_vec(),
168 h_upper_symbolic,
169 h_uppervalues,
170 p,
171 })
172 }
173
174 pub(crate) fn matches(
175 &self,
176 x: &SparseColMat<usize, f64>,
177 penalty_pattern: &SparsePenaltyPattern,
178 ) -> bool {
179 self.xtwx_cache.matches(x)
180 && self.penalty_pattern.nnz_upper == penalty_pattern.nnz_upper
181 && self.penalty_pattern.upper_triplets == penalty_pattern.upper_triplets
182 }
183
184 pub(crate) fn stats(&self) -> SparsePenalizedSystemStats {
185 let upper_total = self.p.saturating_mul(self.p + 1) / 2;
186 SparsePenalizedSystemStats {
187 nnz_xtwx_symbolic: self.xtwx_cache.xtwx_symbolic.row_idx().len(),
188 nnz_s_lambda_upper: self.penalty_pattern.nnz_upper,
189 nnz_h_upper: self.h_upper_symbolic.row_idx().len(),
190 density_upper: if upper_total == 0 {
191 0.0
192 } else {
193 self.h_upper_symbolic.row_idx().len() as f64 / upper_total as f64
194 },
195 }
196 }
197
198 pub(crate) fn assemble_upper(
199 &mut self,
200 x: &SparseColMat<usize, f64>,
201 weights: &Array1<f64>,
202 ridge: f64,
203 precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
204 ) -> Result<SparseColMat<usize, f64>, EstimationError> {
205 if weights.len() != self.xtwx_cache.nrows {
206 crate::bail_invalid_estim!(
207 "weights length {} does not match design rows {}",
208 weights.len(),
209 self.xtwx_cache.nrows
210 );
211 }
212 let use_precomputed = match precomputed_xtwx {
219 Some(pre) => {
220 let col_ptr_ok =
221 pre.xtwx_symbolic_col_ptr.as_slice() == self.xtwx_cache.xtwx_symbolic.col_ptr();
222 let row_idx_ok =
223 pre.xtwx_symbolic_row_idx.as_slice() == self.xtwx_cache.xtwx_symbolic.row_idx();
224 let values_ok = pre.xtwxvalues.len() == self.xtwx_cache.xtwxvalues.len();
225 if col_ptr_ok && row_idx_ok && values_ok {
226 self.xtwx_cache.xtwxvalues.copy_from_slice(&pre.xtwxvalues);
227 true
228 } else {
229 log::warn!(
230 "[sparse-xtwx-cache] precomputed XᵀWX pattern mismatch; \
231 falling back to per-call recompute"
232 );
233 false
234 }
235 }
236 None => false,
237 };
238 if !use_precomputed {
239 self.xtwx_cache.compute_numeric(x, weights)?;
240 }
241 self.h_uppervalues.fill(0.0);
242
243 let mut cursor = self.h_upper_col_ptr[..self.p].to_vec();
244
245 let xtwx_col_ptr = self.xtwx_cache.xtwx_symbolic.col_ptr();
246 let xtwxrow_idx = self.xtwx_cache.xtwx_symbolic.row_idx();
247 for col in 0..self.p {
248 let start = xtwx_col_ptr[col];
249 let end = xtwx_col_ptr[col + 1];
250 for idx in start..end {
251 let row = xtwxrow_idx[idx];
252 if row <= col {
253 let cursor_idx = &mut cursor[col];
254 while *cursor_idx < self.h_upper_col_ptr[col + 1]
255 && self.h_upperrow_idx[*cursor_idx] < row
256 {
257 *cursor_idx += 1;
258 }
259 if *cursor_idx >= self.h_upper_col_ptr[col + 1]
260 || self.h_upperrow_idx[*cursor_idx] != row
261 {
262 crate::bail_invalid_estim!("penalized symbolic pattern missing XtWX entry");
263 }
264 self.h_uppervalues[*cursor_idx] += self.xtwx_cache.xtwxvalues[idx];
265 }
266 }
267 }
268
269 cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
270 for &(row, col, value) in &self.penalty_pattern.upper_triplets {
271 let cursor_idx = &mut cursor[col];
272 while *cursor_idx < self.h_upper_col_ptr[col + 1]
273 && self.h_upperrow_idx[*cursor_idx] < row
274 {
275 *cursor_idx += 1;
276 }
277 if *cursor_idx >= self.h_upper_col_ptr[col + 1]
278 || self.h_upperrow_idx[*cursor_idx] != row
279 {
280 crate::bail_invalid_estim!("penalized symbolic pattern missing penalty entry");
281 }
282 self.h_uppervalues[*cursor_idx] += value;
283 }
284
285 if ridge > 0.0 {
286 cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
287 for col in 0..self.p {
288 let cursor_idx = &mut cursor[col];
289 while *cursor_idx < self.h_upper_col_ptr[col + 1]
290 && self.h_upperrow_idx[*cursor_idx] < col
291 {
292 *cursor_idx += 1;
293 }
294 if *cursor_idx >= self.h_upper_col_ptr[col + 1]
295 || self.h_upperrow_idx[*cursor_idx] != col
296 {
297 crate::bail_invalid_estim!("penalized symbolic pattern missing diagonal entry");
298 }
299 self.h_uppervalues[*cursor_idx] += ridge;
300 }
301 }
302
303 Ok(SparseColMat::new(
304 self.h_upper_symbolic.clone(),
305 self.h_uppervalues.clone(),
306 ))
307 }
308}
309
310pub(crate) fn build_penalized_symbolic(
311 p: usize,
312 xtwx_col_ptr: &[usize],
313 xtwxrow_idx: &[usize],
314 penalty_triplets: &[(usize, usize, f64)],
315) -> Result<SymbolicSparseColMat<usize>, EstimationError> {
316 let mut cols: Vec<BTreeSet<usize>> = (0..p).map(|_| BTreeSet::new()).collect();
317 for col in 0..p {
318 cols[col].insert(col);
319 let start = xtwx_col_ptr[col];
320 let end = xtwx_col_ptr[col + 1];
321 for &row in &xtwxrow_idx[start..end] {
322 if row <= col {
323 cols[col].insert(row);
324 }
325 }
326 }
327 for &(row, col, _) in penalty_triplets {
328 if row > col || col >= p {
329 crate::bail_invalid_estim!(
330 "penalty sparse pattern must be upper-triangular within bounds"
331 );
332 }
333 cols[col].insert(row);
334 }
335
336 let mut col_ptr = Vec::with_capacity(p + 1);
337 let mut row_idx = Vec::new();
338 col_ptr.push(0);
339 for rows in cols {
340 row_idx.extend(rows.into_iter());
341 col_ptr.push(row_idx.len());
342 }
343 Ok(unsafe { SymbolicSparseColMat::new_unchecked(p, p, col_ptr, None, row_idx) })
353}
354
355#[derive(Clone)]
356pub struct SparsePenalizedSystem {
357 pub h_sparse: SparseColMat<usize, f64>,
358 pub factor: gam_linalg::sparse_exact::SparseExactFactor,
359 pub logdet_h: f64,
360}
361
362pub(crate) fn sparse_reml_penalized_hessian(
363 workspace: &mut PirlsWorkspace,
364 x: &SparseColMat<usize, f64>,
365 weights: &Array1<f64>,
366 s_lambda: &Array2<f64>,
367 ridge: f64,
368 precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
369) -> Result<SparseColMat<usize, f64>, EstimationError> {
370 workspace.assemble_sparse_penalized_hessian(x, weights, s_lambda, ridge, precomputed_xtwx)
371}
372
373pub fn assemble_and_factor_sparse_penalized_system(
374 workspace: &mut PirlsWorkspace,
375 x: &SparseColMat<usize, f64>,
376 weights: &Array1<f64>,
377 s_lambda: &Array2<f64>,
378 ridge: f64,
379 precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
380) -> Result<SparsePenalizedSystem, EstimationError> {
381 use gam_linalg::sparse_exact::{factorize_sparse_spd, logdet_from_factor};
382
383 let logdet_h_start = std::time::Instant::now();
384 let h_sparse =
385 sparse_reml_penalized_hessian(workspace, x, weights, s_lambda, ridge, precomputed_xtwx)?;
386 let factor = factorize_sparse_spd(&h_sparse)?;
387 let logdet_h = logdet_from_factor(&factor)?;
388 log::info!(
389 "[STAGE] logdet H (sparse Cholesky) p={} elapsed={:.3}s",
390 h_sparse.nrows(),
391 logdet_h_start.elapsed().as_secs_f64(),
392 );
393 Ok(SparsePenalizedSystem {
394 h_sparse,
395 factor,
396 logdet_h,
397 })
398}