Skip to main content

gam_solve/pirls/
sparse_system.rs

1//! Sparse-native P-IRLS eligibility decision and the sparse penalized-system
2//! (XᵀWX + Sλ) symbolic pattern / numeric cache used by the sparse solve path.
3
4use super::*;
5
6pub struct SparsePirlsDecision {
7    pub path: PirlsLinearSolvePath,
8    pub reason: &'static str,
9    pub p: usize,
10    pub nnz_x: usize,
11    pub nnz_xtwx_symbolic: Option<usize>,
12    pub nnz_s_lambda: usize,
13    pub nnz_h_est: Option<usize>,
14    pub density_h_est: Option<f64>,
15}
16
17pub(crate) fn fmt_opt_usize(v: Option<usize>) -> String {
18    v.map(|v| v.to_string()).unwrap_or_else(|| "na".to_string())
19}
20
21pub(crate) fn fmt_opt_f64(v: Option<f64>) -> String {
22    v.map(|v| format!("{v:.4}"))
23        .unwrap_or_else(|| "na".to_string())
24}
25
26impl SparsePirlsDecision {
27    pub(crate) fn path_str(&self) -> &'static str {
28        match self.path {
29            PirlsLinearSolvePath::DenseTransformed => "dense_transformed",
30            PirlsLinearSolvePath::SparseNative => "sparse_native",
31        }
32    }
33
34    pub(crate) fn format_fields(&self, path: &str) -> String {
35        format!(
36            "path={path} reason={} p={} nnz_x={} nnz_xtwx_symbolic={} nnz_s_lambda={} nnz_h_est={} density_h_est={}",
37            self.reason,
38            self.p,
39            self.nnz_x,
40            fmt_opt_usize(self.nnz_xtwx_symbolic),
41            self.nnz_s_lambda,
42            fmt_opt_usize(self.nnz_h_est),
43            fmt_opt_f64(self.density_h_est),
44        )
45    }
46
47    pub(crate) fn log_once(&self) {
48        let path = self.path_str();
49        let key = self.format_fields(path);
50        let repetition_count = pirls_decision_repetition_count(key.clone());
51        if repetition_count == 1 {
52            log::debug!("[pirls-path] {key}");
53            return;
54        }
55
56        if should_log_pirls_decision_summary(repetition_count) {
57            log::debug!(
58                "[pirls-path] repeated path={} reason={} count={} (suppressing identical decisions)",
59                path,
60                self.reason,
61                repetition_count,
62            );
63        }
64    }
65}
66
67pub(crate) fn pirls_decision_repetition_count(log_key: String) -> usize {
68    static PIRLS_DECISION_LOG_COUNTS: OnceLock<Mutex<HashMap<String, usize>>> = OnceLock::new();
69    let counts = PIRLS_DECISION_LOG_COUNTS.get_or_init(|| Mutex::new(HashMap::new()));
70    let mut counts = counts.lock().expect("pirls decision log counter poisoned");
71    let count = counts.entry(log_key).or_insert(0);
72    *count += 1;
73    *count
74}
75
76pub(crate) fn should_log_pirls_decision_summary(repetition_count: usize) -> bool {
77    repetition_count > 1 && repetition_count.is_power_of_two()
78}
79
80pub(crate) const SPARSE_NATIVE_MAX_H_DENSITY: f64 = 0.30;
81
82#[derive(Clone, Debug)]
83pub(crate) struct SparsePenaltyPattern {
84    pub(crate) upper_triplets: Vec<(usize, usize, f64)>,
85    pub(crate) nnz_upper: usize,
86}
87
88impl SparsePenaltyPattern {
89    pub(crate) fn from_dense_upper(matrix: &Array2<f64>, tol: f64) -> Self {
90        let p = matrix.nrows().min(matrix.ncols());
91        let mut upper_triplets = Vec::new();
92        for col in 0..p {
93            for row in 0..=col {
94                let value = matrix[[row, col]];
95                if value.abs() > tol {
96                    upper_triplets.push((row, col, value));
97                }
98            }
99        }
100        let nnz_upper = upper_triplets.len();
101        Self {
102            upper_triplets,
103            nnz_upper,
104        }
105    }
106}
107
108#[derive(Clone, Debug)]
109pub(crate) struct SparsePenalizedSystemStats {
110    pub(crate) nnz_xtwx_symbolic: usize,
111    pub(crate) nnz_s_lambda_upper: usize,
112    pub(crate) nnz_h_upper: usize,
113    pub(crate) density_upper: f64,
114}
115
116// Phase 2 sparse-native PIRLS will reuse this cache for symbolic structure and
117// repeated numeric assembly of H = X'WX + S_lambda + ridge I.
118//
119// This is the natural insertion point for any future selected-inversion /
120// Takahashi trace backend. In original spline coefficient order, the assembled
121// penalized system can remain sparse/banded, so exact traces like
122// tr(H^{-1} S_k) can be computed from a sparse factorization without ever
123// materializing a dense inverse. That is not true after the REML
124// reparameterization rotates the problem into the dense Qs basis.
125//
126// Algebra:
127//   H = X'WX + sum_k lambda_k S_k + delta I
128// and the REML/LAML first-order trace terms have the form
129//   T_k = tr(H^{-1} S_k).
130// Since tr(AB) = sum_ij A_ij B_ji, for symmetric sparse S_k we only need
131// inverse entries on the support of S_k:
132//   T_k = sum_{(i,j) in nz(S_k), i>=j} (2 - 1{i=j}) (H^{-1})_{ij} (S_k)_{ij}.
133// Takahashi/selected inversion exploits exactly this fact. Given a sparse
134// Cholesky-type factorization H = LDL', it computes only those entries of
135// H^{-1} that lie on the filled graph of L, which contains the structural
136// nonzeros needed for spline penalties. For banded spline systems with
137// half-bandwidth b, the work scales like sum_j |N(j)|^2 = O(p b^2) instead of
138// dense O(p^3), where N(j) is the subdiagonal nonzero pattern of column j of L.
139pub(crate) struct SparsePenalizedSystemCache {
140    pub(crate) xtwx_cache: SparseXtWxCache,
141    pub(crate) penalty_pattern: SparsePenaltyPattern,
142    pub(crate) h_upper_symbolic: SymbolicSparseColMat<usize>,
143    pub(crate) h_uppervalues: Vec<f64>,
144    pub(crate) h_upper_col_ptr: Vec<usize>,
145    pub(crate) h_upperrow_idx: Vec<usize>,
146    pub(crate) p: usize,
147}
148
149impl SparsePenalizedSystemCache {
150    pub(crate) fn new(
151        x: &SparseColMat<usize, f64>,
152        penalty_pattern: SparsePenaltyPattern,
153    ) -> Result<Self, EstimationError> {
154        let xtwx_cache = SparseXtWxCache::new(x)?;
155        let p = x.ncols();
156        let h_upper_symbolic = build_penalized_symbolic(
157            p,
158            xtwx_cache.xtwx_symbolic.col_ptr(),
159            xtwx_cache.xtwx_symbolic.row_idx(),
160            &penalty_pattern.upper_triplets,
161        )?;
162        let h_uppervalues = vec![0.0; h_upper_symbolic.row_idx().len()];
163        Ok(Self {
164            xtwx_cache,
165            penalty_pattern,
166            h_upper_col_ptr: h_upper_symbolic.col_ptr().to_vec(),
167            h_upperrow_idx: h_upper_symbolic.row_idx().to_vec(),
168            h_upper_symbolic,
169            h_uppervalues,
170            p,
171        })
172    }
173
174    pub(crate) fn matches(
175        &self,
176        x: &SparseColMat<usize, f64>,
177        penalty_pattern: &SparsePenaltyPattern,
178    ) -> bool {
179        self.xtwx_cache.matches(x)
180            && self.penalty_pattern.nnz_upper == penalty_pattern.nnz_upper
181            && self.penalty_pattern.upper_triplets == penalty_pattern.upper_triplets
182    }
183
184    pub(crate) fn stats(&self) -> SparsePenalizedSystemStats {
185        let upper_total = self.p.saturating_mul(self.p + 1) / 2;
186        SparsePenalizedSystemStats {
187            nnz_xtwx_symbolic: self.xtwx_cache.xtwx_symbolic.row_idx().len(),
188            nnz_s_lambda_upper: self.penalty_pattern.nnz_upper,
189            nnz_h_upper: self.h_upper_symbolic.row_idx().len(),
190            density_upper: if upper_total == 0 {
191                0.0
192            } else {
193                self.h_upper_symbolic.row_idx().len() as f64 / upper_total as f64
194            },
195        }
196    }
197
198    pub(crate) fn assemble_upper(
199        &mut self,
200        x: &SparseColMat<usize, f64>,
201        weights: &Array1<f64>,
202        ridge: f64,
203        precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
204    ) -> Result<SparseColMat<usize, f64>, EstimationError> {
205        if weights.len() != self.xtwx_cache.nrows {
206            crate::bail_invalid_estim!(
207                "weights length {} does not match design rows {}",
208                weights.len(),
209                self.xtwx_cache.nrows
210            );
211        }
212        // Gaussian-Identity fast path: when the caller has pre-built the
213        // `XᵀWX` numerical values (weights are constant across the outer
214        // loop), install them into the inner cache and skip the SpGEMM.
215        // We verify symbolic-pattern equivalence first; on mismatch we
216        // fall back to the regular per-call recomputation rather than
217        // installing values keyed to a different sparsity layout.
218        let use_precomputed = match precomputed_xtwx {
219            Some(pre) => {
220                let col_ptr_ok =
221                    pre.xtwx_symbolic_col_ptr.as_slice() == self.xtwx_cache.xtwx_symbolic.col_ptr();
222                let row_idx_ok =
223                    pre.xtwx_symbolic_row_idx.as_slice() == self.xtwx_cache.xtwx_symbolic.row_idx();
224                let values_ok = pre.xtwxvalues.len() == self.xtwx_cache.xtwxvalues.len();
225                if col_ptr_ok && row_idx_ok && values_ok {
226                    self.xtwx_cache.xtwxvalues.copy_from_slice(&pre.xtwxvalues);
227                    true
228                } else {
229                    log::warn!(
230                        "[sparse-xtwx-cache] precomputed XᵀWX pattern mismatch; \
231                         falling back to per-call recompute"
232                    );
233                    false
234                }
235            }
236            None => false,
237        };
238        if !use_precomputed {
239            self.xtwx_cache.compute_numeric(x, weights)?;
240        }
241        self.h_uppervalues.fill(0.0);
242
243        let mut cursor = self.h_upper_col_ptr[..self.p].to_vec();
244
245        let xtwx_col_ptr = self.xtwx_cache.xtwx_symbolic.col_ptr();
246        let xtwxrow_idx = self.xtwx_cache.xtwx_symbolic.row_idx();
247        for col in 0..self.p {
248            let start = xtwx_col_ptr[col];
249            let end = xtwx_col_ptr[col + 1];
250            for idx in start..end {
251                let row = xtwxrow_idx[idx];
252                if row <= col {
253                    let cursor_idx = &mut cursor[col];
254                    while *cursor_idx < self.h_upper_col_ptr[col + 1]
255                        && self.h_upperrow_idx[*cursor_idx] < row
256                    {
257                        *cursor_idx += 1;
258                    }
259                    if *cursor_idx >= self.h_upper_col_ptr[col + 1]
260                        || self.h_upperrow_idx[*cursor_idx] != row
261                    {
262                        crate::bail_invalid_estim!("penalized symbolic pattern missing XtWX entry");
263                    }
264                    self.h_uppervalues[*cursor_idx] += self.xtwx_cache.xtwxvalues[idx];
265                }
266            }
267        }
268
269        cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
270        for &(row, col, value) in &self.penalty_pattern.upper_triplets {
271            let cursor_idx = &mut cursor[col];
272            while *cursor_idx < self.h_upper_col_ptr[col + 1]
273                && self.h_upperrow_idx[*cursor_idx] < row
274            {
275                *cursor_idx += 1;
276            }
277            if *cursor_idx >= self.h_upper_col_ptr[col + 1]
278                || self.h_upperrow_idx[*cursor_idx] != row
279            {
280                crate::bail_invalid_estim!("penalized symbolic pattern missing penalty entry");
281            }
282            self.h_uppervalues[*cursor_idx] += value;
283        }
284
285        if ridge > 0.0 {
286            cursor.copy_from_slice(&self.h_upper_col_ptr[..self.p]);
287            for col in 0..self.p {
288                let cursor_idx = &mut cursor[col];
289                while *cursor_idx < self.h_upper_col_ptr[col + 1]
290                    && self.h_upperrow_idx[*cursor_idx] < col
291                {
292                    *cursor_idx += 1;
293                }
294                if *cursor_idx >= self.h_upper_col_ptr[col + 1]
295                    || self.h_upperrow_idx[*cursor_idx] != col
296                {
297                    crate::bail_invalid_estim!("penalized symbolic pattern missing diagonal entry");
298                }
299                self.h_uppervalues[*cursor_idx] += ridge;
300            }
301        }
302
303        Ok(SparseColMat::new(
304            self.h_upper_symbolic.clone(),
305            self.h_uppervalues.clone(),
306        ))
307    }
308}
309
310pub(crate) fn build_penalized_symbolic(
311    p: usize,
312    xtwx_col_ptr: &[usize],
313    xtwxrow_idx: &[usize],
314    penalty_triplets: &[(usize, usize, f64)],
315) -> Result<SymbolicSparseColMat<usize>, EstimationError> {
316    let mut cols: Vec<BTreeSet<usize>> = (0..p).map(|_| BTreeSet::new()).collect();
317    for col in 0..p {
318        cols[col].insert(col);
319        let start = xtwx_col_ptr[col];
320        let end = xtwx_col_ptr[col + 1];
321        for &row in &xtwxrow_idx[start..end] {
322            if row <= col {
323                cols[col].insert(row);
324            }
325        }
326    }
327    for &(row, col, _) in penalty_triplets {
328        if row > col || col >= p {
329            crate::bail_invalid_estim!(
330                "penalty sparse pattern must be upper-triangular within bounds"
331            );
332        }
333        cols[col].insert(row);
334    }
335
336    let mut col_ptr = Vec::with_capacity(p + 1);
337    let mut row_idx = Vec::new();
338    col_ptr.push(0);
339    for rows in cols {
340        row_idx.extend(rows.into_iter());
341        col_ptr.push(row_idx.len());
342    }
343    // `cols` has exactly p BTreeSet columns. Draining them into CSC order
344    // gives p+1 monotone col_ptr entries ending at row_idx.len(), and each
345    // per-column row slice is sorted and duplicate-free. Every inserted row
346    // satisfies row <= col < p: diagonal and XᵀWX entries are inserted only
347    // for the current column's upper triangle, and penalty triplets were
348    // checked above.
349    // SAFETY: the generated col_ptr length, monotonicity, terminal nnz,
350    // sorted per-column rows, absence of duplicates, and row bounds are
351    // exactly the CSC invariants skipped by new_unchecked.
352    Ok(unsafe { SymbolicSparseColMat::new_unchecked(p, p, col_ptr, None, row_idx) })
353}
354
355#[derive(Clone)]
356pub struct SparsePenalizedSystem {
357    pub h_sparse: SparseColMat<usize, f64>,
358    pub factor: gam_linalg::sparse_exact::SparseExactFactor,
359    pub logdet_h: f64,
360}
361
362pub(crate) fn sparse_reml_penalized_hessian(
363    workspace: &mut PirlsWorkspace,
364    x: &SparseColMat<usize, f64>,
365    weights: &Array1<f64>,
366    s_lambda: &Array2<f64>,
367    ridge: f64,
368    precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
369) -> Result<SparseColMat<usize, f64>, EstimationError> {
370    workspace.assemble_sparse_penalized_hessian(x, weights, s_lambda, ridge, precomputed_xtwx)
371}
372
373pub fn assemble_and_factor_sparse_penalized_system(
374    workspace: &mut PirlsWorkspace,
375    x: &SparseColMat<usize, f64>,
376    weights: &Array1<f64>,
377    s_lambda: &Array2<f64>,
378    ridge: f64,
379    precomputed_xtwx: Option<&SparseXtwxPrecomputed>,
380) -> Result<SparsePenalizedSystem, EstimationError> {
381    use gam_linalg::sparse_exact::{factorize_sparse_spd, logdet_from_factor};
382
383    let logdet_h_start = std::time::Instant::now();
384    let h_sparse =
385        sparse_reml_penalized_hessian(workspace, x, weights, s_lambda, ridge, precomputed_xtwx)?;
386    let factor = factorize_sparse_spd(&h_sparse)?;
387    let logdet_h = logdet_from_factor(&factor)?;
388    log::info!(
389        "[STAGE] logdet H (sparse Cholesky) p={} elapsed={:.3}s",
390        h_sparse.nrows(),
391        logdet_h_start.elapsed().as_secs_f64(),
392    );
393    Ok(SparsePenalizedSystem {
394        h_sparse,
395        factor,
396        logdet_h,
397    })
398}