Skip to main content

gam_linalg/matrix/
mod.rs

1use crate::faer_ndarray::{
2    CrossprodAccum, CrossprodStructure, FaerArrayView, array2_to_matmut,
3    effective_global_parallelism, fast_ab, fast_atb, fast_atv, fast_atv_into, fast_av,
4    fast_av_into, fast_xt_diag_x, stream_weighted_crossprod_into,
5};
6use crate::types::RidgePolicy;
7use faer::Accum;
8use faer::linalg::matmul::matmul;
9use faer::sparse::{SparseColMat, SparseRowMat, Triplet};
10use gam_runtime::resource::{
11    MaterializationPolicy, MatrixMaterializationError, ResourcePolicy, rows_for_target_bytes,
12};
13use ndarray::{
14    Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis, ShapeBuilder, s,
15};
16use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
17use std::borrow::Cow;
18use std::collections::BTreeMap;
19use std::ops::Deref;
20use std::ops::Range;
21use std::sync::{Arc, OnceLock};
22
23const MATRIX_FREE_PCG_MIN_P: usize = 2048;
24const MATRIX_FREE_PCG_REL_TOL: f64 = 1e-8;
25/// Minimum numerical ridge added to the (penalized) normal matrix before an SPD
26/// solve. Near `f64` precision: large enough to lift an exactly-singular system
27/// off zero so the factorization succeeds, small enough not to bias a
28/// well-conditioned solve. Acts as a floor on any caller-supplied `ridge_floor`.
29const SPD_SOLVE_RIDGE_FLOOR: f64 = 1e-15;
30const MATRIX_FREE_PCG_MAX_ITER: usize = 2000;
31const MAX_SINGLE_DENSE_MATERIALIZATION_BYTES: usize = 256 * 1024 * 1024;
32const MAX_PERSISTENT_SPARSE_DENSE_CACHE_BYTES: usize = 256 * 1024 * 1024;
33const MAX_SPARSE_TO_DENSE_BYTES: usize = MAX_SINGLE_DENSE_MATERIALIZATION_BYTES;
34const CHUNKED_DENSE_MATERIALIZATION_BYTES: usize = 8 * 1024 * 1024;
35const OPERATOR_ROW_CHUNK_SIZE: usize = 256;
36/// Minimum n*p product for the dense-row parallel fold/reduce paths
37/// (`diag_gram`, `apply_weighted_normal`, dense transpose reductions).
38/// Below this, the sequential row loop wins on overhead.
39const DENSE_ROW_PARALLEL_MIN_NP: u64 = 200_000;
40const WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS: u64 = 500_000;
41const SPARSE_ROW_PARALLEL_MIN_FLOPS: u64 = 100_000;
42/// Maximum bytes for the (n, tail_total) intermediate in GEMM-batched tensor
43/// product matvecs.  Beyond this threshold, fall back to per-column GEMV.
44const TENSOR_GEMM_MAX_INTERMEDIATE_BYTES: usize = 128 * 1024 * 1024; // 128 MB
45
46pub use crate::utils::PcgSolveInfo;
47
48mod sparse_hessian;
49pub use sparse_hessian::SparseHessianAccumulator;
50
51mod weights;
52pub use weights::{PsdWeightsView, SignedWeightsArc, SignedWeightsView};
53
54/// Typed error for `src/linalg/matrix.rs` operations.  All error sites in this
55/// module construct a `MatrixError` variant; trait method bodies that still
56/// return `Result<_, String>` convert via `From<MatrixError> for String` (which
57/// is byte-equivalent to the prior `format!` / `to_string` payloads).
58#[derive(Debug, Clone)]
59pub enum MatrixError {
60    /// Operand shapes (rows, columns, lengths) do not satisfy the operation's
61    /// dimension contract.  Also covers integer-overflow in dimension products.
62    DimensionMismatch { reason: String },
63    /// Refused to materialize an operator-backed or sparse design to a dense
64    /// `Array2<f64>` because the active `ResourcePolicy` (size cap or strict
65    /// operator-only mode) forbids it.
66    DensificationRefused { reason: String },
67}
68
69crate::impl_reason_error_boilerplate! {
70    MatrixError {
71        DimensionMismatch,
72        DensificationRefused,
73    }
74}
75
76#[inline]
77fn dense_materialization_chunk_rows(nrows: usize, ncols: usize) -> usize {
78    rows_for_target_bytes(CHUNKED_DENSE_MATERIALIZATION_BYTES, ncols)
79        .max(1)
80        .min(nrows.max(1))
81}
82
83fn dense_operator_to_dense_by_chunks<O: DenseDesignOperator + ?Sized>(
84    op: &O,
85) -> Result<Array2<f64>, MatrixMaterializationError> {
86    let n = op.nrows();
87    let p = op.ncols();
88    let chunk_rows = dense_materialization_chunk_rows(n, p);
89    let mut out = Array2::<f64>::zeros((n, p));
90    for start in (0..n).step_by(chunk_rows) {
91        let end = (start + chunk_rows).min(n);
92        let slice = out.slice_mut(s![start..end, ..]);
93        op.row_chunk_into(start..end, slice)?;
94    }
95    Ok(out)
96}
97
98pub fn checked_dense_nbytes(nrows: usize, ncols: usize, context: &str) -> Result<usize, String> {
99    nrows
100        .checked_mul(ncols)
101        .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
102        .ok_or_else(|| {
103            MatrixError::DimensionMismatch {
104                reason: format!("{context}: dense size overflow for {nrows}x{ncols}"),
105            }
106            .into()
107        })
108}
109
110pub fn panic_or_error_if_large_scale_mode_and_to_dense_called_with_policy(
111    context: &str,
112    n: usize,
113    p: usize,
114    policy: &ResourcePolicy,
115) -> Result<(), String> {
116    // Strict-operator mode: refuse any dense materialization, regardless of
117    // size.  Callers in this mode have committed to operator-only math; any
118    // dense fallback (cache or otherwise) violates that contract and would
119    // silently turn an analytic-operator path into a hidden dense path at
120    // large scale.
121    if matches!(
122        policy.derivative_storage_mode,
123        gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired
124    ) {
125        return Err(MatrixError::DensificationRefused {
126            reason: format!(
127                "{context}: refusing to densify operator-backed design {n}x{p} under \
128             AnalyticOperatorRequired policy; provide an operator-form path"
129            ),
130        }
131        .into());
132    }
133    let dense_bytes = checked_dense_nbytes(n, p, context)?;
134    let limit = policy.max_single_materialization_bytes;
135    if dense_bytes > limit {
136        let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
137        return Err(MatrixError::DensificationRefused {
138            reason: format!(
139                "{context}: refusing to densify operator-backed design {n}x{p} (~{gib:.2} GiB); use matrix-free or chunked code"
140            ),
141        }
142        .into());
143    }
144    Ok(())
145}
146
147fn weighted_crossprod_dense(
148    left: &Array2<f64>,
149    weights: &Array1<f64>,
150    right: &Array2<f64>,
151) -> Result<Array2<f64>, String> {
152    if left.nrows() != weights.len() || right.nrows() != weights.len() {
153        return Err(MatrixError::DimensionMismatch {
154            reason: format!(
155                "weighted_crossprod_dense row mismatch: left={}, weights={}, right={}",
156                left.nrows(),
157                weights.len(),
158                right.nrows()
159            ),
160        }
161        .into());
162    }
163    Ok(weighted_crossprod_dense_view(left, weights.view(), right))
164}
165
166fn weighted_crossprod_dense_view(
167    left: &Array2<f64>,
168    weights: ArrayView1<'_, f64>,
169    right: &Array2<f64>,
170) -> Array2<f64> {
171    let n = weights.len();
172    let p_left = left.ncols();
173    let p_right = right.ncols();
174    let work = (n as u64)
175        .saturating_mul(p_left as u64)
176        .saturating_mul(p_right as u64);
177    if rayon::current_num_threads() <= 1 || work < WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS {
178        return weighted_crossprod_dense_rows(left, weights, right, 0..n);
179    }
180
181    let min_parallel_work = WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS.min(usize::MAX as u64) as usize;
182    let Some(chunk_rows) = crate::parallel::row_reduction_chunk_rows(
183        n,
184        p_left.saturating_mul(p_right),
185        p_left.saturating_mul(p_right),
186        min_parallel_work,
187    ) else {
188        return weighted_crossprod_dense_rows(left, weights, right, 0..n);
189    };
190    let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
191    let partials: Vec<Array2<f64>> = starts
192        .into_par_iter()
193        .map(|start| {
194            weighted_crossprod_dense_rows(left, weights, right, start..(start + chunk_rows).min(n))
195        })
196        .collect();
197    let mut out = Array2::<f64>::zeros((p_left, p_right));
198    for partial in &partials {
199        out += partial;
200    }
201    out
202}
203
204fn weighted_crossprod_dense_rows(
205    left: &Array2<f64>,
206    weights: ArrayView1<'_, f64>,
207    right: &Array2<f64>,
208    rows: Range<usize>,
209) -> Array2<f64> {
210    // The per-row body below is `Σᵢ wᵢ · leftᵢᵀ · rightᵢ`, which is linear in
211    // `wᵢ` and therefore sign-correct without any PSD assumption. The PSD
212    // precondition belongs at the symmetric `Xᵀ W X` caller (`weighted_crossprod_dense_view`),
213    // not at this kernel: `BlockDesignOperator::cross_block` legitimately uses
214    // the asymmetric form `X_iᵀ W X_j` with signed `c·Xv` weights from the outer
215    // REML Hessian-derivative correction, which is not PSD even when `w ≥ 0`.
216    // The prior assert here turned that legitimate signed use into a panic.
217    let p_left = left.ncols();
218    let p_right = right.ncols();
219    let mut out = Array2::<f64>::zeros((p_left, p_right));
220    if left.is_standard_layout()
221        && right.is_standard_layout()
222        && let (Some(lx), Some(rx), Some(w)) =
223            (left.as_slice(), right.as_slice(), weights.as_slice())
224    {
225        let out_slice = out.as_slice_mut().expect("zeros are contiguous");
226        for i in rows {
227            let wi = w[i];
228            if wi == 0.0 {
229                continue;
230            }
231            let l_row = &lx[i * p_left..i * p_left + p_left];
232            let r_row = &rx[i * p_right..i * p_right + p_right];
233            for a in 0..p_left {
234                let scaled = wi * l_row[a];
235                if scaled == 0.0 {
236                    continue;
237                }
238                let out_row = &mut out_slice[a * p_right..a * p_right + p_right];
239                for b in 0..p_right {
240                    out_row[b] += scaled * r_row[b];
241                }
242            }
243        }
244        return out;
245    }
246    for i in rows {
247        let wi = weights[i];
248        if wi == 0.0 {
249            continue;
250        }
251        for a in 0..p_left {
252            let scaled = wi * left[[i, a]];
253            if scaled == 0.0 {
254                continue;
255            }
256            for b in 0..p_right {
257                out[[a, b]] += scaled * right[[i, b]];
258            }
259        }
260    }
261    out
262}
263
264pub struct DenseRightProductView<'a> {
265    base: &'a Array2<f64>,
266    first: Option<&'a Array2<f64>>,
267    second: Option<&'a Array2<f64>>,
268}
269
270impl<'a> DenseRightProductView<'a> {
271    pub fn new(base: &'a Array2<f64>) -> Self {
272        Self {
273            base,
274            first: None,
275            second: None,
276        }
277    }
278
279    pub fn with_factor(mut self, factor: &'a Array2<f64>) -> Self {
280        if self.first.is_none() {
281            self.first = Some(factor);
282        } else if self.second.is_none() {
283            self.second = Some(factor);
284        } else {
285            // SAFETY: DenseRightProductView statically carries exactly two optional
286            // factor slots (`first` and `second`); reaching this branch means a
287            // caller invoked `with_factor` a third time, which violates the
288            // type's documented contract of at most two right factors.
289            // SAFETY: third `with_factor` call violates the type's two-factor invariant.
290            std::panic::panic_any("DenseRightProductView supports at most two right factors");
291        }
292        self
293    }
294
295    pub fn with_optional_factor(self, factor: Option<&'a Array2<f64>>) -> Self {
296        match factor {
297            Some(factor) => self.with_factor(factor),
298            None => self,
299        }
300    }
301
302    pub fn materialize(&self) -> Array2<f64> {
303        let mut out = self.base.clone();
304        if let Some(factor) = self.first {
305            out = fast_ab(&out, factor);
306        }
307        if let Some(factor) = self.second {
308            out = fast_ab(&out, factor);
309        }
310        out
311    }
312
313    fn transformed_ncols(&self) -> usize {
314        if let Some(factor) = self.second {
315            factor.ncols()
316        } else if let Some(factor) = self.first {
317            factor.ncols()
318        } else {
319            self.base.ncols()
320        }
321    }
322}
323
324pub struct EmbeddedColumnBlock<'a> {
325    local: &'a Array2<f64>,
326    global_range: Range<usize>,
327    total_cols: usize,
328}
329
330impl<'a> EmbeddedColumnBlock<'a> {
331    pub fn new(local: &'a Array2<f64>, global_range: Range<usize>, total_cols: usize) -> Self {
332        Self {
333            local,
334            global_range,
335            total_cols,
336        }
337    }
338
339    pub fn materialize(&self) -> Array2<f64> {
340        if self.local.nrows() == 0 {
341            return Array2::<f64>::zeros((0, self.total_cols));
342        }
343        assert_eq!(
344            self.local.ncols(),
345            self.global_range.len(),
346            "embedded column block width mismatch"
347        );
348        let mut out = Array2::<f64>::zeros((self.local.nrows(), self.total_cols));
349        out.slice_mut(ndarray::s![.., self.global_range.clone()])
350            .assign(self.local);
351        out
352    }
353}
354
355pub struct EmbeddedSquareBlock<'a> {
356    local: &'a Array2<f64>,
357    global_range: Range<usize>,
358    total_dim: usize,
359}
360
361impl<'a> EmbeddedSquareBlock<'a> {
362    pub fn new(local: &'a Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
363        Self {
364            local,
365            global_range,
366            total_dim,
367        }
368    }
369
370    pub fn materialize(&self) -> Array2<f64> {
371        let mut out = Array2::<f64>::zeros((self.total_dim, self.total_dim));
372        out.slice_mut(ndarray::s![
373            self.global_range.clone(),
374            self.global_range.clone()
375        ])
376        .assign(self.local);
377        out
378    }
379}
380
381struct PenalizedWeightedNormalOperator<'a, O: LinearOperator + ?Sized> {
382    operator: &'a O,
383    weights: &'a Array1<f64>,
384    penalty: Option<&'a Array2<f64>>,
385    ridge: f64,
386}
387
388impl<'a, O: LinearOperator + ?Sized> PenalizedWeightedNormalOperator<'a, O> {
389    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
390        self.operator
391            .apply_weighted_normal(self.weights, vector, self.penalty, self.ridge)
392    }
393
394    fn jacobi_preconditioner(&self) -> Result<Array1<f64>, String> {
395        let mut diag = self.operator.diag_gram(self.weights)?;
396        if let Some(pen) = self.penalty {
397            for i in 0..diag.len() {
398                diag[i] += pen[[i, i]];
399            }
400        }
401        if self.ridge > 0.0 {
402            for i in 0..diag.len() {
403                diag[i] += self.ridge;
404            }
405        }
406        Ok(diag)
407    }
408}
409
410#[inline]
411fn dense_diag_gram_view(matrix: &Array2<f64>, weights: PsdWeightsView<'_>) -> Array1<f64> {
412    // Diagonal of XᵀWX — used as Fisher-info diagonal for preconditioning and
413    // for diagonal-of-Gram queries. Negative weights have no sensible meaning
414    // here (the diagonal must be nonneg for it to act as a preconditioner);
415    // typed at the boundary via `PsdWeightsView` so the previous runtime
416    // `assert!` is no longer required inside the kernel.
417    let weights = weights.view();
418    let p = matrix.ncols();
419    let n = matrix.nrows();
420    let large = (n as u64) * (p as u64) >= DENSE_ROW_PARALLEL_MIN_NP;
421    let parallel = large && rayon::current_thread_index().is_none();
422    // Fast path: if the matrix is row-major contiguous, read each row as a
423    // slice and avoid n*p bounds-checked indexing.
424    if matrix.is_standard_layout()
425        && let (Some(x), Some(w)) = (matrix.as_slice(), weights.as_slice())
426    {
427        if parallel {
428            return (0..n)
429                .into_par_iter()
430                .fold(
431                    || vec![0.0_f64; p],
432                    |mut acc, i| {
433                        let wi = w[i];
434                        if wi != 0.0 {
435                            let row = &x[i * p..i * p + p];
436                            for j in 0..p {
437                                let xij = row[j];
438                                acc[j] += wi * xij * xij;
439                            }
440                        }
441                        acc
442                    },
443                )
444                .reduce(
445                    || vec![0.0_f64; p],
446                    |mut a, b| {
447                        for (av, bv) in a.iter_mut().zip(b) {
448                            *av += bv;
449                        }
450                        a
451                    },
452                )
453                .into();
454        }
455        let mut diag = Array1::<f64>::zeros(p);
456        let diag_slice = diag.as_slice_mut().expect("zeros are contiguous");
457        for i in 0..n {
458            let wi = w[i];
459            if wi == 0.0 {
460                continue;
461            }
462            let row = &x[i * p..i * p + p];
463            for j in 0..p {
464                let xij = row[j];
465                diag_slice[j] += wi * xij * xij;
466            }
467        }
468        return diag;
469    }
470    let mut diag = Array1::<f64>::zeros(p);
471    for i in 0..n {
472        let wi = weights[i];
473        if wi == 0.0 {
474            continue;
475        }
476        for j in 0..p {
477            let xij = matrix[[i, j]];
478            diag[j] += wi * xij * xij;
479        }
480    }
481    diag
482}
483
484fn sparse_csr_weighted_xtwx(
485    row_ptr: &[usize],
486    col_idx: &[usize],
487    vals: &[f64],
488    n: usize,
489    p: usize,
490    weights: ArrayView1<'_, f64>,
491) -> Array2<f64> {
492    let nnz = vals.len() as u64;
493    let avg = nnz.checked_div(n.max(1) as u64).unwrap_or(0);
494    let work = (n as u64).saturating_mul(avg.saturating_mul(avg));
495    if rayon::current_num_threads() <= 1 || work < SPARSE_ROW_PARALLEL_MIN_FLOPS {
496        return sparse_csr_weighted_xtwx_rows(row_ptr, col_idx, vals, p, weights, 0..n);
497    }
498
499    let min_parallel_work = SPARSE_ROW_PARALLEL_MIN_FLOPS.min(usize::MAX as u64) as usize;
500    let Some(chunk_rows) = crate::parallel::row_reduction_chunk_rows(
501        n,
502        avg.min(usize::MAX as u64) as usize,
503        p.saturating_mul(p),
504        min_parallel_work,
505    ) else {
506        return sparse_csr_weighted_xtwx_rows(row_ptr, col_idx, vals, p, weights, 0..n);
507    };
508    let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
509    let partials: Vec<Array2<f64>> = starts
510        .into_par_iter()
511        .map(|start| {
512            sparse_csr_weighted_xtwx_rows(
513                row_ptr,
514                col_idx,
515                vals,
516                p,
517                weights,
518                start..(start + chunk_rows).min(n),
519            )
520        })
521        .collect();
522    let mut xtwx = Array2::<f64>::zeros((p, p));
523    for partial in &partials {
524        xtwx += partial;
525    }
526    xtwx
527}
528
529fn sparse_csr_weighted_xtwx_rows(
530    row_ptr: &[usize],
531    col_idx: &[usize],
532    vals: &[f64],
533    p: usize,
534    weights: ArrayView1<'_, f64>,
535    rows: Range<usize>,
536) -> Array2<f64> {
537    // PSD precondition is discharged at the typed boundary
538    // (`PsdWeightsView::try_new` inside callers of `xt_diag_x_psd_op`). The CSC
539    // counterpart (`streaming_sparse_csc_xt_diag_x`) accepts signed weights and
540    // is the right path for observed-Hessian assembly; this CSR-row kernel is
541    // reserved for Fisher-scoring Gram builds where the working weights are
542    // guaranteed nonneg by typed construction.
543    let mut xtwx = Array2::<f64>::zeros((p, p));
544    for i in rows {
545        let wi = weights[i];
546        if wi == 0.0 {
547            continue;
548        }
549        let start = row_ptr[i];
550        let end = row_ptr[i + 1];
551        for a_ptr in start..end {
552            let a = col_idx[a_ptr];
553            let wxa = wi * vals[a_ptr];
554            for b_ptr in a_ptr..end {
555                let b = col_idx[b_ptr];
556                let v = wxa * vals[b_ptr];
557                xtwx[[a, b]] += v;
558                if a != b {
559                    xtwx[[b, a]] += v;
560                }
561            }
562        }
563    }
564    xtwx
565}
566
567pub fn streaming_sparse_csc_xt_diag_x(
568    col_ptr: &[usize],
569    row_idx: &[usize],
570    vals: &[f64],
571    n: usize,
572    p: usize,
573    weights: ArrayView1<'_, f64>,
574    out: &mut Array2<f64>,
575) {
576    if n == 0 || p == 0 {
577        return;
578    }
579
580    let chunk_rows = dense_materialization_chunk_rows(n, p);
581    let par = effective_global_parallelism();
582    let mut x_chunk = Array2::<f64>::zeros((chunk_rows, p).f());
583    let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p).f());
584
585    {
586        let mut out_view = array2_to_matmut(out);
587
588        for start in (0..n).step_by(chunk_rows) {
589            let rows = (n - start).min(chunk_rows);
590            {
591                let mut x_slice = x_chunk.slice_mut(s![0..rows, ..]);
592                let mut wx_slice = wx_chunk.slice_mut(s![0..rows, ..]);
593                x_slice.fill(0.0);
594                wx_slice.fill(0.0);
595                let end = start + rows;
596                for col in 0..p {
597                    let col_start = col_ptr[col];
598                    let col_end = col_ptr[col + 1];
599                    let rows_for_col = &row_idx[col_start..col_end];
600                    let local_start = rows_for_col.partition_point(|&row| row < start);
601                    let local_end = rows_for_col.partition_point(|&row| row < end);
602                    for local_ptr in local_start..local_end {
603                        let ptr = col_start + local_ptr;
604                        let row = row_idx[ptr];
605                        let local = row - start;
606                        let wi = weights[row];
607                        let value = vals[ptr];
608                        x_slice[[local, col]] += value;
609                        wx_slice[[local, col]] += wi * value;
610                    }
611                }
612            }
613            let x_slice = x_chunk.slice(s![0..rows, ..]);
614            let wx_slice = wx_chunk.slice(s![0..rows, ..]);
615            let x_view = FaerArrayView::new(&x_slice);
616            let wx_view = FaerArrayView::new(&wx_slice);
617            matmul(
618                out_view.as_mut(),
619                Accum::Add,
620                x_view.as_ref().transpose(),
621                wx_view.as_ref(),
622                1.0,
623                par,
624            );
625        }
626    }
627}
628
629fn sparse_csr_diag_gram(
630    row_ptr: &[usize],
631    col_idx: &[usize],
632    vals: &[f64],
633    n: usize,
634    p: usize,
635    weights: ArrayView1<'_, f64>,
636) -> Array1<f64> {
637    let work = vals.len() as u64;
638    if rayon::current_num_threads() <= 1 || work < SPARSE_ROW_PARALLEL_MIN_FLOPS {
639        return sparse_csr_diag_gram_rows(row_ptr, col_idx, vals, p, weights, 0..n);
640    }
641    let min_parallel_work = SPARSE_ROW_PARALLEL_MIN_FLOPS.min(usize::MAX as u64) as usize;
642    let Some(chunk_rows) = crate::parallel::row_reduction_chunk_rows(n, 1, p, min_parallel_work)
643    else {
644        return sparse_csr_diag_gram_rows(row_ptr, col_idx, vals, p, weights, 0..n);
645    };
646    let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
647    let partials: Vec<Array1<f64>> = starts
648        .into_par_iter()
649        .map(|start| {
650            sparse_csr_diag_gram_rows(
651                row_ptr,
652                col_idx,
653                vals,
654                p,
655                weights,
656                start..(start + chunk_rows).min(n),
657            )
658        })
659        .collect();
660    let mut diag = Array1::<f64>::zeros(p);
661    for partial in &partials {
662        diag += partial;
663    }
664    diag
665}
666
667fn sparse_csr_diag_gram_rows(
668    row_ptr: &[usize],
669    col_idx: &[usize],
670    vals: &[f64],
671    p: usize,
672    weights: ArrayView1<'_, f64>,
673    rows: Range<usize>,
674) -> Array1<f64> {
675    // PSD precondition discharged at the typed boundary
676    // (`PsdWeightsView::try_new` inside callers of `xt_diag_x_psd_op`).
677    // Signed observed-Hessian assembly uses the signed Gram path
678    // (xt_diag_x_signed → streaming kernels) and never reaches this routine.
679    let mut diag = Array1::<f64>::zeros(p);
680    for i in rows {
681        let wi = weights[i];
682        if wi == 0.0 {
683            continue;
684        }
685        for idx in row_ptr[i]..row_ptr[i + 1] {
686            let j = col_idx[idx];
687            let xij = vals[idx];
688            diag[j] += wi * xij * xij;
689        }
690    }
691    diag
692}
693
694#[inline]
695fn dense_transpose_weighted_response(
696    matrix: &Array2<f64>,
697    weights: &Array1<f64>,
698    y: &Array1<f64>,
699    row_scale: Option<&Array1<f64>>,
700) -> Array1<f64> {
701    // Signed-safe: XᵀWy is linear in W, so observed-Hessian / non-canonical-link
702    // IRLS sites that drive signed working weights through this kernel must be
703    // preserved end-to-end. Clipping negative weights here silently biases the
704    // pseudo-response and was the source of the Gram-cleanup mismatch.
705    let p = matrix.ncols();
706    let n = matrix.nrows();
707    let mut out = Array1::<f64>::zeros(p);
708    if matrix.is_standard_layout()
709        && let (Some(x), Some(w), Some(yslice)) =
710            (matrix.as_slice(), weights.as_slice(), y.as_slice())
711    {
712        let scale_slice = row_scale.and_then(|s| s.as_slice());
713        let out_slice = out.as_slice_mut().expect("zeros are contiguous");
714        for i in 0..n {
715            let mut scaled = yslice[i] * w[i];
716            if let Some(s) = scale_slice {
717                scaled *= s[i];
718            } else if let Some(scale) = row_scale {
719                scaled *= scale[i];
720            }
721            if scaled == 0.0 {
722                continue;
723            }
724            let row = &x[i * p..i * p + p];
725            for j in 0..p {
726                out_slice[j] += row[j] * scaled;
727            }
728        }
729        return out;
730    }
731    for i in 0..n {
732        let mut scaled = y[i] * weights[i];
733        if let Some(scale) = row_scale {
734            scaled *= scale[i];
735        }
736        if scaled == 0.0 {
737            continue;
738        }
739        for j in 0..p {
740            out[j] += matrix[[i, j]] * scaled;
741        }
742    }
743    out
744}
745
746#[inline]
747fn dense_transpose_weighted_response_view(
748    matrix: &Array2<f64>,
749    weights: ArrayView1<'_, f64>,
750    y: ArrayView1<'_, f64>,
751) -> Array1<f64> {
752    // Signed-safe view variant of dense_transpose_weighted_response; see that
753    // function for the rationale on preserving sign through XᵀWy.
754    let p = matrix.ncols();
755    let n = matrix.nrows();
756    let mut out = Array1::<f64>::zeros(p);
757    if matrix.is_standard_layout()
758        && let (Some(x), Some(w), Some(yslice)) =
759            (matrix.as_slice(), weights.as_slice(), y.as_slice())
760    {
761        let out_slice = out.as_slice_mut().expect("zeros are contiguous");
762        for i in 0..n {
763            let scaled = yslice[i] * w[i];
764            if scaled == 0.0 {
765                continue;
766            }
767            let row = &x[i * p..i * p + p];
768            for j in 0..p {
769                out_slice[j] += row[j] * scaled;
770            }
771        }
772        return out;
773    }
774    for i in 0..n {
775        let scaled = y[i] * weights[i];
776        if scaled == 0.0 {
777            continue;
778        }
779        for j in 0..p {
780            out[j] += matrix[[i, j]] * scaled;
781        }
782    }
783    out
784}
785
786#[derive(Clone)]
787pub struct SparseDesignMatrix {
788    matrix: SparseColMat<usize, f64>,
789    dense_cache: Arc<OnceLock<Arc<Array2<f64>>>>,
790    csr_cache: Arc<OnceLock<Arc<SparseRowMat<usize, f64>>>>,
791}
792
793impl SparseDesignMatrix {
794    pub fn new(matrix: SparseColMat<usize, f64>) -> Self {
795        Self {
796            matrix,
797            dense_cache: Arc::new(OnceLock::new()),
798            csr_cache: Arc::new(OnceLock::new()),
799        }
800    }
801
802    fn dense_nbytes(&self) -> Result<usize, String> {
803        self.matrix
804            .nrows()
805            .checked_mul(self.matrix.ncols())
806            .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
807            .ok_or_else(|| {
808                format!(
809                    "dense size overflow for sparse design {}x{}",
810                    self.matrix.nrows(),
811                    self.matrix.ncols()
812                )
813            })
814    }
815
816    fn materialize_dense_arc(&self) -> Arc<Array2<f64>> {
817        let mut out = Array2::<f64>::zeros((self.matrix.nrows(), self.matrix.ncols()));
818        let (symbolic, values) = self.matrix.parts();
819        let col_ptr = symbolic.col_ptr();
820        let row_idx = symbolic.row_idx();
821        for col in 0..self.matrix.ncols() {
822            let start = col_ptr[col];
823            let end = col_ptr[col + 1];
824            for idx in start..end {
825                out[[row_idx[idx], col]] += values[idx];
826            }
827        }
828        Arc::new(out)
829    }
830
831    pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
832        let dense_bytes = self.dense_nbytes()?;
833        if dense_bytes > MAX_SPARSE_TO_DENSE_BYTES {
834            let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
835            return Err(MatrixError::DensificationRefused {
836                reason: format!(
837                    "{context}: refusing to densify sparse design {}x{} (~{gib:.2} GiB); use sparse or matrix-free code",
838                    self.matrix.nrows(),
839                    self.matrix.ncols(),
840                ),
841            }
842            .into());
843        }
844        if dense_bytes <= MAX_PERSISTENT_SPARSE_DENSE_CACHE_BYTES {
845            Ok(self
846                .dense_cache
847                .get_or_init(|| self.materialize_dense_arc())
848                .clone())
849        } else {
850            Ok(self.materialize_dense_arc())
851        }
852    }
853
854    pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
855        self.try_to_dense_arc("SparseDesignMatrix::to_dense_arc")
856            .unwrap_or_else(|msg| {
857                let bt = std::backtrace::Backtrace::force_capture();
858                // SAFETY: infallible-style accessor used at sites where the
859                // caller has already established that densifying this sparse
860                // matrix is permitted (size below the densification guard); a
861                // failure here means the caller broke that contract, which
862                // warrants an immediate abort with backtrace for diagnosis.
863                // SAFETY: infallible accessor; densification refusal here is a caller contract violation.
864                std::panic::panic_any(format!("{msg}\nbacktrace:\n{bt}"))
865            })
866    }
867
868    pub fn to_csr_arc(&self) -> Option<Arc<SparseRowMat<usize, f64>>> {
869        if let Some(cached) = self.csr_cache.get() {
870            return Some(cached.clone());
871        }
872        let csr = self.matrix.as_ref().to_row_major().ok()?;
873        let arc = Arc::new(csr);
874        self.csr_cache.set(arc.clone()).ok();
875        Some(arc)
876    }
877}
878
879impl Deref for SparseDesignMatrix {
880    type Target = SparseColMat<usize, f64>;
881    fn deref(&self) -> &Self::Target {
882        &self.matrix
883    }
884}
885
886impl AsRef<SparseColMat<usize, f64>> for SparseDesignMatrix {
887    fn as_ref(&self) -> &SparseColMat<usize, f64> {
888        &self.matrix
889    }
890}
891
892/// Trait for dense-backed design operators that avoid eager materialization.
893///
894/// Implement this trait for structured designs (multi-channel, rowwise-Kronecker,
895/// etc.) that can perform matvecs and Gram-matrix assembly without forming the
896/// full dense matrix. Wrap implementations in `DenseDesignMatrix::Lazy(Arc<..>)`
897/// to integrate them with the rest of the codebase while keeping the top-level
898/// `DesignMatrix` split strictly `Dense | Sparse`.
899pub trait DenseDesignOperator: LinearOperator + Send + Sync {
900    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
901        // Default: X'(w ⊙ y) via apply_transpose.
902        let n = self.nrows();
903        if weights.len() != n || y.len() != n {
904            return Err(format!(
905                "DenseDesignOperator::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
906                weights.len(),
907                y.len(),
908                n
909            ));
910        }
911        // Signed-safe XᵀWy: linear in w, so observed-Hessian / non-canonical
912        // working weights must flow through unclipped.
913        let mut wy = Array1::<f64>::zeros(n);
914        ndarray::Zip::from(&mut wy)
915            .and(weights)
916            .and(y)
917            .par_for_each(|o, &w, &yi| *o = w * yi);
918        Ok(self.apply_transpose(&wy))
919    }
920
921    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
922        // Default: diag(X M X') computed in chunks via row_chunk — avoids
923        // materializing the full n×p dense matrix at once.
924        if middle.nrows() != self.ncols() || middle.ncols() != self.ncols() {
925            return Err(format!(
926                "DenseDesignOperator::quadratic_form_diag dimension mismatch: {}x{} vs expected {}x{}",
927                middle.nrows(),
928                middle.ncols(),
929                self.ncols(),
930                self.ncols()
931            ));
932        }
933        let n = self.nrows();
934        let mut out = Array1::<f64>::zeros(n);
935        // Process in chunks to bound memory: ~8 MB working set.
936        let chunk_size = (8 * 1024 * 1024 / (self.ncols().max(1) * 8 * 2))
937            .max(16)
938            .min(n.max(1));
939        let mut start = 0;
940        while start < n {
941            let end = (start + chunk_size).min(n);
942            let x_chunk = self.try_row_chunk(start..end).map_err(|e| e.to_string())?;
943            let xm_chunk = fast_ab(&x_chunk, middle);
944            let mut chunk_out = out.slice_mut(ndarray::s![start..end]);
945            ndarray::Zip::from(&mut chunk_out)
946                .and(x_chunk.rows())
947                .and(xm_chunk.rows())
948                // clamp tiny-negative fp drift on diag(X M Xᵀ) when M is a
949                // PSD covariance/precision matrix; not a weight clip.
950                .par_for_each(|o, xr, xmr| *o = xr.dot(&xmr).max(0.0));
951            start = end;
952        }
953        Ok(out)
954    }
955
956    /// Fill a dense row chunk without materializing the full matrix.
957    /// Required: every implementor must provide row-local access here.
958    fn row_chunk_into(
959        &self,
960        rows: Range<usize>,
961        out: ArrayViewMut2<'_, f64>,
962    ) -> Result<(), MatrixMaterializationError>;
963
964    /// Extract a dense row chunk without materializing the full matrix.
965    /// Non-panicking owned-chunk API built on top of `row_chunk_into`.
966    fn try_row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, MatrixMaterializationError> {
967        let mut out = Array2::<f64>::zeros((rows.end - rows.start, self.ncols()));
968        self.row_chunk_into(rows, out.view_mut())?;
969        Ok(out)
970    }
971
972    /// Borrow dense storage when this operator already owns it.
973    fn as_dense_ref(&self) -> Option<&Array2<f64>> {
974        None
975    }
976
977    /// Batched column extraction: returns an `nrows × cols.len()` dense block
978    /// whose k-th column is `apply(e_{cols[k]})`.
979    ///
980    /// Default impl loops over columns and applies a unit vector per call. Operator
981    /// types like `ReparamOperator` that can express the batch as a single GEMM
982    /// (`X · Qs[:, cols]`) should override this — it avoids re-walking the inner
983    /// matvec for every column.
984    fn apply_columns(&self, cols: &[usize]) -> Array2<f64> {
985        let n = self.nrows();
986        let p = self.ncols();
987        let mut out = Array2::<f64>::zeros((n, cols.len()));
988        let mut e = Array1::<f64>::zeros(p);
989        for (k, &j) in cols.iter().enumerate() {
990            assert!(
991                j < p,
992                "DenseDesignOperator::apply_columns: column index {j} out of bounds (ncols={p})"
993            );
994            e[j] = 1.0;
995            let col = self.apply(&e);
996            e[j] = 0.0;
997            out.column_mut(k).assign(&col);
998        }
999        out
1000    }
1001
1002    /// Materialize the full dense matrix. Operators that exist precisely to
1003    /// avoid materialization should still support this for fallback paths,
1004    /// diagnostics, and prediction.
1005    fn to_dense(&self) -> Array2<f64>;
1006
1007    fn estimated_dense_bytes(&self) -> usize {
1008        self.nrows()
1009            .saturating_mul(self.ncols())
1010            .saturating_mul(std::mem::size_of::<f64>())
1011    }
1012
1013    fn try_to_dense_with_policy(
1014        &self,
1015        policy: &MaterializationPolicy,
1016        context: &'static str,
1017    ) -> Result<Arc<Array2<f64>>, MatrixMaterializationError> {
1018        let bytes = self.estimated_dense_bytes();
1019        if !policy.allow_operator_materialization {
1020            return Err(MatrixMaterializationError::Forbidden {
1021                context,
1022                mode: gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired,
1023            });
1024        }
1025        if bytes > policy.max_single_dense_bytes {
1026            return Err(MatrixMaterializationError::TooLarge {
1027                context,
1028                nrows: self.nrows(),
1029                ncols: self.ncols(),
1030                bytes,
1031                limit_bytes: policy.max_single_dense_bytes,
1032            });
1033        }
1034        dense_operator_to_dense_by_chunks(self).map(Arc::new)
1035    }
1036
1037    /// Shared dense materialization via the required row-chunk API.
1038    ///
1039    /// This deliberately does not fall back through `to_dense()`: operator-backed
1040    /// designs can be large-scale, and their chunked row path is the bounded
1041    /// memory materialization contract. Implementations that already own an
1042    /// `Arc<Array2<_>>` should override this to return it directly.
1043    fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1044        Arc::new(
1045            dense_operator_to_dense_by_chunks(self)
1046                .expect("DenseDesignOperator::to_dense_arc: row-chunk materialization failed"),
1047        )
1048    }
1049}
1050
1051#[derive(Clone)]
1052pub enum DenseDesignMatrix {
1053    Materialized(Arc<Array2<f64>>),
1054    Lazy(Arc<dyn DenseDesignOperator>),
1055}
1056
1057impl std::fmt::Debug for DenseDesignMatrix {
1058    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1059        match self {
1060            Self::Materialized(matrix) => {
1061                write!(
1062                    f,
1063                    "DenseDesignMatrix::Materialized({}x{})",
1064                    matrix.nrows(),
1065                    matrix.ncols()
1066                )
1067            }
1068            Self::Lazy(op) => write!(f, "DenseDesignMatrix::Lazy({}x{})", op.nrows(), op.ncols()),
1069        }
1070    }
1071}
1072
1073impl From<Arc<Array2<f64>>> for DenseDesignMatrix {
1074    fn from(value: Arc<Array2<f64>>) -> Self {
1075        Self::Materialized(value)
1076    }
1077}
1078
1079impl From<Array2<f64>> for DenseDesignMatrix {
1080    fn from(value: Array2<f64>) -> Self {
1081        Self::Materialized(Arc::new(value))
1082    }
1083}
1084
1085impl<T> From<Arc<T>> for DenseDesignMatrix
1086where
1087    T: DenseDesignOperator + 'static,
1088{
1089    fn from(value: Arc<T>) -> Self {
1090        Self::Lazy(value)
1091    }
1092}
1093
1094impl DenseDesignMatrix {
1095    /// Stable identity for cache keying.
1096    ///
1097    /// Returns the address of the inner shared `Arc`, which `Clone` shares by
1098    /// reference. Two `DenseDesignMatrix` values produced by cloning the same
1099    /// origin (e.g. the `k` per-coordinate `GlmCurvatureCorrectionOperator`s all
1100    /// built from one converged design) report the same identity, so a `X·F`
1101    /// projection memoized under this id is reused across them within a single
1102    /// outer REML evaluation instead of being re-streamed once per coordinate.
1103    pub fn cache_identity(&self) -> usize {
1104        match self {
1105            Self::Materialized(matrix) => Arc::as_ptr(matrix) as *const () as usize,
1106            Self::Lazy(op) => Arc::as_ptr(op) as *const () as usize,
1107        }
1108    }
1109
1110    pub fn nrows(&self) -> usize {
1111        match self {
1112            Self::Materialized(matrix) => matrix.nrows(),
1113            Self::Lazy(op) => op.nrows(),
1114        }
1115    }
1116
1117    pub fn ncols(&self) -> usize {
1118        match self {
1119            Self::Materialized(matrix) => matrix.ncols(),
1120            Self::Lazy(op) => op.ncols(),
1121        }
1122    }
1123
1124    pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
1125        match self {
1126            Self::Materialized(matrix) => Some(matrix.as_ref()),
1127            Self::Lazy(op) => op.as_dense_ref(),
1128        }
1129    }
1130
1131    pub const fn is_materialized_dense(&self) -> bool {
1132        matches!(self, Self::Materialized(_))
1133    }
1134
1135    pub const fn is_operator_backed(&self) -> bool {
1136        matches!(self, Self::Lazy(_))
1137    }
1138
1139    pub fn to_dense(&self) -> Array2<f64> {
1140        match self {
1141            Self::Materialized(matrix) => matrix.as_ref().clone(),
1142            // Infallible-by-contract dense materialization: callers that reach
1143            // `to_dense` are committed to a dense `Array2<f64>` consumer and
1144            // own the memory budget. Stream row chunks directly via the
1145            // operator's `row_chunk_into`, bypassing the conservative
1146            // single-materialization byte cap (which only callers with a
1147            // strict-operator policy actually need). Strict callers must use
1148            // `try_to_dense_arc_with_policy(ctx, &analytic_operator_required())`
1149            // to get refusal semantics — they explicitly opted into operator-only math.
1150            Self::Lazy(op) => {
1151                dense_operator_to_dense_by_chunks(op.as_ref()).unwrap_or_else(|err| {
1152                    // SAFETY: this branch is the infallible-by-contract dense
1153                    // materialization noted above; the row-chunk path only
1154                    // fails on operator implementation bugs (it does not
1155                    // enforce a byte budget), so failure here is a hard
1156                    // contract violation rather than a runtime condition.
1157                    // SAFETY: row_chunk_into is infallible-by-contract for valid operators.
1158                    std::panic::panic_any(format!(
1159                        "DenseDesignMatrix::to_dense: failed to materialize {}x{} \
1160                         operator-backed design via row chunks: {err}",
1161                        op.nrows(),
1162                        op.ncols(),
1163                    ))
1164                })
1165            }
1166        }
1167    }
1168
1169    pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1170        match self {
1171            Self::Materialized(matrix) => Arc::clone(matrix),
1172            Self::Lazy(op) => Arc::new(
1173                dense_operator_to_dense_by_chunks(op.as_ref()).unwrap_or_else(|err| {
1174                    // SAFETY: companion to the `to_dense` arm above — this
1175                    // path is infallible-by-contract; row-chunk
1176                    // materialization only fails on operator implementation
1177                    // bugs, so a non-Ok result here is a hard contract
1178                    // violation rather than a runtime budget issue.
1179                    // SAFETY: row_chunk_into is infallible-by-contract for valid operators.
1180                    std::panic::panic_any(format!(
1181                        "DenseDesignMatrix::to_dense_arc: failed to materialize {}x{} \
1182                         operator-backed design via row chunks: {err}",
1183                        op.nrows(),
1184                        op.ncols(),
1185                    ))
1186                }),
1187            ),
1188        }
1189    }
1190
1191    pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
1192        // Auto-policy from the design's own dense footprint. The earlier
1193        // shape-based pick reused `for_problem(nrows, ncols, _)`, which is
1194        // intended for classifying the *whole fitting problem* — it flips to
1195        // `AnalyticOperatorRequired` at `nrows >= 100_000` regardless of
1196        // column count. That was wrong for an individual design: a 102052x4
1197        // operator-backed block dense-materializes to only ~3 MiB and is
1198        // genuinely safe. We now pick the permissive policy and let the
1199        // byte-cap inside the materialization guard reject anything that
1200        // would actually blow the default 1 GiB single-materialization budget.
1201        // Callers that need strict refusal still get it by calling
1202        // `try_to_dense_arc_with_policy(ctx, &analytic_operator_required())`.
1203        let policy = ResourcePolicy::default_library();
1204        self.try_to_dense_arc_with_policy(context, &policy)
1205    }
1206
1207    /// Policy-aware variant of [`Self::try_to_dense_arc`].
1208    ///
1209    /// Uses the supplied policy's `max_single_materialization_bytes` cap when
1210    /// deciding whether to densify a lazy operator-backed design.  The default
1211    /// `try_to_dense_arc` always uses `ResourcePolicy::default_library()` (the
1212    /// 1 GiB cap suitable for ad-hoc dense conversions, matching the
1213    /// `CoefficientTransformOperator::MATERIALIZE_MAX_BYTES` ceiling); cache
1214    /// layers that have their own larger cap (e.g.
1215    /// `CoefficientTransformOperator::MATERIALIZE_MAX_BYTES`) can call this
1216    /// method to consume the inner under their own threshold without forcing
1217    /// the conservative default on every consumer.
1218    pub fn try_to_dense_arc_with_policy(
1219        &self,
1220        context: &str,
1221        policy: &ResourcePolicy,
1222    ) -> Result<Arc<Array2<f64>>, String> {
1223        match self {
1224            Self::Materialized(matrix) => Ok(Arc::clone(matrix)),
1225            Self::Lazy(op) => {
1226                panic_or_error_if_large_scale_mode_and_to_dense_called_with_policy(
1227                    context,
1228                    op.nrows(),
1229                    op.ncols(),
1230                    policy,
1231                )?;
1232                dense_operator_to_dense_by_chunks(op.as_ref())
1233                    .map(Arc::new)
1234                    .map_err(|err| {
1235                        format!("{context}: failed to materialize dense row chunks: {err}")
1236                    })
1237            }
1238        }
1239    }
1240
1241    pub fn try_row_chunk(
1242        &self,
1243        rows: Range<usize>,
1244    ) -> Result<Array2<f64>, MatrixMaterializationError> {
1245        match self {
1246            Self::Materialized(matrix) => Ok(matrix.slice(s![rows, ..]).to_owned()),
1247            Self::Lazy(op) => op.try_row_chunk(rows),
1248        }
1249    }
1250
1251    pub fn row_chunk_into(
1252        &self,
1253        rows: Range<usize>,
1254        out: ArrayViewMut2<'_, f64>,
1255    ) -> Result<(), MatrixMaterializationError> {
1256        match self {
1257            Self::Materialized(matrix) => {
1258                let mut out = out;
1259                out.assign(&matrix.slice(s![rows, ..]));
1260                Ok(())
1261            }
1262            Self::Lazy(op) => op.row_chunk_into(rows, out),
1263        }
1264    }
1265}
1266
1267impl LinearOperator for DenseDesignMatrix {
1268    fn nrows(&self) -> usize {
1269        DenseDesignMatrix::nrows(self)
1270    }
1271
1272    fn ncols(&self) -> usize {
1273        DenseDesignMatrix::ncols(self)
1274    }
1275
1276    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
1277        match self {
1278            Self::Materialized(matrix) => fast_av(matrix, vector),
1279            Self::Lazy(op) => op.apply(vector),
1280        }
1281    }
1282
1283    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
1284        match self {
1285            Self::Materialized(matrix) => fast_atv(matrix, vector),
1286            Self::Lazy(op) => op.apply_transpose(vector),
1287        }
1288    }
1289
1290    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
1291        match self {
1292            Self::Materialized(matrix) => {
1293                if weights.len() != matrix.nrows() {
1294                    return Err(format!(
1295                        "DenseDesignMatrix::diag_xtw_x weight length mismatch: weights={}, nrows={}",
1296                        weights.len(),
1297                        matrix.nrows()
1298                    ));
1299                }
1300                let mut xtwx = Array2::<f64>::zeros((matrix.ncols(), matrix.ncols()));
1301                stream_weighted_crossprod_into(
1302                    matrix,
1303                    weights,
1304                    &mut xtwx,
1305                    CrossprodStructure::Full,
1306                    CrossprodAccum::Replace,
1307                    effective_global_parallelism(),
1308                );
1309                Ok(xtwx)
1310            }
1311            Self::Lazy(op) => op.diag_xtw_x(weights),
1312        }
1313    }
1314
1315    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
1316        // PSD precondition discharged at the typed boundary
1317        // (`PsdWeightsView::try_new` inside callers of `xt_diag_x_psd_op` /
1318        // `diag_gram_view`). Observed-Hessian sites use the signed Gram path.
1319        match self {
1320            Self::Materialized(matrix) => {
1321                let n = matrix.nrows();
1322                let p = matrix.ncols();
1323                if weights.len() != n {
1324                    return Err(format!(
1325                        "DenseDesignMatrix::diag_gram weight length mismatch: weights={}, nrows={}",
1326                        weights.len(),
1327                        n
1328                    ));
1329                }
1330                if (n as u64) * (p as u64) < DENSE_ROW_PARALLEL_MIN_NP {
1331                    let mut diag = Array1::<f64>::zeros(p);
1332                    for i in 0..n {
1333                        let wi = weights[i];
1334                        if wi == 0.0 {
1335                            continue;
1336                        }
1337                        for j in 0..p {
1338                            let xij = matrix[[i, j]];
1339                            diag[j] += wi * xij * xij;
1340                        }
1341                    }
1342                    return Ok(diag);
1343                }
1344                let diag = (0..n)
1345                    .into_par_iter()
1346                    .fold(
1347                        || Array1::<f64>::zeros(p),
1348                        |mut acc, i| {
1349                            let wi = weights[i];
1350                            if wi != 0.0 {
1351                                for j in 0..p {
1352                                    let xij = matrix[[i, j]];
1353                                    acc[j] += wi * xij * xij;
1354                                }
1355                            }
1356                            acc
1357                        },
1358                    )
1359                    .reduce(
1360                        || Array1::<f64>::zeros(p),
1361                        |mut a, b| {
1362                            a += &b;
1363                            a
1364                        },
1365                    );
1366                Ok(diag)
1367            }
1368            Self::Lazy(op) => op.diag_gram(weights),
1369        }
1370    }
1371
1372    fn apply_weighted_normal(
1373        &self,
1374        weights: &Array1<f64>,
1375        vector: &Array1<f64>,
1376        penalty: Option<&Array2<f64>>,
1377        ridge: f64,
1378    ) -> Array1<f64> {
1379        assert_eq!(
1380            weights.len(),
1381            self.nrows(),
1382            "DenseDesignMatrix::apply_weighted_normal weight length mismatch"
1383        );
1384        assert_eq!(
1385            vector.len(),
1386            self.ncols(),
1387            "DenseDesignMatrix::apply_weighted_normal vector length mismatch"
1388        );
1389        // PSD precondition discharged at the typed boundary: callers driving
1390        // this Fisher-scoring PCG matvec ((XᵀWX + S + ρI) v) construct their
1391        // weights through `PsdWeightsView::try_new`. Signed observed-Hessian
1392        // assembly routes through `xt_diag_x_signed_op` instead.
1393        match self {
1394            Self::Materialized(matrix) => {
1395                let n = matrix.nrows();
1396                let p = matrix.ncols();
1397                let mut out = if (n as u64) * (p as u64) < DENSE_ROW_PARALLEL_MIN_NP {
1398                    let mut out = Array1::<f64>::zeros(p);
1399                    for i in 0..n {
1400                        let wi = weights[i];
1401                        if wi == 0.0 {
1402                            continue;
1403                        }
1404                        let mut row_dot = 0.0_f64;
1405                        for j in 0..p {
1406                            row_dot += matrix[[i, j]] * vector[j];
1407                        }
1408                        if row_dot == 0.0 {
1409                            continue;
1410                        }
1411                        let scaled = wi * row_dot;
1412                        for j in 0..p {
1413                            out[j] += scaled * matrix[[i, j]];
1414                        }
1415                    }
1416                    out
1417                } else {
1418                    (0..n)
1419                        .into_par_iter()
1420                        .fold(
1421                            || Array1::<f64>::zeros(p),
1422                            |mut acc, i| {
1423                                let wi = weights[i];
1424                                if wi != 0.0 {
1425                                    let mut row_dot = 0.0_f64;
1426                                    for j in 0..p {
1427                                        row_dot += matrix[[i, j]] * vector[j];
1428                                    }
1429                                    if row_dot != 0.0 {
1430                                        let scaled = wi * row_dot;
1431                                        for j in 0..p {
1432                                            acc[j] += scaled * matrix[[i, j]];
1433                                        }
1434                                    }
1435                                }
1436                                acc
1437                            },
1438                        )
1439                        .reduce(
1440                            || Array1::<f64>::zeros(p),
1441                            |mut a, b| {
1442                                a += &b;
1443                                a
1444                            },
1445                        )
1446                };
1447                if let Some(pen) = penalty {
1448                    out += &fast_av(pen, vector);
1449                }
1450                if ridge > 0.0 {
1451                    for j in 0..p {
1452                        out[j] += ridge * vector[j];
1453                    }
1454                }
1455                out
1456            }
1457            Self::Lazy(op) => op.apply_weighted_normal(weights, vector, penalty, ridge),
1458        }
1459    }
1460
1461    fn uses_matrix_free_pcg(&self) -> bool {
1462        match self {
1463            Self::Materialized(_) => true,
1464            Self::Lazy(op) => op.uses_matrix_free_pcg(),
1465        }
1466    }
1467}
1468
1469impl DenseDesignOperator for DenseDesignMatrix {
1470    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
1471        match self {
1472            Self::Materialized(matrix) => {
1473                if weights.len() != matrix.nrows() || y.len() != matrix.nrows() {
1474                    return Err(format!(
1475                        "DenseDesignMatrix::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
1476                        weights.len(),
1477                        y.len(),
1478                        matrix.nrows()
1479                    ));
1480                }
1481                Ok(dense_transpose_weighted_response(matrix, weights, y, None))
1482            }
1483            Self::Lazy(op) => op.compute_xtwy(weights, y),
1484        }
1485    }
1486
1487    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
1488        match self {
1489            Self::Materialized(matrix) => {
1490                if middle.nrows() != matrix.ncols() || middle.ncols() != matrix.ncols() {
1491                    return Err(format!(
1492                        "quadratic_form_diag dimension mismatch: matrix is {}x{}, expected {}x{}",
1493                        middle.nrows(),
1494                        middle.ncols(),
1495                        matrix.ncols(),
1496                        matrix.ncols()
1497                    ));
1498                }
1499                let xc = fast_ab(matrix, middle);
1500                let n = matrix.nrows();
1501                let p = matrix.ncols();
1502                let mut out = Array1::<f64>::zeros(n);
1503                if matrix.is_standard_layout()
1504                    && xc.is_standard_layout()
1505                    && let (Some(m_all), Some(xc_all), Some(out_slice)) =
1506                        (matrix.as_slice(), xc.as_slice(), out.as_slice_mut())
1507                {
1508                    // Parallel per-row clamped quadratic-form diagonal with
1509                    // stride-1 reads from both row-major operands. Avoids the
1510                    // per-row `Array1::dot` call's overhead at large-scale shapes
1511                    // (n ≈ 2e5, p ≈ 33).
1512                    use rayon::iter::{IndexedParallelIterator, ParallelIterator};
1513                    use rayon::slice::ParallelSliceMut;
1514                    out_slice
1515                        .par_chunks_mut(1)
1516                        .enumerate()
1517                        .for_each(|(i, slot)| {
1518                            let off = i * p;
1519                            let m_row = &m_all[off..off + p];
1520                            let xc_row = &xc_all[off..off + p];
1521                            let mut acc = 0.0_f64;
1522                            for j in 0..p {
1523                                acc += m_row[j] * xc_row[j];
1524                            }
1525                            // clamp tiny-negative fp drift on diag(X M Xᵀ)
1526                            // when M is a PSD covariance/precision matrix.
1527                            slot[0] = acc.max(0.0);
1528                        });
1529                } else {
1530                    for i in 0..n {
1531                        // clamp tiny-negative fp drift on diag(X M Xᵀ)
1532                        // when M is a PSD covariance/precision matrix.
1533                        out[i] = matrix.row(i).dot(&xc.row(i)).max(0.0);
1534                    }
1535                }
1536                Ok(out)
1537            }
1538            Self::Lazy(op) => op.quadratic_form_diag(middle),
1539        }
1540    }
1541
1542    fn as_dense_ref(&self) -> Option<&Array2<f64>> {
1543        DenseDesignMatrix::as_dense_ref(self)
1544    }
1545
1546    fn row_chunk_into(
1547        &self,
1548        rows: Range<usize>,
1549        mut out: ArrayViewMut2<'_, f64>,
1550    ) -> Result<(), MatrixMaterializationError> {
1551        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
1552            return Err(MatrixMaterializationError::MissingRowChunk {
1553                context: "DenseDesignMatrix::row_chunk_into shape mismatch",
1554            });
1555        }
1556        match self {
1557            Self::Materialized(matrix) => {
1558                out.assign(&matrix.slice(s![rows, ..]));
1559                Ok(())
1560            }
1561            Self::Lazy(op) => op.row_chunk_into(rows, out),
1562        }
1563    }
1564
1565    fn to_dense(&self) -> Array2<f64> {
1566        DenseDesignMatrix::to_dense(self)
1567    }
1568
1569    fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1570        DenseDesignMatrix::to_dense_arc(self)
1571    }
1572}
1573
1574// ---------------------------------------------------------------------------
1575// ReparamOperator — lazy X·Qs composition without materialization
1576// ---------------------------------------------------------------------------
1577
1578/// Lazy composed operator for reparameterized design: X_transformed = X_original · Qs.
1579///
1580/// Instead of materializing the dense n×p product X·Qs, this operator applies
1581/// the p×p orthogonal transform Qs on the coefficient side:
1582///
1583///   apply(v)           → X · (Qs · v)
1584///   apply_transpose(v) → Qs^T · (X^T · v)
1585///   diag_xtw_x(w)      → Qs^T · (X^T W X) · Qs
1586///
1587/// This preserves the sparsity of X and avoids an O(n·p) dense allocation.
1588pub struct ReparamOperator {
1589    x_original: DesignMatrix,
1590    qs: Arc<Array2<f64>>,
1591    n: usize,
1592    p: usize,
1593}
1594
1595impl ReparamOperator {
1596    pub fn new(x_original: DesignMatrix, qs: Arc<Array2<f64>>) -> Self {
1597        let n = x_original.nrows();
1598        let p = qs.ncols();
1599        assert_eq!(
1600            x_original.ncols(),
1601            qs.nrows(),
1602            "ReparamOperator: X cols ({}) must match Qs rows ({})",
1603            x_original.ncols(),
1604            qs.nrows()
1605        );
1606        Self {
1607            x_original,
1608            qs,
1609            n,
1610            p,
1611        }
1612    }
1613
1614    /// Access the underlying original design matrix.
1615    pub fn x_original(&self) -> &DesignMatrix {
1616        &self.x_original
1617    }
1618
1619    /// Access the Qs orthogonal transform.
1620    pub fn qs(&self) -> &Array2<f64> {
1621        &self.qs
1622    }
1623}
1624
1625impl LinearOperator for ReparamOperator {
1626    fn nrows(&self) -> usize {
1627        self.n
1628    }
1629
1630    fn ncols(&self) -> usize {
1631        self.p
1632    }
1633
1634    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
1635        // X · (Qs · v): apply Qs on the p-dimensional side first, then sparse/dense X.
1636        let qv = self.qs.dot(vector);
1637        self.x_original.apply(&qv)
1638    }
1639
1640    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
1641        // Qs^T · (X^T · v): apply X^T first (sparse matvec), then small dense Qs^T.
1642        let xtv = self.x_original.apply_transpose(vector);
1643        fast_atv(&self.qs, &xtv)
1644    }
1645
1646    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
1647        // Qs^T · (X^T W X) · Qs: compute X^TWX in original basis (sparse-friendly),
1648        // then two small p×p multiplications.
1649        let xtwx = self.x_original.diag_xtw_x(weights)?;
1650        let tmp = fast_atb(&self.qs, &xtwx);
1651        Ok(fast_ab(&tmp, &self.qs))
1652    }
1653
1654    fn apply_weighted_normal(
1655        &self,
1656        weights: &Array1<f64>,
1657        vector: &Array1<f64>,
1658        penalty: Option<&Array2<f64>>,
1659        ridge: f64,
1660    ) -> Array1<f64> {
1661        assert_eq!(
1662            weights.len(),
1663            self.x_original.nrows(),
1664            "ReparamOperator::apply_weighted_normal weight length mismatch"
1665        );
1666        assert_eq!(
1667            vector.len(),
1668            self.qs.ncols(),
1669            "ReparamOperator::apply_weighted_normal vector length mismatch"
1670        );
1671        // PSD precondition discharged at the typed boundary: this is the
1672        // Fisher-scoring PCG normal-equations matvec; signed observed-Hessian
1673        // assembly does not reach this path.
1674        // Qs^T X^T W X Qs v + S v + ridge v
1675        let qv = self.qs.dot(vector);
1676        let xqv = self.x_original.apply(&qv);
1677        let mut wxqv = xqv;
1678        for i in 0..wxqv.len() {
1679            wxqv[i] *= weights[i];
1680        }
1681        let xtw = self.x_original.apply_transpose(&wxqv);
1682        let mut out = fast_atv(&self.qs, &xtw);
1683        if let Some(pen) = penalty {
1684            out += &fast_av(pen, vector);
1685        }
1686        if ridge > 0.0 {
1687            // BLAS axpy: out += ridge * vector, no temporary allocation.
1688            out.scaled_add(ridge, vector);
1689        }
1690        out
1691    }
1692}
1693
1694impl DenseDesignOperator for ReparamOperator {
1695    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
1696        // Qs^T · X^T(w ⊙ y)
1697        let xtwy = self.x_original.compute_xtwy(weights, y)?;
1698        Ok(fast_atv(&self.qs, &xtwy))
1699    }
1700
1701    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
1702        // diag(X Qs M Qs^T X^T) = diag(X · (Qs M Qs^T) · X^T)
1703        // Compute M_orig = Qs · M · Qs^T (p×p), then delegate to x_original.
1704        let qm = fast_ab(&self.qs, middle);
1705        let m_orig = fast_ab(&qm, &self.qs.t().to_owned());
1706        self.x_original.quadratic_form_diag(&m_orig)
1707    }
1708
1709    fn to_dense(&self) -> Array2<f64> {
1710        match &self.x_original {
1711            DesignMatrix::Dense(x) => fast_ab(x.to_dense_arc().as_ref(), &self.qs),
1712            _ => {
1713                let x_dense = self.x_original.to_dense();
1714                fast_ab(&x_dense, &self.qs)
1715            }
1716        }
1717    }
1718
1719    fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1720        Arc::new(self.to_dense())
1721    }
1722
1723    fn as_dense_ref(&self) -> Option<&Array2<f64>> {
1724        None
1725    }
1726
1727    fn apply_columns(&self, cols: &[usize]) -> Array2<f64> {
1728        // (X · Qs)[:, cols] = X · Qs[:, cols] — one batched matvec over the inner
1729        // design instead of one-per-column dispatch on a unit vector.
1730        let qs_cols = self.qs.select(Axis(1), cols);
1731        match &self.x_original {
1732            DesignMatrix::Dense(x) => match x.as_dense_ref() {
1733                Some(x_dense) => fast_ab(x_dense, &qs_cols),
1734                None => {
1735                    let n = self.n;
1736                    let mut out = Array2::<f64>::zeros((n, cols.len()));
1737                    for k in 0..cols.len() {
1738                        let col = qs_cols.column(k).to_owned();
1739                        let xc = self.x_original.apply(&col);
1740                        out.column_mut(k).assign(&xc);
1741                    }
1742                    out
1743                }
1744            },
1745            DesignMatrix::Sparse(_) => {
1746                // Sparse X: apply column-by-column over the small qs_cols block.
1747                let n = self.n;
1748                let mut out = Array2::<f64>::zeros((n, cols.len()));
1749                for k in 0..cols.len() {
1750                    let col = qs_cols.column(k).to_owned();
1751                    let xc = self.x_original.apply(&col);
1752                    out.column_mut(k).assign(&xc);
1753                }
1754                out
1755            }
1756        }
1757    }
1758
1759    fn row_chunk_into(
1760        &self,
1761        rows: Range<usize>,
1762        mut out: ArrayViewMut2<'_, f64>,
1763    ) -> Result<(), MatrixMaterializationError> {
1764        if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
1765            return Err(MatrixMaterializationError::MissingRowChunk {
1766                context: "ReparamOperator::row_chunk_into shape mismatch",
1767            });
1768        }
1769        match &self.x_original {
1770            DesignMatrix::Dense(x) => {
1771                let chunk = x.try_row_chunk(rows)?;
1772                out.assign(&fast_ab(&chunk, &self.qs));
1773            }
1774            DesignMatrix::Sparse(sdm) => {
1775                // Extract rows directly from CSR without densifying the full matrix.
1776                let csr = sdm
1777                    .to_csr_arc()
1778                    .ok_or(MatrixMaterializationError::MissingRowChunk {
1779                        context: "ReparamOperator::row_chunk_into: failed to obtain CSR view",
1780                    })?;
1781                let sym = csr.symbolic();
1782                let row_ptr = sym.row_ptr();
1783                let col_idx = sym.col_idx();
1784                let vals = csr.val();
1785                let chunk_rows = rows.end - rows.start;
1786                let p_inner = sdm.ncols();
1787                let mut chunk = Array2::<f64>::zeros((chunk_rows, p_inner));
1788                for (local, global) in (rows.start..rows.end).enumerate() {
1789                    for ptr in row_ptr[global]..row_ptr[global + 1] {
1790                        chunk[[local, col_idx[ptr]]] = vals[ptr];
1791                    }
1792                }
1793                out.assign(&fast_ab(&chunk, &self.qs));
1794            }
1795        }
1796        Ok(())
1797    }
1798}
1799
1800// ---------------------------------------------------------------------------
1801// RandomEffectOperator — O(n) implicit design for random intercepts
1802// ---------------------------------------------------------------------------
1803
1804/// Implicit design operator for random-intercept effects.
1805///
1806/// Instead of materializing an n × q one-hot matrix, stores only the O(n)
1807/// integer group-label vector.  All matvecs, Gram assembly, and
1808/// weighted-normal products operate in O(n) time and O(n + q) memory.
1809#[derive(Clone)]
1810pub struct RandomEffectOperator {
1811    /// For each observation, the column index of its group (0..num_groups),
1812    /// or `None` if the observation's level was not in the kept set (prediction
1813    /// with unseen levels).
1814    pub group_ids: Vec<Option<usize>>,
1815    /// Number of observations.
1816    pub n: usize,
1817    /// Number of groups (columns).
1818    pub num_groups: usize,
1819}
1820
1821impl RandomEffectOperator {
1822    pub fn new(group_ids: Vec<Option<usize>>, num_groups: usize) -> Self {
1823        let n = group_ids.len();
1824        Self {
1825            group_ids,
1826            n,
1827            num_groups,
1828        }
1829    }
1830
1831    /// For a dense block X_dense (n × p_dense) and weights w, compute
1832    /// X_dense' diag(w) X_re  →  (p_dense × num_groups) matrix.
1833    ///
1834    /// Column g of the result = Σ_{i: group[i]=g} w[i] * X_dense.row(i).
1835    /// Total cost: O(n × p_dense).
1836    pub fn weighted_cross_with_dense(
1837        &self,
1838        dense: &Array2<f64>,
1839        weights: &Array1<f64>,
1840    ) -> Array2<f64> {
1841        assert_eq!(
1842            dense.nrows(),
1843            self.n,
1844            "RandomEffectOperator::weighted_cross_with_dense row mismatch"
1845        );
1846        assert_eq!(
1847            weights.len(),
1848            self.n,
1849            "RandomEffectOperator::weighted_cross_with_dense weight length mismatch"
1850        );
1851        let p_dense = dense.ncols();
1852        let mut cross = Array2::<f64>::zeros((p_dense, self.num_groups));
1853        for i in 0..self.n {
1854            if let Some(g) = self.group_ids[i] {
1855                let wi = weights[i].max(0.0);
1856                if wi == 0.0 {
1857                    continue;
1858                }
1859                for j in 0..p_dense {
1860                    cross[[j, g]] += wi * dense[[i, j]];
1861                }
1862            }
1863        }
1864        cross
1865    }
1866
1867    /// For two RE operators, compute X_re_a' diag(w) X_re_b → (qa × qb).
1868    /// Entry (a, b) = Σ_{i: group_a[i]=a AND group_b[i]=b} w[i].
1869    /// Cost: O(n).
1870    pub fn weighted_cross_with_re(
1871        &self,
1872        other: &RandomEffectOperator,
1873        weights: &Array1<f64>,
1874    ) -> Array2<f64> {
1875        assert_eq!(
1876            other.n, self.n,
1877            "RandomEffectOperator::weighted_cross_with_re row mismatch"
1878        );
1879        assert_eq!(
1880            weights.len(),
1881            self.n,
1882            "RandomEffectOperator::weighted_cross_with_re weight length mismatch"
1883        );
1884        let mut cross = Array2::<f64>::zeros((self.num_groups, other.num_groups));
1885        for i in 0..self.n {
1886            if let (Some(a), Some(b)) = (self.group_ids[i], other.group_ids[i]) {
1887                let wi = weights[i].max(0.0);
1888                if wi != 0.0 {
1889                    cross[[a, b]] += wi;
1890                }
1891            }
1892        }
1893        cross
1894    }
1895}
1896
1897impl LinearOperator for RandomEffectOperator {
1898    fn nrows(&self) -> usize {
1899        self.n
1900    }
1901
1902    fn ncols(&self) -> usize {
1903        self.num_groups
1904    }
1905
1906    /// Forward: out[i] = β[group[i]], or 0 if unmatched.
1907    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
1908        use rayon::prelude::*;
1909        let out: Vec<f64> = self
1910            .group_ids
1911            .par_iter()
1912            .map(|g| g.map(|g| vector[g]).unwrap_or(0.0))
1913            .collect();
1914        Array1::from(out)
1915    }
1916
1917    /// Transpose: out[g] = Σ_{i: group[i]=g} v[i].
1918    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
1919        let mut out = Array1::<f64>::zeros(self.num_groups);
1920        for i in 0..self.n {
1921            if let Some(g) = self.group_ids[i] {
1922                out[g] += vector[i];
1923            }
1924        }
1925        out
1926    }
1927
1928    /// X'WX for a one-hot design is diagonal: D[g,g] = Σ_{i: group[i]=g} w[i].
1929    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
1930        if weights.len() != self.n {
1931            return Err(format!(
1932                "RandomEffectOperator::diag_xtw_x weight length mismatch: weights={}, nrows={}",
1933                weights.len(),
1934                self.n
1935            ));
1936        }
1937        let q = self.num_groups;
1938        let mut xtwx = Array2::<f64>::zeros((q, q));
1939        for i in 0..self.n {
1940            if let Some(g) = self.group_ids[i] {
1941                xtwx[[g, g]] += weights[i].max(0.0);
1942            }
1943        }
1944        Ok(xtwx)
1945    }
1946
1947    /// Diagonal of X'WX: per-group weight sums.
1948    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
1949        if weights.len() != self.n {
1950            return Err(format!(
1951                "RandomEffectOperator::diag_gram weight length mismatch: weights={}, nrows={}",
1952                weights.len(),
1953                self.n
1954            ));
1955        }
1956        let mut diag = Array1::<f64>::zeros(self.num_groups);
1957        for i in 0..self.n {
1958            if let Some(g) = self.group_ids[i] {
1959                diag[g] += weights[i].max(0.0);
1960            }
1961        }
1962        Ok(diag)
1963    }
1964
1965    /// Fused X'WXβ + Sβ + ridge·β.  O(n + q).
1966    fn apply_weighted_normal(
1967        &self,
1968        weights: &Array1<f64>,
1969        vector: &Array1<f64>,
1970        penalty: Option<&Array2<f64>>,
1971        ridge: f64,
1972    ) -> Array1<f64> {
1973        assert_eq!(
1974            weights.len(),
1975            self.n,
1976            "RandomEffectOperator::apply_weighted_normal weight length mismatch"
1977        );
1978        assert_eq!(
1979            vector.len(),
1980            self.num_groups,
1981            "RandomEffectOperator::apply_weighted_normal vector length mismatch"
1982        );
1983        // Step 1: accumulate per-group weighted β[g] contributions.
1984        //   group_acc[g] = Σ_{i in group g} w[i]
1985        //   result[g] = group_acc[g] * vector[g]
1986        let mut group_wacc = Array1::<f64>::zeros(self.num_groups);
1987        for i in 0..self.n {
1988            if let Some(g) = self.group_ids[i] {
1989                group_wacc[g] += weights[i].max(0.0);
1990            }
1991        }
1992        let mut out = Array1::<f64>::zeros(self.num_groups);
1993        for g in 0..self.num_groups {
1994            out[g] = group_wacc[g] * vector[g];
1995        }
1996        if let Some(pen) = penalty {
1997            out += &pen.dot(vector);
1998        }
1999        if ridge > 0.0 {
2000            for g in 0..self.num_groups {
2001                out[g] += ridge * vector[g];
2002            }
2003        }
2004        out
2005    }
2006
2007    fn uses_matrix_free_pcg(&self) -> bool {
2008        true
2009    }
2010}
2011
2012impl DenseDesignOperator for RandomEffectOperator {
2013    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
2014        if weights.len() != self.n || y.len() != self.n {
2015            return Err(format!(
2016                "RandomEffectOperator::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
2017                weights.len(),
2018                y.len(),
2019                self.n
2020            ));
2021        }
2022        let mut out = Array1::<f64>::zeros(self.num_groups);
2023        for i in 0..self.n {
2024            if let Some(g) = self.group_ids[i] {
2025                let wi = weights[i].max(0.0);
2026                out[g] += wi * y[i];
2027            }
2028        }
2029        Ok(out)
2030    }
2031
2032    /// diag(X M X') for one-hot X: out[i] = M[group[i], group[i]].
2033    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
2034        use rayon::prelude::*;
2035        let out: Vec<f64> = self
2036            .group_ids
2037            .par_iter()
2038            .map(|g| g.map(|g| middle[[g, g]].max(0.0)).unwrap_or(0.0))
2039            .collect();
2040        Ok(Array1::from(out))
2041    }
2042
2043    fn row_chunk_into(
2044        &self,
2045        rows: Range<usize>,
2046        mut out: ArrayViewMut2<'_, f64>,
2047    ) -> Result<(), MatrixMaterializationError> {
2048        if out.nrows() != rows.end - rows.start || out.ncols() != self.num_groups {
2049            return Err(MatrixMaterializationError::MissingRowChunk {
2050                context: "RandomEffectOperator::row_chunk_into shape mismatch",
2051            });
2052        }
2053        out.fill(0.0);
2054        for (local, global) in rows.enumerate() {
2055            if let Some(g) = self.group_ids[global] {
2056                out[[local, g]] = 1.0;
2057            }
2058        }
2059        Ok(())
2060    }
2061
2062    /// Materialize the full n × q one-hot matrix (fallback for diagnostics).
2063    fn to_dense(&self) -> Array2<f64> {
2064        let mut out = Array2::<f64>::zeros((self.n, self.num_groups));
2065        ndarray::Zip::indexed(out.rows_mut()).par_for_each(|i, mut row| {
2066            if let Some(g) = self.group_ids[i] {
2067                row[g] = 1.0;
2068            }
2069        });
2070        out
2071    }
2072}
2073
2074// ---------------------------------------------------------------------------
2075// BlockDesignOperator — horizontal block composition [B₀ | B₁ | … | Bₖ]
2076// ---------------------------------------------------------------------------
2077
2078/// A single block in a horizontally-composed design operator.
2079#[derive(Clone)]
2080pub enum DesignBlock {
2081    Dense(DenseDesignMatrix),
2082    Sparse(SparseDesignMatrix),
2083    RandomEffect(Arc<RandomEffectOperator>),
2084    /// Implicit all-ones intercept column: n rows, 1 column, zero storage.
2085    Intercept(usize),
2086}
2087
2088impl DesignBlock {
2089    pub fn nrows(&self) -> usize {
2090        match self {
2091            Self::Dense(d) => d.nrows(),
2092            Self::Sparse(s) => s.nrows(),
2093            Self::RandomEffect(op) => op.nrows(),
2094            Self::Intercept(n) => *n,
2095        }
2096    }
2097
2098    pub fn ncols(&self) -> usize {
2099        match self {
2100            Self::Dense(d) => d.ncols(),
2101            Self::Sparse(s) => s.ncols(),
2102            Self::RandomEffect(op) => op.ncols(),
2103            Self::Intercept(_) => 1,
2104        }
2105    }
2106
2107    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
2108        match self {
2109            Self::Dense(d) => d.apply(vector),
2110            Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).apply(vector),
2111            Self::RandomEffect(op) => op.apply(vector),
2112            Self::Intercept(n) => Array1::from_elem(*n, vector[0]),
2113        }
2114    }
2115
2116    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
2117        match self {
2118            Self::Dense(d) => d.apply_transpose(vector),
2119            Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).apply_transpose(vector),
2120            Self::RandomEffect(op) => op.apply_transpose(vector),
2121            Self::Intercept(_) => {
2122                let sum: f64 = vector.iter().sum();
2123                Array1::from_vec(vec![sum])
2124            }
2125        }
2126    }
2127
2128    fn try_row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, MatrixMaterializationError> {
2129        match self {
2130            Self::Dense(d) => d.try_row_chunk(rows),
2131            Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).try_row_chunk(rows),
2132            Self::RandomEffect(op) => op.try_row_chunk(rows),
2133            Self::Intercept(_) => Ok(Array2::ones((rows.end - rows.start, 1))),
2134        }
2135    }
2136
2137    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
2138        match self {
2139            Self::Dense(d) => d.diag_xtw_x(weights),
2140            Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).diag_xtw_x(weights),
2141            Self::RandomEffect(op) => op.diag_xtw_x(weights),
2142            Self::Intercept(_) => {
2143                let sum: f64 = weights.iter().map(|w| w.max(0.0)).sum();
2144                Ok(Array2::from_elem((1, 1), sum))
2145            }
2146        }
2147    }
2148
2149    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
2150        match self {
2151            Self::Dense(d) => d.diag_gram(weights),
2152            Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).diag_gram(weights),
2153            Self::RandomEffect(op) => op.diag_gram(weights),
2154            Self::Intercept(_) => {
2155                let sum: f64 = weights.iter().map(|w| w.max(0.0)).sum();
2156                Ok(Array1::from_vec(vec![sum]))
2157            }
2158        }
2159    }
2160
2161    /// Materialize this block as a dense (n, p_k) matrix.
2162    fn to_dense(&self) -> Array2<f64> {
2163        match self {
2164            Self::Dense(d) => d.to_dense(),
2165            Self::Sparse(s) => s.to_dense_arc().as_ref().clone(),
2166            Self::RandomEffect(op) => op.to_dense(),
2167            Self::Intercept(n) => Array2::ones((*n, 1)),
2168        }
2169    }
2170}
2171
2172/// Horizontally-composed design operator: X = [B₀ | B₁ | … | Bₖ].
2173///
2174/// Each block can be dense or operator-based.  The coefficient vector β is
2175/// partitioned by block, and the forward product is the sum of per-block
2176/// contributions.  Cross-block terms in X'WX are computed via specialized
2177/// methods on `RandomEffectOperator` for efficiency.
2178#[derive(Clone)]
2179pub struct BlockDesignOperator {
2180    pub blocks: Vec<DesignBlock>,
2181    /// Cumulative column offsets: block i owns columns col_offsets[i]..col_offsets[i+1].
2182    pub col_offsets: Vec<usize>,
2183    pub total_cols: usize,
2184    pub n: usize,
2185}
2186
2187impl BlockDesignOperator {
2188    pub fn new(blocks: Vec<DesignBlock>) -> Result<Self, String> {
2189        if blocks.is_empty() {
2190            return Err("BlockDesignOperator: need at least one block".to_string());
2191        }
2192        let n = blocks[0].nrows();
2193        for (i, b) in blocks.iter().enumerate() {
2194            if b.nrows() != n {
2195                return Err(format!(
2196                    "BlockDesignOperator: block {i} has {} rows, expected {n}",
2197                    b.nrows()
2198                ));
2199            }
2200        }
2201        let mut col_offsets = Vec::with_capacity(blocks.len() + 1);
2202        col_offsets.push(0);
2203        for b in &blocks {
2204            col_offsets.push(col_offsets.last().unwrap() + b.ncols());
2205        }
2206        let total_cols = *col_offsets.last().unwrap();
2207        Ok(Self {
2208            blocks,
2209            col_offsets,
2210            total_cols,
2211            n,
2212        })
2213    }
2214
2215    fn weighted_cross_chunked(
2216        &self,
2217        left: &DesignBlock,
2218        right: &DesignBlock,
2219        weights: &Array1<f64>,
2220    ) -> Result<Array2<f64>, String> {
2221        let pi = left.ncols();
2222        let pj = right.ncols();
2223        let mut cross = Array2::<f64>::zeros((pi, pj));
2224        for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2225            let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2226            let left_chunk = left.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2227            let right_chunk = right.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2228            for local in 0..(end - start) {
2229                // Cross-block X_iᵀ diag(w) X_j is linear in w and well-defined
2230                // for any sign — observed-Hessian assembly (binomial+cloglog,
2231                // Gamma+identity, etc.) legitimately supplies signed w_hessian
2232                // here. The prior `.max(0.0)` silently zeroed the negative-
2233                // curvature contribution, producing an inconsistent off-
2234                // diagonal block. Mirrors the dense-rows kernel's sign-correct
2235                // accumulation a few hundred lines above.
2236                let wi = weights[start + local];
2237                if wi == 0.0 {
2238                    continue;
2239                }
2240                for a in 0..pi {
2241                    let scaled = wi * left_chunk[[local, a]];
2242                    if scaled == 0.0 {
2243                        continue;
2244                    }
2245                    for b in 0..pj {
2246                        cross[[a, b]] += scaled * right_chunk[[local, b]];
2247                    }
2248                }
2249            }
2250        }
2251        Ok(cross)
2252    }
2253
2254    fn quadratic_form_diag_cross_chunked(
2255        &self,
2256        block_a: &DesignBlock,
2257        block_b: &DesignBlock,
2258        m_ab: &Array2<f64>,
2259    ) -> Result<Array1<f64>, String> {
2260        let mut out = Array1::<f64>::zeros(self.n);
2261        for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2262            let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2263            let a_chunk = block_a
2264                .try_row_chunk(start..end)
2265                .map_err(|e| e.to_string())?;
2266            let b_chunk = block_b
2267                .try_row_chunk(start..end)
2268                .map_err(|e| e.to_string())?;
2269            let a_m = fast_ab(&a_chunk, m_ab);
2270            for local in 0..(end - start) {
2271                out[start + local] = a_m.row(local).dot(&b_chunk.row(local));
2272            }
2273        }
2274        Ok(out)
2275    }
2276
2277    /// Compute the cross-block X_i' diag(w) X_j for blocks i < j.
2278    fn cross_block(
2279        &self,
2280        i: usize,
2281        j: usize,
2282        weights: &Array1<f64>,
2283    ) -> Result<Array2<f64>, String> {
2284        match (&self.blocks[i], &self.blocks[j]) {
2285            // ── Dense × Dense ───────────────────────────────────────────
2286            (DesignBlock::Dense(d_i), DesignBlock::Dense(d_j)) => {
2287                if let (Some(xi), Some(xj)) = (d_i.as_dense_ref(), d_j.as_dense_ref()) {
2288                    weighted_crossprod_dense(xi, weights, xj)
2289                } else {
2290                    self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2291                }
2292            }
2293            (DesignBlock::Dense(_), DesignBlock::Sparse(_))
2294            | (DesignBlock::Sparse(_), DesignBlock::Dense(_))
2295            | (DesignBlock::Sparse(_), DesignBlock::Sparse(_))
2296            | (DesignBlock::Sparse(_), DesignBlock::RandomEffect(_))
2297            | (DesignBlock::RandomEffect(_), DesignBlock::Sparse(_)) => {
2298                self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2299            }
2300
2301            // ── Dense × RandomEffect ────────────────────────────────────
2302            (DesignBlock::Dense(d), DesignBlock::RandomEffect(re)) => {
2303                if let Some(dense) = d.as_dense_ref() {
2304                    Ok(re.weighted_cross_with_dense(dense, weights))
2305                } else {
2306                    self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2307                }
2308            }
2309            (DesignBlock::RandomEffect(re), DesignBlock::Dense(d)) => {
2310                if let Some(dense) = d.as_dense_ref() {
2311                    let cross_t = re.weighted_cross_with_dense(dense, weights);
2312                    Ok(cross_t.t().to_owned())
2313                } else {
2314                    self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2315                }
2316            }
2317
2318            // ── RandomEffect × RandomEffect ─────────────────────────────
2319            (DesignBlock::RandomEffect(re_a), DesignBlock::RandomEffect(re_b)) => {
2320                Ok(re_a.weighted_cross_with_re(re_b, weights))
2321            }
2322
2323            // ── Intercept × anything ────────────────────────────────────
2324            // 1'·diag(w)·B_j  →  (1 × p_j) where entry [0,c] = Σ_i w[i] * B_j[i,c]
2325            (DesignBlock::Intercept(_), other) => {
2326                let pj = other.ncols();
2327                let mut cross = Array2::<f64>::zeros((1, pj));
2328                let weighted = Array1::from_shape_fn(self.n, |idx| weights[idx].max(0.0));
2329                let row = other.apply_transpose(&weighted);
2330                cross.row_mut(0).assign(&row);
2331                Ok(cross)
2332            }
2333            (other, DesignBlock::Intercept(_)) => {
2334                let pi = other.ncols();
2335                let mut cross = Array2::<f64>::zeros((pi, 1));
2336                let weighted = Array1::from_shape_fn(self.n, |idx| weights[idx].max(0.0));
2337                let col = other.apply_transpose(&weighted);
2338                cross.column_mut(0).assign(&col);
2339                Ok(cross)
2340            }
2341        }
2342    }
2343
2344    /// Diagonal contribution diag(X_k M X_k') for a single block.
2345    fn quadratic_form_diag_block(
2346        &self,
2347        block: &DesignBlock,
2348        m_kk: &Array2<f64>,
2349    ) -> Result<Array1<f64>, String> {
2350        match block {
2351            DesignBlock::Dense(d) => {
2352                if let Some(dense) = d.as_dense_ref() {
2353                    let xm = fast_ab(dense, m_kk);
2354                    let mut out = Array1::<f64>::zeros(self.n);
2355                    ndarray::Zip::from(&mut out)
2356                        .and(dense.rows())
2357                        .and(xm.rows())
2358                        .par_for_each(|o, dr, xmr| *o = dr.dot(&xmr));
2359                    Ok(out)
2360                } else {
2361                    d.quadratic_form_diag(m_kk)
2362                }
2363            }
2364            DesignBlock::Sparse(s) => {
2365                let sparse = DesignMatrix::Sparse(s.clone());
2366                sparse.quadratic_form_diag(m_kk)
2367            }
2368            DesignBlock::RandomEffect(re) => {
2369                use rayon::prelude::*;
2370                let out: Vec<f64> = re
2371                    .group_ids
2372                    .par_iter()
2373                    .map(|g| g.map(|g| m_kk[[g, g]]).unwrap_or(0.0))
2374                    .collect();
2375                Ok(Array1::from(out))
2376            }
2377            DesignBlock::Intercept(_) => {
2378                // Row i of intercept block is [1], so contribution = M[0,0] for all i.
2379                Ok(Array1::from_elem(self.n, m_kk[[0, 0]]))
2380            }
2381        }
2382    }
2383
2384    /// Cross-block contribution diag(X_a M_ab X_b') for two distinct blocks.
2385    fn quadratic_form_diag_cross(
2386        &self,
2387        block_a: &DesignBlock,
2388        block_b: &DesignBlock,
2389        m_ab: &Array2<f64>,
2390    ) -> Result<Array1<f64>, String> {
2391        match (block_a, block_b) {
2392            (DesignBlock::Dense(da), DesignBlock::Dense(db)) => {
2393                if let (Some(da), Some(db)) = (da.as_dense_ref(), db.as_dense_ref()) {
2394                    let da_m = fast_ab(da, m_ab);
2395                    let mut out = Array1::<f64>::zeros(self.n);
2396                    ndarray::Zip::from(&mut out)
2397                        .and(da_m.rows())
2398                        .and(db.rows())
2399                        .par_for_each(|o, ar, br| *o = ar.dot(&br));
2400                    Ok(out)
2401                } else {
2402                    self.quadratic_form_diag_cross_chunked(block_a, block_b, m_ab)
2403                }
2404            }
2405            (DesignBlock::Dense(_), DesignBlock::Sparse(_))
2406            | (DesignBlock::Sparse(_), DesignBlock::Dense(_))
2407            | (DesignBlock::Sparse(_), DesignBlock::Sparse(_))
2408            | (DesignBlock::Sparse(_), DesignBlock::RandomEffect(_))
2409            | (DesignBlock::RandomEffect(_), DesignBlock::Sparse(_)) => {
2410                self.quadratic_form_diag_cross_chunked(block_a, block_b, m_ab)
2411            }
2412            (DesignBlock::Dense(d), DesignBlock::RandomEffect(re)) => {
2413                let mut out = Array1::<f64>::zeros(self.n);
2414                for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2415                    let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2416                    let chunk = d.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2417                    for local in 0..chunk.nrows() {
2418                        let i = start + local;
2419                        if let Some(g) = re.group_ids[i] {
2420                            let mut val = 0.0;
2421                            for j in 0..chunk.ncols() {
2422                                val += chunk[[local, j]] * m_ab[[j, g]];
2423                            }
2424                            out[i] = val;
2425                        }
2426                    }
2427                }
2428                Ok(out)
2429            }
2430            (DesignBlock::RandomEffect(re), DesignBlock::Dense(d)) => {
2431                let mut out = Array1::<f64>::zeros(self.n);
2432                for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2433                    let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2434                    let chunk = d.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2435                    for local in 0..chunk.nrows() {
2436                        let i = start + local;
2437                        if let Some(g) = re.group_ids[i] {
2438                            let mut val = 0.0;
2439                            for j in 0..chunk.ncols() {
2440                                val += m_ab[[g, j]] * chunk[[local, j]];
2441                            }
2442                            out[i] = val;
2443                        }
2444                    }
2445                }
2446                Ok(out)
2447            }
2448            (DesignBlock::RandomEffect(re_a), DesignBlock::RandomEffect(re_b)) => {
2449                use rayon::prelude::*;
2450                let out: Vec<f64> = re_a
2451                    .group_ids
2452                    .par_iter()
2453                    .zip(re_b.group_ids.par_iter())
2454                    .map(|(ga, gb)| match (ga, gb) {
2455                        (Some(ga), Some(gb)) => m_ab[[*ga, *gb]],
2456                        _ => 0.0,
2457                    })
2458                    .collect();
2459                Ok(Array1::from(out))
2460            }
2461
2462            // Intercept × anything: contribution at row i = m_ab[0, :] · row_i(B_b)
2463            (DesignBlock::Intercept(_), other) => {
2464                let m_row = m_ab.row(0);
2465                let mut out = Array1::<f64>::zeros(self.n);
2466                for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2467                    let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2468                    let chunk = other.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2469                    for local in 0..(end - start) {
2470                        out[start + local] = chunk.row(local).dot(&m_row);
2471                    }
2472                }
2473                Ok(out)
2474            }
2475            (other, DesignBlock::Intercept(_)) => {
2476                let m_col = m_ab.column(0);
2477                let mut out = Array1::<f64>::zeros(self.n);
2478                for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2479                    let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2480                    let chunk = other.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2481                    for local in 0..(end - start) {
2482                        out[start + local] = chunk.row(local).dot(&m_col);
2483                    }
2484                }
2485                Ok(out)
2486            }
2487        }
2488    }
2489}
2490
2491impl LinearOperator for BlockDesignOperator {
2492    fn nrows(&self) -> usize {
2493        self.n
2494    }
2495
2496    fn ncols(&self) -> usize {
2497        self.total_cols
2498    }
2499
2500    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
2501        let mut out = Array1::<f64>::zeros(self.n);
2502        for (idx, block) in self.blocks.iter().enumerate() {
2503            let start = self.col_offsets[idx];
2504            let end = self.col_offsets[idx + 1];
2505            let slice = vector.slice(s![start..end]).to_owned();
2506            let contribution = block.apply(&slice);
2507            out += &contribution;
2508        }
2509        out
2510    }
2511
2512    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
2513        let mut out = Array1::<f64>::zeros(self.total_cols);
2514        for (idx, block) in self.blocks.iter().enumerate() {
2515            let start = self.col_offsets[idx];
2516            let end = self.col_offsets[idx + 1];
2517            let transposed = block.apply_transpose(vector);
2518            out.slice_mut(s![start..end]).assign(&transposed);
2519        }
2520        out
2521    }
2522
2523    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
2524        if weights.len() != self.n {
2525            return Err(format!(
2526                "BlockDesignOperator::diag_xtw_x weight length mismatch: weights={}, nrows={}",
2527                weights.len(),
2528                self.n
2529            ));
2530        }
2531        let p = self.total_cols;
2532        let mut result = Array2::<f64>::zeros((p, p));
2533
2534        // Diagonal blocks.
2535        for (idx, block) in self.blocks.iter().enumerate() {
2536            let start = self.col_offsets[idx];
2537            let end = self.col_offsets[idx + 1];
2538            let block_xtwx = block.diag_xtw_x(weights)?;
2539            result
2540                .slice_mut(s![start..end, start..end])
2541                .assign(&block_xtwx);
2542        }
2543
2544        // Cross blocks (i, j) for i < j.
2545        //
2546        // Perf (#1017): the shared weight-scaled design `diag(w)·X_i` is
2547        // identical across every pairing `(i, j>i)`. The prior code recomputed
2548        // it inside each `cross_block` call (re-scaling X_i by w once per
2549        // partner j) and folded the product with a naive O(n·p_i·p_j) triple
2550        // loop. We now scale each dense block by `w` exactly ONCE up front and
2551        // route every Dense×Dense pair through a single blocked BLAS GEMM
2552        // (`fast_atb`), collapsing the c² hand-rolled accumulations into c²
2553        // batched matmuls over a design that is weight-scaled c times instead
2554        // of O(c²) times. Non-dense pairs keep their specialized kernels.
2555        //
2556        // Bit-identity: `cross[a,b] = Σ_i w_i · X_i[i,a] · X_j[i,b]` is exactly
2557        // `(diag(w)·X_i)ᵀ · X_j`, so pre-scaling then GEMM is the same sum,
2558        // reassociated only by the matmul's blocking (≤1e-10).
2559        let weighted_dense: Vec<Option<Array2<f64>>> = self
2560            .blocks
2561            .iter()
2562            .map(|block| match block {
2563                DesignBlock::Dense(d) => d.as_dense_ref().map(|x| {
2564                    // diag(w)·X computed once; signed w (no .max(0.0)) to match
2565                    // the asymmetric cross-block kernel's sign-correct form.
2566                    x * &weights.view().insert_axis(Axis(1))
2567                }),
2568                _ => None,
2569            })
2570            .collect();
2571
2572        for i in 0..self.blocks.len() {
2573            for j in (i + 1)..self.blocks.len() {
2574                let cross = match (&weighted_dense[i], &self.blocks[j]) {
2575                    // Fused Dense×Dense: single GEMM over the shared,
2576                    // already-once-scaled left design.
2577                    (Some(wx_i), DesignBlock::Dense(d_j)) => match d_j.as_dense_ref() {
2578                        Some(x_j) => fast_atb(wx_i, x_j),
2579                        None => self.cross_block(i, j, weights)?,
2580                    },
2581                    _ => self.cross_block(i, j, weights)?,
2582                };
2583                let si = self.col_offsets[i];
2584                let ei = self.col_offsets[i + 1];
2585                let sj = self.col_offsets[j];
2586                let ej = self.col_offsets[j + 1];
2587                result.slice_mut(s![si..ei, sj..ej]).assign(&cross);
2588                result.slice_mut(s![sj..ej, si..ei]).assign(&cross.t());
2589            }
2590        }
2591
2592        Ok(result)
2593    }
2594
2595    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
2596        if weights.len() != self.n {
2597            return Err(format!(
2598                "BlockDesignOperator::diag_gram weight length mismatch: weights={}, nrows={}",
2599                weights.len(),
2600                self.n
2601            ));
2602        }
2603        let mut out = Array1::<f64>::zeros(self.total_cols);
2604        for (idx, block) in self.blocks.iter().enumerate() {
2605            let start = self.col_offsets[idx];
2606            let end = self.col_offsets[idx + 1];
2607            let block_diag = block.diag_gram(weights)?;
2608            out.slice_mut(s![start..end]).assign(&block_diag);
2609        }
2610        Ok(out)
2611    }
2612
2613    fn apply_weighted_normal(
2614        &self,
2615        weights: &Array1<f64>,
2616        vector: &Array1<f64>,
2617        penalty: Option<&Array2<f64>>,
2618        ridge: f64,
2619    ) -> Array1<f64> {
2620        assert_eq!(
2621            weights.len(),
2622            self.n,
2623            "BlockDesignOperator::apply_weighted_normal weight length mismatch"
2624        );
2625        assert_eq!(
2626            vector.len(),
2627            self.total_cols,
2628            "BlockDesignOperator::apply_weighted_normal vector length mismatch"
2629        );
2630        // Fused: X'W(Xβ) + Sβ + ridge·β
2631        let xv = self.apply(vector);
2632        let mut weighted = xv;
2633        for i in 0..weighted.len() {
2634            weighted[i] *= weights[i].max(0.0);
2635        }
2636        let mut out = self.apply_transpose(&weighted);
2637        if let Some(pen) = penalty {
2638            out += &fast_av(pen, vector);
2639        }
2640        if ridge > 0.0 {
2641            // BLAS axpy: out += ridge * vector, no temporary allocation.
2642            out.scaled_add(ridge, vector);
2643        }
2644        out
2645    }
2646
2647    fn uses_matrix_free_pcg(&self) -> bool {
2648        // Enable PCG when any block is non-dense (RE, Operator, or Intercept).
2649        self.blocks
2650            .iter()
2651            .any(|b| matches!(b, DesignBlock::RandomEffect(_) | DesignBlock::Intercept(_)))
2652    }
2653}
2654
2655impl DenseDesignOperator for BlockDesignOperator {
2656    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
2657        if weights.len() != self.n || y.len() != self.n {
2658            return Err(format!(
2659                "BlockDesignOperator::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
2660                weights.len(),
2661                y.len(),
2662                self.n
2663            ));
2664        }
2665        let mut wy = Array1::<f64>::zeros(self.n);
2666        ndarray::Zip::from(&mut wy)
2667            .and(weights)
2668            .and(y)
2669            .par_for_each(|o, &w, &yi| *o = w.max(0.0) * yi);
2670        Ok(self.apply_transpose(&wy))
2671    }
2672
2673    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
2674        // diag(X M X'): for each observation i, compute row_i(X) · M · row_i(X)'.
2675        // With block structure, this decomposes into diagonal and cross-block terms.
2676        let mut out = Array1::<f64>::zeros(self.n);
2677        let nb = self.blocks.len();
2678
2679        // Diagonal contributions: diag(X_k M_kk X_k')
2680        for k in 0..nb {
2681            let sk = self.col_offsets[k];
2682            let ek = self.col_offsets[k + 1];
2683            let m_kk = middle.slice(s![sk..ek, sk..ek]).to_owned();
2684            let block_diag = self.quadratic_form_diag_block(&self.blocks[k], &m_kk)?;
2685            out += &block_diag;
2686        }
2687
2688        // Cross-block contributions: 2·diag(X_a M_ab X_b')
2689        for a in 0..nb {
2690            for b in (a + 1)..nb {
2691                let sa = self.col_offsets[a];
2692                let ea = self.col_offsets[a + 1];
2693                let sb = self.col_offsets[b];
2694                let eb = self.col_offsets[b + 1];
2695                let m_ab = middle.slice(s![sa..ea, sb..eb]);
2696
2697                let cross_diag = self.quadratic_form_diag_cross(
2698                    &self.blocks[a],
2699                    &self.blocks[b],
2700                    &m_ab.to_owned(),
2701                )?;
2702                for i in 0..self.n {
2703                    out[i] += 2.0 * cross_diag[i];
2704                }
2705            }
2706        }
2707
2708        // Clamp to non-negative (variance-like quantity).
2709        for v in out.iter_mut() {
2710            *v = v.max(0.0);
2711        }
2712        Ok(out)
2713    }
2714
2715    fn row_chunk_into(
2716        &self,
2717        rows: Range<usize>,
2718        mut out: ArrayViewMut2<'_, f64>,
2719    ) -> Result<(), MatrixMaterializationError> {
2720        if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
2721            return Err(MatrixMaterializationError::MissingRowChunk {
2722                context: "BlockDesignOperator::row_chunk_into shape mismatch",
2723            });
2724        }
2725        for (idx, block) in self.blocks.iter().enumerate() {
2726            let cs = self.col_offsets[idx];
2727            let ce = self.col_offsets[idx + 1];
2728            let block_chunk = block.try_row_chunk(rows.clone())?;
2729            out.slice_mut(s![.., cs..ce]).assign(&block_chunk);
2730        }
2731        Ok(())
2732    }
2733
2734    fn to_dense(&self) -> Array2<f64> {
2735        let mut out = Array2::<f64>::zeros((self.n, self.total_cols));
2736        for (idx, block) in self.blocks.iter().enumerate() {
2737            let start = self.col_offsets[idx];
2738            let end = self.col_offsets[idx + 1];
2739            let dense_block = block.to_dense();
2740            out.slice_mut(s![.., start..end]).assign(&dense_block);
2741        }
2742        out
2743    }
2744}
2745
2746// ---------------------------------------------------------------------------
2747// MultiChannelOperator
2748// ---------------------------------------------------------------------------
2749
2750/// Multi-channel design operator: presents k views of shape (n, p) as a single
2751/// (k*n, p) operator without materializing the stacked matrix.
2752///
2753/// Primary use: survival time blocks with entry/exit/derivative channels.
2754/// Each channel contributes independently to matvecs and Gram assembly:
2755///
2756///   apply(β) = [X₀ β; X₁ β; …; X_{k-1} β]      (concatenated)
2757///   apply_transpose(v) = Σᵢ Xᵢᵀ vᵢ              (summed over channel slices)
2758///   X'WX = Σᵢ Xᵢᵀ diag(wᵢ) Xᵢ                  (summed over channel slices)
2759#[derive(Clone)]
2760pub struct MultiChannelOperator {
2761    /// Per-channel design matrices, each (n, p).
2762    pub channels: Vec<DesignMatrix>,
2763    /// Number of rows per channel (all channels must share the same n).
2764    pub n_per_channel: usize,
2765    /// Number of columns (shared across all channels).
2766    pub p: usize,
2767}
2768
2769impl MultiChannelOperator {
2770    pub fn new(channels: Vec<DesignMatrix>) -> Result<Self, String> {
2771        if channels.is_empty() {
2772            return Err("MultiChannelOperator: need at least one channel".to_string());
2773        }
2774        let n = channels[0].nrows();
2775        let p = channels[0].ncols();
2776        for (i, ch) in channels.iter().enumerate() {
2777            if ch.nrows() != n {
2778                return Err(format!(
2779                    "MultiChannelOperator: channel {i} has {} rows, expected {n}",
2780                    ch.nrows()
2781                ));
2782            }
2783            if ch.ncols() != p {
2784                return Err(format!(
2785                    "MultiChannelOperator: channel {i} has {} cols, expected {p}",
2786                    ch.ncols()
2787                ));
2788            }
2789        }
2790        Ok(Self {
2791            channels,
2792            n_per_channel: n,
2793            p,
2794        })
2795    }
2796}
2797
2798impl LinearOperator for MultiChannelOperator {
2799    fn nrows(&self) -> usize {
2800        self.n_per_channel * self.channels.len()
2801    }
2802
2803    fn ncols(&self) -> usize {
2804        self.p
2805    }
2806
2807    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
2808        let total = self.nrows();
2809        let mut out = Array1::<f64>::zeros(total);
2810        let n = self.n_per_channel;
2811        for (i, ch) in self.channels.iter().enumerate() {
2812            let ch_result = ch.matrixvectormultiply(vector);
2813            out.slice_mut(s![i * n..(i + 1) * n]).assign(&ch_result);
2814        }
2815        out
2816    }
2817
2818    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
2819        let n = self.n_per_channel;
2820        let mut out = Array1::<f64>::zeros(self.p);
2821        for (i, ch) in self.channels.iter().enumerate() {
2822            out += &ch.apply_transpose_view(vector.slice(s![i * n..(i + 1) * n]));
2823        }
2824        out
2825    }
2826
2827    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
2828        let n = self.n_per_channel;
2829        if weights.len() != self.nrows() {
2830            return Err(format!(
2831                "MultiChannelOperator::diag_xtw_x: weights length {} != nrows {}",
2832                weights.len(),
2833                self.nrows()
2834            ));
2835        }
2836        // PSD-clamp the weights to `w ≥ 0`, consistent with this operator's own
2837        // `diag_gram` (`PsdWeightsView::try_new`) and `compute_xtwy`/
2838        // `apply_weighted_normal` (`w.max(0.0)`). Routing the signed/observed
2839        // path here would let a negative working weight flip a column's
2840        // contribution, leaving `diag_xtw_x` and its own diagonal `diag_gram`
2841        // disagreeing on their shared entries for the same operator+weights.
2842        // Multi-channel Grams are always consumed as PSD preconditioners, so
2843        // the clamped XᵀWX is the correct shared semantics (gam#846).
2844        let w_pos = weights.mapv(|w: f64| w.max(0.0));
2845        let mut xtwx = Array2::<f64>::zeros((self.p, self.p));
2846        for (i, ch) in self.channels.iter().enumerate() {
2847            let ch_xtwx = ch
2848                .xt_diag_x_signed_op(SignedWeightsView::new(w_pos.slice(s![i * n..(i + 1) * n])))?;
2849            xtwx += &ch_xtwx;
2850        }
2851        Ok(xtwx)
2852    }
2853
2854    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
2855        let n = self.n_per_channel;
2856        if weights.len() != self.nrows() {
2857            return Err(format!(
2858                "MultiChannelOperator::diag_gram: weights length {} != nrows {}",
2859                weights.len(),
2860                self.nrows()
2861            ));
2862        }
2863        // PSD-clamp the weights to match this operator's own `diag_xtw_x` and
2864        // `compute_xtwy` semantics. Per-channel `diag_gram_view` is backed by
2865        // `PsdWeights::try_new`, which rejects negative entries outright: a
2866        // negative working weight would surface as a hard error here while
2867        // `diag_xtw_x(weights)` succeeds, so the operator's own Gram diagonal
2868        // would disagree with the diagonal of its full Gram for the same
2869        // signed weights. Multi-channel Grams are always consumed as PSD
2870        // preconditioners (gam#846), so clamping is the correct shared
2871        // semantics.
2872        let w_pos = weights.mapv(|w: f64| w.max(0.0));
2873        let mut diag = Array1::<f64>::zeros(self.p);
2874        for (i, ch) in self.channels.iter().enumerate() {
2875            diag += &ch.diag_gram_view(w_pos.slice(s![i * n..(i + 1) * n]))?;
2876        }
2877        Ok(diag)
2878    }
2879
2880    fn uses_matrix_free_pcg(&self) -> bool {
2881        true
2882    }
2883}
2884
2885impl DenseDesignOperator for MultiChannelOperator {
2886    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
2887        let n = self.n_per_channel;
2888        let total = self.nrows();
2889        if weights.len() != total || y.len() != total {
2890            return Err(format!(
2891                "MultiChannelOperator::compute_xtwy: weights={}, y={}, nrows={}",
2892                weights.len(),
2893                y.len(),
2894                total
2895            ));
2896        }
2897        // Clamp signed weights to non-negative to match this operator's
2898        // `diag_xtw_x` / `diag_gram` semantics (multi-channel Grams are PSD
2899        // preconditioners — gam#846). The dense per-channel
2900        // `compute_xtwy_view` path is signed-safe by design (it preserves the
2901        // sign through XᵀWy so observed-Hessian assembly is exact), while the
2902        // sparse path clamps internally — passing raw signed weights would
2903        // therefore produce different XᵀWy depending on which channels are
2904        // sparse vs dense for the same operator+weights.
2905        let w_pos = weights.mapv(|w: f64| w.max(0.0));
2906        let mut out = Array1::<f64>::zeros(self.p);
2907        for (i, ch) in self.channels.iter().enumerate() {
2908            out += &ch.compute_xtwy_view(
2909                w_pos.slice(s![i * n..(i + 1) * n]),
2910                y.slice(s![i * n..(i + 1) * n]),
2911            )?;
2912        }
2913        Ok(out)
2914    }
2915
2916    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
2917        let n = self.n_per_channel;
2918        let mut out = Array1::<f64>::zeros(self.nrows());
2919        for (i, ch) in self.channels.iter().enumerate() {
2920            let ch_diag = ch.quadratic_form_diag(middle)?;
2921            out.slice_mut(s![i * n..(i + 1) * n]).assign(&ch_diag);
2922        }
2923        Ok(out)
2924    }
2925
2926    fn to_dense(&self) -> Array2<f64> {
2927        let total = self.nrows();
2928        let n = self.n_per_channel;
2929        let mut out = Array2::<f64>::zeros((total, self.p));
2930        for (i, ch) in self.channels.iter().enumerate() {
2931            let dense = ch.to_dense();
2932            out.slice_mut(s![i * n..(i + 1) * n, ..]).assign(&dense);
2933        }
2934        out
2935    }
2936
2937    fn row_chunk_into(
2938        &self,
2939        rows: Range<usize>,
2940        mut out: ArrayViewMut2<'_, f64>,
2941    ) -> Result<(), MatrixMaterializationError> {
2942        if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
2943            return Err(MatrixMaterializationError::MissingRowChunk {
2944                context: "MultiChannelOperator::row_chunk_into shape mismatch",
2945            });
2946        }
2947        let n = self.n_per_channel;
2948        let mut local = 0usize;
2949        let mut global = rows.start;
2950        while global < rows.end {
2951            let ch_idx = global / n;
2952            let ch_local_start = global % n;
2953            let ch_local_end = ((ch_idx + 1) * n).min(rows.end) - ch_idx * n;
2954            let segment_len = ch_local_end - ch_local_start;
2955            let ch_chunk = self.channels[ch_idx].try_row_chunk(ch_local_start..ch_local_end)?;
2956            out.slice_mut(s![local..local + segment_len, ..])
2957                .assign(&ch_chunk);
2958            local += segment_len;
2959            global += segment_len;
2960        }
2961        Ok(())
2962    }
2963}
2964
2965// Rowwise-Kronecker + tensor-product design operators (#1145): see `kronecker.rs`.
2966mod kronecker;
2967pub use kronecker::*;
2968
2969/// Coefficient-side transform operator: represents X_eff = X_inner * T without
2970/// materializing the product. Preserves the sparsity/operator structure of the
2971/// inner design by applying T on the coefficient side:
2972///   apply(v) = X_inner * (T * v)
2973///   apply_transpose(v) = T^T * (X_inner^T * v)
2974///   diag_xtw_x(w) = T^T * (X_inner^T W X_inner) * T
2975pub struct CoefficientTransformOperator {
2976    inner: DenseDesignMatrix,
2977    transform: Arc<Array2<f64>>,
2978    n: usize,
2979    p_out: usize,
2980    /// One-time-materialized X · T dense block, populated on first hot use.
2981    /// Only allocated when the n × p_out block fits within MATERIALIZE_MAX_BYTES;
2982    /// reused across all PIRLS iterations and outer-seed evaluations. Without
2983    /// this cache, `BlockDesignOperator::cross_block` (used by per-iter
2984    /// curvature builds) calls `row_chunk_into` repeatedly, each time
2985    /// re-running `fast_ab(inner_chunk, transform)` — measured ~3.7 s / iter
2986    /// at large-scale duchon60 shape (n=320 K, p_out=42 effective) for a single
2987    /// `update_with_curvature`, all of it allocations + chunked GEMM.
2988    materialized: OnceLock<Option<Arc<Array2<f64>>>>,
2989}
2990
2991impl CoefficientTransformOperator {
2992    /// Maximum bytes for the one-shot X · T materialization. 1 GiB is generous
2993    /// enough to cover large-scale (n = 320 K, p_out = 42 → ~107 MiB) and
2994    /// rejects pathological designs. Matches ChunkedKernelDesignOperator.
2995    const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
2996
2997    pub fn new(inner: DenseDesignMatrix, transform: Array2<f64>) -> Result<Self, String> {
2998        let p_inner = inner.ncols();
2999        if transform.nrows() != p_inner {
3000            return Err(format!(
3001                "CoefficientTransformOperator: inner has {} cols but transform has {} rows",
3002                p_inner,
3003                transform.nrows(),
3004            ));
3005        }
3006        let n = inner.nrows();
3007        let p_out = transform.ncols();
3008        Ok(Self {
3009            inner,
3010            transform: Arc::new(transform),
3011            n,
3012            p_out,
3013            materialized: OnceLock::new(),
3014        })
3015    }
3016
3017    /// Get-or-build the materialized X · T dense block. Returns `None` when
3018    /// either the X·T product exceeds the operator's local cap *or* when the
3019    /// inner design refuses dense materialization under the cache-local
3020    /// policy (the cache owns the densified block's lifetime, so it gates
3021    /// the inner densification on the same `MATERIALIZE_MAX_BYTES` budget
3022    /// rather than falling back to the conservative library default).  In
3023    /// either case callers fall back to per-chunk evaluation; the cache is
3024    /// best-effort optimization, not a hard requirement, so a refusal must
3025    /// never panic — the streaming `row_chunk_into` / `apply` paths still
3026    /// work.
3027    /// Same OnceLock-under-rayon hazard as
3028    /// `ChunkedKernelDesignOperator::materialized_combined`: the inner
3029    /// `try_to_dense_arc_with_policy` may dispatch parallel work
3030    /// (kernel-evaluation chunks, BLAS GEMM via faer's rayon pool, etc.),
3031    /// so we compute outside the lock and write with `set`. See
3032    /// `feedback_oncelock_rayon_deadlock`.
3033    fn materialized_combined(&self) -> Option<&Array2<f64>> {
3034        if let Some(slot) = self.materialized.get() {
3035            return slot.as_ref().map(|a| a.as_ref());
3036        }
3037        let bytes = self
3038            .n
3039            .checked_mul(self.p_out)
3040            .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
3041        let computed = match bytes {
3042            Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
3043                // Auto-strict at large-scale shape: even though the cache's
3044                // own MATERIALIZE_MAX_BYTES budget would admit this
3045                // block, refuse densification when the operator's
3046                // outer shape says we're in strict territory. Falls
3047                // through to streaming row_chunk_into / apply paths.
3048                let auto_policy = ResourcePolicy::for_problem(
3049                    self.n,
3050                    self.p_out,
3051                    gam_runtime::resource::ProblemHints::default(),
3052                );
3053                let cache_policy = ResourcePolicy {
3054                    max_single_materialization_bytes: Self::MATERIALIZE_MAX_BYTES,
3055                    derivative_storage_mode: auto_policy.derivative_storage_mode,
3056                    ..ResourcePolicy::default_library()
3057                };
3058                self.inner
3059                    .try_to_dense_arc_with_policy(
3060                        "CoefficientTransformOperator materialization",
3061                        &cache_policy,
3062                    )
3063                    .ok()
3064                    .map(|x| Arc::new(fast_ab(x.as_ref(), &self.transform)))
3065            }
3066            _ => None,
3067        };
3068        if self.materialized.set(computed).is_err() {
3069            return self
3070                .materialized
3071                .get()
3072                .and_then(|opt| opt.as_ref().map(|a| a.as_ref()));
3073        }
3074        self.materialized
3075            .get()
3076            .and_then(|opt| opt.as_ref().map(|a| a.as_ref()))
3077    }
3078}
3079
3080impl LinearOperator for CoefficientTransformOperator {
3081    fn nrows(&self) -> usize {
3082        self.n
3083    }
3084    fn ncols(&self) -> usize {
3085        self.p_out
3086    }
3087    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3088        if let Some(combined) = self.materialized_combined() {
3089            return fast_av(combined, vector);
3090        }
3091        let tv = fast_av(&self.transform, vector);
3092        self.inner.apply(&tv)
3093    }
3094    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
3095        if let Some(combined) = self.materialized_combined() {
3096            return fast_atv(combined, vector);
3097        }
3098        let xtv = self.inner.apply_transpose(vector);
3099        fast_atv(&self.transform, &xtv)
3100    }
3101    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
3102        if let Some(combined) = self.materialized_combined() {
3103            let mut xtwx = Array2::<f64>::zeros((self.p_out, self.p_out));
3104            stream_weighted_crossprod_into(
3105                combined,
3106                weights,
3107                &mut xtwx,
3108                CrossprodStructure::Full,
3109                CrossprodAccum::Replace,
3110                effective_global_parallelism(),
3111            );
3112            return Ok(xtwx);
3113        }
3114        let inner_xtwx = self.inner.diag_xtw_x(weights)?;
3115        // T^T * (X^T W X) * T
3116        let tmp = fast_ab(&self.transform.t().to_owned(), &inner_xtwx);
3117        Ok(fast_ab(&tmp, &self.transform))
3118    }
3119}
3120
3121impl DenseDesignOperator for CoefficientTransformOperator {
3122    /// Expose the cached X·T materialization when populated. This is what lets
3123    /// `BlockDesignOperator::cross_block` recognize a Dense × Dense pair and
3124    /// route to `weighted_crossprod_dense` (BLAS-3 GEMM) instead of the
3125    /// scalar `weighted_cross_chunked` triple loop. Without this override the
3126    /// default trait impl returns `None`, the fast path is skipped, and a
3127    /// 4-block large-scale fit (pgs + sex + smooth_age + duchon) pays a 24 s
3128    /// cross-block cost per PIRLS curvature build.
3129    fn as_dense_ref(&self) -> Option<&Array2<f64>> {
3130        self.materialized_combined()
3131    }
3132
3133    fn to_dense(&self) -> Array2<f64> {
3134        if let Some(combined) = self.materialized_combined() {
3135            return combined.clone();
3136        }
3137        let x = self.inner.to_dense();
3138        fast_ab(&x, &self.transform)
3139    }
3140    fn row_chunk_into(
3141        &self,
3142        rows: Range<usize>,
3143        mut out: ArrayViewMut2<'_, f64>,
3144    ) -> Result<(), MatrixMaterializationError> {
3145        if out.nrows() != rows.end - rows.start || out.ncols() != self.p_out {
3146            return Err(MatrixMaterializationError::MissingRowChunk {
3147                context: "CoefficientTransformOperator::row_chunk_into shape mismatch",
3148            });
3149        }
3150        if let Some(combined) = self.materialized_combined() {
3151            out.assign(&combined.slice(s![rows, ..]));
3152            return Ok(());
3153        }
3154        let chunk = self.inner.try_row_chunk(rows)?;
3155        out.assign(&fast_ab(&chunk, &self.transform));
3156        Ok(())
3157    }
3158}
3159
3160/// SMGS Phase 4b residualised design operator: emits the mathematically-exact
3161/// row `C_b · V_b − Σ_{a<b} A_a · R_{a,b}` for block `b`, where `V_b` is the
3162/// kept-direction reparametrisation of the inner raw block `C_b` and each
3163/// `R_{a,b} = M_{a,b} · V_b` is the precomputed residualised contribution of
3164/// an earlier anchor design `A_a` against block `b`.
3165///
3166/// The operator presents shape `(n × V_b.ncols())` so the rest of the design
3167/// stack sees the compiled (kept) width. Row-chunk emission computes
3168/// `inner_chunk · V_b` and then subtracts `anchor_chunk · r_block` for every
3169/// anchor pair. The combined `n × kept` block is cached via `OnceLock` under
3170/// the same `MATERIALIZE_MAX_BYTES = 1 GiB` ceiling as
3171/// [`CoefficientTransformOperator`], so streaming consumers fall back to
3172/// per-chunk evaluation when the materialisation would exceed budget.
3173pub struct ResidualisedDesignOperator {
3174    inner: DenseDesignMatrix,
3175    transform: Arc<Array2<f64>>,
3176    anchors: Vec<(DesignMatrix, Arc<Array2<f64>>)>,
3177    n: usize,
3178    p_out: usize,
3179    materialized: OnceLock<Option<Arc<Array2<f64>>>>,
3180}
3181
3182impl ResidualisedDesignOperator {
3183    /// Matches `CoefficientTransformOperator::MATERIALIZE_MAX_BYTES`: 1 GiB
3184    /// covers large-scale shapes (n=320 K, p_out≈40 → ~100 MiB) and rejects
3185    /// pathological designs.
3186    const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
3187
3188    pub fn new(
3189        inner: DenseDesignMatrix,
3190        transform: Array2<f64>,
3191        anchors: Vec<(DesignMatrix, Arc<Array2<f64>>)>,
3192    ) -> Result<Self, String> {
3193        let p_inner = inner.ncols();
3194        if transform.nrows() != p_inner {
3195            return Err(format!(
3196                "ResidualisedDesignOperator: inner has {} cols but transform has {} rows",
3197                p_inner,
3198                transform.nrows(),
3199            ));
3200        }
3201        let n = inner.nrows();
3202        let p_out = transform.ncols();
3203        for (idx, (anchor, r_block)) in anchors.iter().enumerate() {
3204            if anchor.nrows() != n {
3205                return Err(format!(
3206                    "ResidualisedDesignOperator: anchor[{idx}] has {} rows but inner has {n}",
3207                    anchor.nrows(),
3208                ));
3209            }
3210            if r_block.nrows() != anchor.ncols() || r_block.ncols() != p_out {
3211                return Err(format!(
3212                    "ResidualisedDesignOperator: anchor[{idx}] r_block is {}x{} but expected {}x{}",
3213                    r_block.nrows(),
3214                    r_block.ncols(),
3215                    anchor.ncols(),
3216                    p_out,
3217                ));
3218            }
3219        }
3220        Ok(Self {
3221            inner,
3222            transform: Arc::new(transform),
3223            anchors,
3224            n,
3225            p_out,
3226            materialized: OnceLock::new(),
3227        })
3228    }
3229
3230    /// Get-or-build the cached n × p_out materialised block. Mirrors
3231    /// [`CoefficientTransformOperator::materialized_combined`]: computed
3232    /// outside the lock to avoid the OnceLock+rayon deadlock pattern, with
3233    /// the per-cache 1 GiB byte cap routed through a relaxed policy so the
3234    /// inner densification is admitted on the cache's own budget.
3235    fn materialized_combined(&self) -> Option<&Array2<f64>> {
3236        if let Some(slot) = self.materialized.get() {
3237            return slot.as_ref().map(|a| a.as_ref());
3238        }
3239        let bytes = self
3240            .n
3241            .checked_mul(self.p_out)
3242            .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
3243        let computed = match bytes {
3244            Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
3245                let auto_policy = ResourcePolicy::for_problem(
3246                    self.n,
3247                    self.p_out,
3248                    gam_runtime::resource::ProblemHints::default(),
3249                );
3250                let cache_policy = ResourcePolicy {
3251                    max_single_materialization_bytes: Self::MATERIALIZE_MAX_BYTES,
3252                    derivative_storage_mode: auto_policy.derivative_storage_mode,
3253                    ..ResourcePolicy::default_library()
3254                };
3255                self.inner
3256                    .try_to_dense_arc_with_policy(
3257                        "ResidualisedDesignOperator materialization",
3258                        &cache_policy,
3259                    )
3260                    .ok()
3261                    .and_then(|x| {
3262                        let mut combined = fast_ab(x.as_ref(), &self.transform);
3263                        for (anchor, r_block) in &self.anchors {
3264                            let anchor_dense = match anchor {
3265                                DesignMatrix::Dense(d) => d
3266                                    .try_to_dense_arc_with_policy(
3267                                        "ResidualisedDesignOperator anchor materialization",
3268                                        &cache_policy,
3269                                    )
3270                                    .ok()?,
3271                                DesignMatrix::Sparse(s) => s
3272                                    .try_to_dense_arc(
3273                                        "ResidualisedDesignOperator anchor materialization",
3274                                    )
3275                                    .ok()?,
3276                            };
3277                            let contribution = fast_ab(anchor_dense.as_ref(), r_block.as_ref());
3278                            combined -= &contribution;
3279                        }
3280                        Some(Arc::new(combined))
3281                    })
3282            }
3283            _ => None,
3284        };
3285        if self.materialized.set(computed).is_err() {
3286            return self
3287                .materialized
3288                .get()
3289                .and_then(|opt| opt.as_ref().map(|a| a.as_ref()));
3290        }
3291        self.materialized
3292            .get()
3293            .and_then(|opt| opt.as_ref().map(|a| a.as_ref()))
3294    }
3295
3296    /// Public lazy materialisation handle: returns the cached combined block
3297    /// when available, falling back to chunked row-wise materialisation via
3298    /// the operator-backed path on the caller's behalf.
3299    pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
3300        if let Some(combined) = self.materialized.get().and_then(|opt| opt.clone()) {
3301            return Ok(combined);
3302        }
3303        if let Some(_combined_ref) = self.materialized_combined() {
3304            if let Some(arc) = self.materialized.get().and_then(|opt| opt.clone()) {
3305                return Ok(arc);
3306            }
3307        }
3308        dense_operator_to_dense_by_chunks(self)
3309            .map(Arc::new)
3310            .map_err(|err| format!("{context}: failed to materialize dense row chunks: {err}"))
3311    }
3312}
3313
3314impl LinearOperator for ResidualisedDesignOperator {
3315    fn nrows(&self) -> usize {
3316        self.n
3317    }
3318    fn ncols(&self) -> usize {
3319        self.p_out
3320    }
3321    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3322        if let Some(combined) = self.materialized_combined() {
3323            return fast_av(combined, vector);
3324        }
3325        // y = C_b · (V_b · v) − Σ_a A_a · (R_{a,b} · v)
3326        let tv = fast_av(&self.transform, vector);
3327        let mut out = self.inner.apply(&tv);
3328        for (anchor, r_block) in &self.anchors {
3329            let rv = fast_av(r_block.as_ref(), vector);
3330            let contrib = anchor.apply(&rv);
3331            out -= &contrib;
3332        }
3333        out
3334    }
3335    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
3336        if let Some(combined) = self.materialized_combined() {
3337            return fast_atv(combined, vector);
3338        }
3339        let xtv = self.inner.apply_transpose(vector);
3340        let mut out = fast_atv(&self.transform, &xtv);
3341        for (anchor, r_block) in &self.anchors {
3342            let atv = anchor.apply_transpose(vector);
3343            let contrib = fast_atv(r_block.as_ref(), &atv);
3344            out -= &contrib;
3345        }
3346        out
3347    }
3348    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
3349        if let Some(combined) = self.materialized_combined() {
3350            let mut xtwx = Array2::<f64>::zeros((self.p_out, self.p_out));
3351            stream_weighted_crossprod_into(
3352                combined,
3353                weights,
3354                &mut xtwx,
3355                CrossprodStructure::Full,
3356                CrossprodAccum::Replace,
3357                effective_global_parallelism(),
3358            );
3359            return Ok(xtwx);
3360        }
3361        // Fall back to the default DenseDesignOperator chunked path via
3362        // explicit materialisation: emit chunks through row_chunk_into and
3363        // accumulate XᵀWX without ever holding the full n × p_out block.
3364        let n = self.n;
3365        if weights.len() != n {
3366            return Err(format!(
3367                "ResidualisedDesignOperator::diag_xtw_x weights len {} != nrows {n}",
3368                weights.len()
3369            ));
3370        }
3371        let p = self.p_out;
3372        let chunk_rows = (8 * 1024 * 1024 / (p.max(1) * 8 * 2)).max(16).min(n.max(1));
3373        let mut xtwx = Array2::<f64>::zeros((p, p));
3374        let mut start = 0;
3375        while start < n {
3376            let end = (start + chunk_rows).min(n);
3377            let chunk = <Self as DenseDesignOperator>::try_row_chunk(self, start..end)
3378                .map_err(|e| e.to_string())?;
3379            let w_slice = weights.slice(s![start..end]).to_owned();
3380            let mut local = Array2::<f64>::zeros((p, p));
3381            stream_weighted_crossprod_into(
3382                &chunk,
3383                &w_slice,
3384                &mut local,
3385                CrossprodStructure::Full,
3386                CrossprodAccum::Replace,
3387                effective_global_parallelism(),
3388            );
3389            xtwx += &local;
3390            start = end;
3391        }
3392        Ok(xtwx)
3393    }
3394}
3395
3396impl DenseDesignOperator for ResidualisedDesignOperator {
3397    fn as_dense_ref(&self) -> Option<&Array2<f64>> {
3398        self.materialized_combined()
3399    }
3400
3401    fn to_dense(&self) -> Array2<f64> {
3402        if let Some(combined) = self.materialized_combined() {
3403            return combined.clone();
3404        }
3405        // Chunked fallback when the cache refuses (oversize block).
3406        dense_operator_to_dense_by_chunks(self).unwrap_or_else(|err| {
3407            std::panic::panic_any(format!(
3408                "ResidualisedDesignOperator::to_dense: failed to materialize {}x{} \
3409                 via row chunks: {err}",
3410                self.n, self.p_out,
3411            ))
3412        })
3413    }
3414
3415    fn row_chunk_into(
3416        &self,
3417        rows: Range<usize>,
3418        mut out: ArrayViewMut2<'_, f64>,
3419    ) -> Result<(), MatrixMaterializationError> {
3420        if out.nrows() != rows.end - rows.start || out.ncols() != self.p_out {
3421            return Err(MatrixMaterializationError::MissingRowChunk {
3422                context: "ResidualisedDesignOperator::row_chunk_into shape mismatch",
3423            });
3424        }
3425        if let Some(combined) = self.materialized_combined() {
3426            out.assign(&combined.slice(s![rows, ..]));
3427            return Ok(());
3428        }
3429        // C_b chunk in raw width, then project: out = (inner_chunk) · V_b
3430        let inner_chunk = self.inner.try_row_chunk(rows.clone())?;
3431        let mut combined = fast_ab(&inner_chunk, &self.transform);
3432        // Subtract Σ_a (anchor_chunk · r_block)
3433        for (anchor, r_block) in &self.anchors {
3434            let anchor_chunk = anchor.try_row_chunk(rows.clone())?;
3435            let contribution = fast_ab(&anchor_chunk, r_block.as_ref());
3436            combined -= &contribution;
3437        }
3438        out.assign(&combined);
3439        Ok(())
3440    }
3441}
3442
3443// ---------------------------------------------------------------------------
3444// ConditionedDesign — lazy per-column affine transform
3445// ---------------------------------------------------------------------------
3446
3447/// A design matrix wrapper that lazily applies per-column centering and scaling
3448/// without materializing a new dense matrix.
3449///
3450/// For each conditioned column `j`, the effective column is
3451/// `(X[:,j] - mean_j) / scale_j`.  All other columns pass through unchanged.
3452/// Algebraically this is `X·diag(a) - 1·d'` where `a[j] = 1/scale` for
3453/// conditioned columns (1 otherwise) and `d[j] = mean/scale` for conditioned
3454/// columns (0 otherwise).
3455pub struct ConditionedDesign {
3456    inner: DesignMatrix,
3457    /// Per-conditioned-column: (global_col_idx, mean, scale).
3458    columns: Vec<(usize, f64, f64)>,
3459}
3460
3461impl ConditionedDesign {
3462    pub fn new(inner: DesignMatrix, columns: Vec<(usize, f64, f64)>) -> Self {
3463        Self { inner, columns }
3464    }
3465}
3466
3467impl LinearOperator for ConditionedDesign {
3468    fn nrows(&self) -> usize {
3469        self.inner.nrows()
3470    }
3471
3472    fn ncols(&self) -> usize {
3473        self.inner.ncols()
3474    }
3475
3476    /// X_c v = X(a⊙v) - (d·v)·1
3477    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3478        let mut scaled = vector.clone();
3479        let mut shift = 0.0;
3480        for &(j, mean, scale) in &self.columns {
3481            scaled[j] /= scale;
3482            shift += mean * scaled[j];
3483        }
3484        let mut result = self.inner.apply(&scaled);
3485        if shift != 0.0 {
3486            result.mapv_inplace(|v| v - shift);
3487        }
3488        result
3489    }
3490
3491    /// X_c'u = a⊙(X'u) - d·Σu
3492    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
3493        let mut result = self.inner.apply_transpose(vector);
3494        let sum_u: f64 = vector.iter().sum();
3495        for &(j, mean, scale) in &self.columns {
3496            result[j] = (result[j] - mean * sum_u) / scale;
3497        }
3498        result
3499    }
3500
3501    /// X_c'WX_c = D_a(X'WX)D_a - D_a(X'w)d' - d(X'w)'D_a + Σw·dd'
3502    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
3503        let mut base = self.inner.diag_xtw_x(weights)?;
3504        if self.columns.is_empty() {
3505            return Ok(base);
3506        }
3507        let p = base.ncols();
3508        let w_pos: Array1<f64> = weights.mapv(|w| w.max(0.0));
3509        let sum_w: f64 = w_pos.sum();
3510        let cw = self.inner.apply_transpose(&w_pos);
3511
3512        // Precompute a[j] and d[j] for all columns.
3513        let mut a = vec![1.0_f64; p];
3514        let mut d = vec![0.0_f64; p];
3515        for &(j, mean, scale) in &self.columns {
3516            a[j] = 1.0 / scale;
3517            d[j] = mean / scale;
3518        }
3519
3520        // Apply the full transformation in one pass (symmetric).
3521        for i in 0..p {
3522            for j in i..p {
3523                let val = a[i] * base[[i, j]] * a[j] - a[i] * cw[i] * d[j] - d[i] * cw[j] * a[j]
3524                    + sum_w * d[i] * d[j];
3525                base[[i, j]] = val;
3526                base[[j, i]] = val;
3527            }
3528        }
3529        Ok(base)
3530    }
3531
3532    /// Diagonal of X_c'WX_c — only conditioned columns change.
3533    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
3534        let mut result = self.inner.diag_gram(weights)?;
3535        if self.columns.is_empty() {
3536            return Ok(result);
3537        }
3538        let w_pos: Array1<f64> = weights.mapv(|w| w.max(0.0));
3539        let sum_w: f64 = w_pos.sum();
3540        let cw = self.inner.apply_transpose(&w_pos);
3541        for &(j, mean, scale) in &self.columns {
3542            let a_j = 1.0 / scale;
3543            let d_j = mean / scale;
3544            result[j] = a_j * a_j * result[j] - 2.0 * a_j * cw[j] * d_j + sum_w * d_j * d_j;
3545        }
3546        Ok(result)
3547    }
3548
3549    fn uses_matrix_free_pcg(&self) -> bool {
3550        match &self.inner {
3551            DesignMatrix::Dense(_) => true,
3552            DesignMatrix::Sparse(_) => false,
3553        }
3554    }
3555}
3556
3557impl DenseDesignOperator for ConditionedDesign {
3558    /// X_c'(w⊙y) = a⊙(X'(w⊙y)) - d·Σ(w⊙y)
3559    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
3560        let mut result = self.inner.compute_xtwy(weights, y)?;
3561        if self.columns.is_empty() {
3562            return Ok(result);
3563        }
3564        let sum_wy: f64 = weights
3565            .iter()
3566            .zip(y.iter())
3567            .map(|(&w, &yi)| w.max(0.0) * yi)
3568            .sum();
3569        for &(j, mean, scale) in &self.columns {
3570            result[j] = (result[j] - mean * sum_wy) / scale;
3571        }
3572        Ok(result)
3573    }
3574
3575    /// diag(X_c M X_c') = diag(X(D_a M D_a)X') - 2·X(D_a M d) + d'Md
3576    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
3577        if self.columns.is_empty() {
3578            return self.inner.quadratic_form_diag(middle);
3579        }
3580        let p = self.ncols();
3581        let mut d = Array1::zeros(p);
3582        for &(j, mean, scale) in &self.columns {
3583            d[j] = mean / scale;
3584        }
3585
3586        // D_a M D_a: scale rows and columns for conditioned indices.
3587        let mut ama = middle.clone();
3588        for &(j, _, scale) in &self.columns {
3589            for k in 0..p {
3590                ama[[j, k]] /= scale;
3591                ama[[k, j]] /= scale;
3592            }
3593        }
3594
3595        // D_a M d
3596        let md = middle.dot(&d);
3597        let mut amd = md;
3598        for &(j, _, scale) in &self.columns {
3599            amd[j] /= scale;
3600        }
3601
3602        let dtmd: f64 = d.dot(&middle.dot(&d));
3603
3604        let mut result = self.inner.quadratic_form_diag(&ama)?;
3605        let x_amd = self.inner.apply(&amd);
3606        for i in 0..result.len() {
3607            result[i] = (result[i] - 2.0 * x_amd[i] + dtmd).max(0.0);
3608        }
3609        Ok(result)
3610    }
3611
3612    fn row_chunk_into(
3613        &self,
3614        rows: Range<usize>,
3615        mut out: ArrayViewMut2<'_, f64>,
3616    ) -> Result<(), MatrixMaterializationError> {
3617        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
3618            return Err(MatrixMaterializationError::MissingRowChunk {
3619                context: "ConditionedDesign::row_chunk_into shape mismatch",
3620            });
3621        }
3622        let mut chunk = self.inner.try_row_chunk(rows)?;
3623        for &(j, mean, scale) in &self.columns {
3624            chunk.column_mut(j).mapv_inplace(|v| (v - mean) / scale);
3625        }
3626        out.assign(&chunk);
3627        Ok(())
3628    }
3629
3630    fn to_dense(&self) -> Array2<f64> {
3631        let mut dense = self.inner.to_dense();
3632        for &(j, mean, scale) in &self.columns {
3633            dense.column_mut(j).mapv_inplace(|v| (v - mean) / scale);
3634        }
3635        dense
3636    }
3637}
3638
3639/// Unified design matrix representation for dense and sparse workflows.
3640///
3641/// Dense matrices are wrapped in Arc for O(1) cloning — at large scale
3642/// design matrices are 100-500MB and get cloned repeatedly during GAMLSS
3643/// family construction, warm-start caching, and prediction.
3644///
3645/// The `Dense` variant wraps both materialized dense matrices and lazy
3646/// dense-backed operators (`DenseDesignMatrix::Lazy`) that implement
3647/// `DenseDesignOperator` without reopening a third top-level storage state.
3648#[derive(Clone)]
3649pub enum DesignMatrix {
3650    Dense(DenseDesignMatrix),
3651    Sparse(SparseDesignMatrix),
3652}
3653
3654impl std::fmt::Debug for DesignMatrix {
3655    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3656        match self {
3657            Self::Dense(m) => write!(f, "DesignMatrix::Dense({}x{})", m.nrows(), m.ncols()),
3658            Self::Sparse(s) => write!(f, "DesignMatrix::Sparse({}x{})", s.nrows(), s.ncols()),
3659        }
3660    }
3661}
3662
3663// Symmetric-matrix container + Gram assembly (#1145): see `symmetric.rs`.
3664mod symmetric;
3665pub use symmetric::*;
3666/// A generic abstraction over a factorized symmetric positive-definite (or regularized) system.
3667pub trait FactorizedSystem: Send + Sync {
3668    /// Solve $H x = b$ for a single right-hand side.
3669    fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String>;
3670
3671    /// Solve $H X = B$ for multiple right-hand sides.
3672    fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String>;
3673
3674    /// Return the log-determinant of the factorized matrix.
3675    fn logdet(&self) -> f64;
3676}
3677
3678pub trait LinearOperator {
3679    fn nrows(&self) -> usize;
3680    fn ncols(&self) -> usize;
3681    fn apply(&self, vector: &Array1<f64>) -> Array1<f64>;
3682    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64>;
3683    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String>;
3684
3685    /// Observed-Hessian / non-canonical-link Gram: `XᵀWX` with sign-honest
3686    /// weights. Returns a dense `Array2<f64>` because the result is symmetric
3687    /// but not guaranteed PSD (so consumers cannot assume the `SymmetricMatrix`
3688    /// PSD contract). Default impl delegates to `diag_xtw_x` for legacy
3689    /// operators; overriding impls may take a sign-aware fast path.
3690    fn xt_diag_x_signed_op(&self, weights: SignedWeightsView<'_>) -> Result<Array2<f64>, String> {
3691        self.diag_xtw_x(&weights.view().to_owned())
3692    }
3693
3694    /// PSD-precondition Gram: `XᵀWX` with `w ≥ 0` discharged at the
3695    /// `PsdWeightsView` constructor. Returns a typed `SymmetricMatrix` so
3696    /// downstream consumers can route through PSD-only solvers (Cholesky).
3697    /// Default impl wraps the signed path's `Array2` in `SymmetricMatrix::Dense`.
3698    fn xt_diag_x_psd_op(&self, weights: PsdWeightsView<'_>) -> Result<SymmetricMatrix, String> {
3699        let xtwx = self.diag_xtw_x(&weights.view().to_owned())?;
3700        Ok(SymmetricMatrix::Dense(xtwx))
3701    }
3702
3703    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
3704        let xtwx = self.diag_xtw_x(weights)?;
3705        Ok(Array1::from_iter((0..self.ncols()).map(|j| xtwx[[j, j]])))
3706    }
3707    fn apply_weighted_normal(
3708        &self,
3709        weights: &Array1<f64>,
3710        vector: &Array1<f64>,
3711        penalty: Option<&Array2<f64>>,
3712        ridge: f64,
3713    ) -> Array1<f64> {
3714        assert_eq!(
3715            weights.len(),
3716            self.nrows(),
3717            "apply_weighted_normal weight length mismatch"
3718        );
3719        assert_eq!(
3720            vector.len(),
3721            self.ncols(),
3722            "apply_weighted_normal vector length mismatch"
3723        );
3724        let xv = self.apply(vector);
3725        let mut weighted_xv = xv;
3726        for i in 0..weighted_xv.len() {
3727            weighted_xv[i] *= weights[i].max(0.0);
3728        }
3729        let mut out = self.apply_transpose(&weighted_xv);
3730        if let Some(pen) = penalty {
3731            out += &fast_av(pen, vector);
3732        }
3733        if ridge > 0.0 {
3734            // BLAS axpy: out += ridge * vector, no temporary allocation.
3735            out.scaled_add(ridge, vector);
3736        }
3737        out
3738    }
3739    fn uses_matrix_free_pcg(&self) -> bool {
3740        false
3741    }
3742    fn solve_system_matrix_free_pcg_try(
3743        &self,
3744        weights: &Array1<f64>,
3745        rhs: &Array1<f64>,
3746        penalty: Option<&Array2<f64>>,
3747        baseridge: f64,
3748    ) -> Result<Array1<f64>, String> {
3749        self.solve_system_matrix_free_pcg_with_info_try(weights, rhs, penalty, baseridge)
3750            .map(|(solution, _)| solution)
3751    }
3752    fn solve_system_matrix_free_pcg_with_info_try(
3753        &self,
3754        weights: &Array1<f64>,
3755        rhs: &Array1<f64>,
3756        penalty: Option<&Array2<f64>>,
3757        baseridge: f64,
3758    ) -> Result<(Array1<f64>, PcgSolveInfo), String> {
3759        if rhs.len() != self.ncols() {
3760            return Err(format!(
3761                "solve_system_matrix_free_pcg rhs dimension mismatch: rhs length {} != ncols {}",
3762                rhs.len(),
3763                self.ncols()
3764            ));
3765        }
3766        if !self.uses_matrix_free_pcg() {
3767            return Err("matrix-free PCG is only enabled for eligible operator types".to_string());
3768        }
3769        if let Some(pen) = penalty
3770            && (pen.nrows() != self.ncols() || pen.ncols() != self.ncols())
3771        {
3772            return Err(format!(
3773                "solve_system_matrix_free_pcg penalty shape mismatch: got {}x{}, expected {}x{}",
3774                pen.nrows(),
3775                pen.ncols(),
3776                self.ncols(),
3777                self.ncols()
3778            ));
3779        }
3780        let p = self.ncols();
3781        for retry in 0..8 {
3782            let ridge = if baseridge > 0.0 {
3783                baseridge * 10f64.powi(retry)
3784            } else {
3785                0.0
3786            };
3787            let normal_op = PenalizedWeightedNormalOperator {
3788                operator: self,
3789                weights,
3790                penalty,
3791                ridge,
3792            };
3793            let preconditioner = normal_op.jacobi_preconditioner()?;
3794            let attempt_started = std::time::Instant::now();
3795            let solved = crate::utils::solve_spd_pcg_with_info(
3796                |v| normal_op.apply(v),
3797                rhs,
3798                &preconditioner,
3799                MATRIX_FREE_PCG_REL_TOL,
3800                MATRIX_FREE_PCG_MAX_ITER.max(4 * p),
3801            );
3802            let elapsed = attempt_started.elapsed().as_secs_f64();
3803            // Progress diagnostics for the matrix-free inner solve. The happy
3804            // path (retry==0, converged) logs at debug to stay quiet inside the
3805            // inner-Newton loop; any ridge escalation — the trust-region-retry
3806            // analog and a strong signal that the operator is ill-conditioned —
3807            // logs at info so a slow fit shows whether time is going into CG
3808            // iterations, repeated HVPs, or escalation churn.
3809            match solved {
3810                Some((solution, info)) if solution.iter().all(|v| v.is_finite()) => {
3811                    if retry > 0 {
3812                        log::info!(
3813                            "[matrix-free PCG] converged after ridge escalation: p={p} retry={retry} ridge={ridge:.3e} iters={} converged={} rel_resid={:.3e} elapsed={elapsed:.3}s",
3814                            info.iterations,
3815                            info.converged,
3816                            info.relative_residual_norm,
3817                        );
3818                    } else {
3819                        log::debug!(
3820                            "[matrix-free PCG] solved: p={p} iters={} converged={} rel_resid={:.3e} elapsed={elapsed:.3}s",
3821                            info.iterations,
3822                            info.converged,
3823                            info.relative_residual_norm,
3824                        );
3825                    }
3826                    return Ok((solution, info));
3827                }
3828                Some((_, info)) => {
3829                    log::info!(
3830                        "[matrix-free PCG] non-finite solution, escalating ridge: p={p} retry={retry} ridge={ridge:.3e} iters={} converged={} rel_resid={:.3e} elapsed={elapsed:.3}s",
3831                        info.iterations,
3832                        info.converged,
3833                        info.relative_residual_norm,
3834                    );
3835                }
3836                None => {
3837                    log::info!(
3838                        "[matrix-free PCG] CG breakdown (non-SPD/NaN), escalating ridge: p={p} retry={retry} ridge={ridge:.3e} elapsed={elapsed:.3}s",
3839                    );
3840                }
3841            }
3842        }
3843        Err("matrix-free PCG failed after ridge retries".to_string())
3844    }
3845    fn factorize_system(
3846        &self,
3847        weights: &Array1<f64>,
3848        penalty: Option<&Array2<f64>>,
3849    ) -> Result<Box<dyn FactorizedSystem>, String> {
3850        let mut system = self.diag_xtw_x(weights)?;
3851        if let Some(pen) = penalty {
3852            if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
3853                return Err(format!(
3854                    "factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
3855                    pen.nrows(),
3856                    pen.ncols(),
3857                    system.nrows(),
3858                    system.ncols()
3859                ));
3860            }
3861            system += pen;
3862        }
3863        let factor = crate::utils::StableSolver::new("linear operator system")
3864            .factorize(&system)
3865            .map_err(|e| format!("factorize_system failed: {e:?}"))?;
3866        Ok(Box::new(factor))
3867    }
3868    fn solve_system(
3869        &self,
3870        weights: &Array1<f64>,
3871        rhs: &Array1<f64>,
3872        penalty: Option<&Array2<f64>>,
3873    ) -> Result<Array1<f64>, String> {
3874        self.solve_systemwith_policy(
3875            weights,
3876            rhs,
3877            penalty,
3878            SPD_SOLVE_RIDGE_FLOOR,
3879            RidgePolicy::explicit_stabilization_pospart(),
3880        )
3881    }
3882    fn solve_systemwith_policy(
3883        &self,
3884        weights: &Array1<f64>,
3885        rhs: &Array1<f64>,
3886        penalty: Option<&Array2<f64>>,
3887        ridge_floor: f64,
3888        ridge_policy: RidgePolicy,
3889    ) -> Result<Array1<f64>, String> {
3890        if rhs.len() != self.ncols() {
3891            return Err(format!(
3892                "solve_systemwith_policy rhs dimension mismatch: rhs length {} != ncols {}",
3893                rhs.len(),
3894                self.ncols()
3895            ));
3896        }
3897        let baseridge = if ridge_policy.include_laplacehessian {
3898            ridge_floor.max(SPD_SOLVE_RIDGE_FLOOR)
3899        } else {
3900            0.0
3901        };
3902        // Try matrix-free PCG first to avoid assembling the dense p×p normal matrix.
3903        if self.uses_matrix_free_pcg()
3904            && self.ncols() >= MATRIX_FREE_PCG_MIN_P
3905            && let Ok(solution) =
3906                self.solve_system_matrix_free_pcg_try(weights, rhs, penalty, baseridge)
3907        {
3908            return Ok(solution);
3909        }
3910        // Fallback: assemble dense system and solve via Cholesky with ridge retries.
3911        let mut system = self.diag_xtw_x(weights)?;
3912        if let Some(pen) = penalty {
3913            if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
3914                return Err(format!(
3915                    "solve_systemwith_policy penalty shape mismatch: got {}x{}, expected {}x{}",
3916                    pen.nrows(),
3917                    pen.ncols(),
3918                    system.nrows(),
3919                    system.ncols()
3920                ));
3921            }
3922            system += pen;
3923        }
3924        crate::utils::StableSolver::new("linear operator system")
3925            .solvevectorwithridge_retries(&system, rhs, baseridge)
3926            .ok_or_else(|| "solve_systemwith_policy failed after ridge retries".to_string())
3927    }
3928}
3929
3930impl LinearOperator for DesignMatrix {
3931    fn uses_matrix_free_pcg(&self) -> bool {
3932        match self {
3933            Self::Dense(matrix) => matrix.uses_matrix_free_pcg(),
3934            Self::Sparse(_) => false,
3935        }
3936    }
3937
3938    fn nrows(&self) -> usize {
3939        match self {
3940            Self::Dense(matrix) => matrix.nrows(),
3941            Self::Sparse(matrix) => matrix.nrows(),
3942        }
3943    }
3944
3945    fn ncols(&self) -> usize {
3946        match self {
3947            Self::Dense(matrix) => matrix.ncols(),
3948            Self::Sparse(matrix) => matrix.ncols(),
3949        }
3950    }
3951
3952    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3953        match self {
3954            Self::Dense(matrix) => matrix.apply(vector),
3955            Self::Sparse(matrix) => {
3956                let mut output = Array1::<f64>::zeros(matrix.nrows());
3957                let (symbolic, values) = matrix.parts();
3958                let col_ptr = symbolic.col_ptr();
3959                let row_idx = symbolic.row_idx();
3960                for col in 0..matrix.ncols() {
3961                    let start = col_ptr[col];
3962                    let end = col_ptr[col + 1];
3963                    let x = vector[col];
3964                    for idx in start..end {
3965                        let row = row_idx[idx];
3966                        output[row] += values[idx] * x;
3967                    }
3968                }
3969                output
3970            }
3971        }
3972    }
3973
3974    fn apply_weighted_normal(
3975        &self,
3976        weights: &Array1<f64>,
3977        vector: &Array1<f64>,
3978        penalty: Option<&Array2<f64>>,
3979        ridge: f64,
3980    ) -> Array1<f64> {
3981        assert_eq!(
3982            weights.len(),
3983            self.nrows(),
3984            "DesignMatrix::apply_weighted_normal weight length mismatch"
3985        );
3986        assert_eq!(
3987            vector.len(),
3988            self.ncols(),
3989            "DesignMatrix::apply_weighted_normal vector length mismatch"
3990        );
3991        match self {
3992            Self::Dense(matrix) => matrix.apply_weighted_normal(weights, vector, penalty, ridge),
3993            Self::Sparse(_) => {
3994                let sparse = self
3995                    .as_sparse()
3996                    .expect("DesignMatrix::Sparse must expose sparse view");
3997                let mut out = if let Some(csr) = sparse.to_csr_arc() {
3998                    let sym = csr.symbolic();
3999                    let row_ptr = sym.row_ptr();
4000                    let col_idx = sym.col_idx();
4001                    let vals = csr.val();
4002                    let mut fused = Array1::<f64>::zeros(self.ncols());
4003                    for i in 0..self.nrows() {
4004                        let wi = weights[i].max(0.0);
4005                        if wi == 0.0 {
4006                            continue;
4007                        }
4008                        let start = row_ptr[i];
4009                        let end = row_ptr[i + 1];
4010                        let mut row_dot = 0.0_f64;
4011                        for ptr in start..end {
4012                            row_dot += vals[ptr] * vector[col_idx[ptr]];
4013                        }
4014                        if row_dot == 0.0 {
4015                            continue;
4016                        }
4017                        let scaled = wi * row_dot;
4018                        for ptr in start..end {
4019                            fused[col_idx[ptr]] += vals[ptr] * scaled;
4020                        }
4021                    }
4022                    fused
4023                } else {
4024                    let xv = self.apply(vector);
4025                    let mut weighted_xv = xv;
4026                    for i in 0..weighted_xv.len() {
4027                        weighted_xv[i] *= weights[i].max(0.0);
4028                    }
4029                    self.apply_transpose(&weighted_xv)
4030                };
4031                if let Some(pen) = penalty {
4032                    out += &fast_av(pen, vector);
4033                }
4034                if ridge > 0.0 {
4035                    for j in 0..out.len() {
4036                        out[j] += ridge * vector[j];
4037                    }
4038                }
4039                out
4040            }
4041        }
4042    }
4043
4044    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
4045        match self {
4046            Self::Dense(matrix) => matrix.apply_transpose(vector),
4047            Self::Sparse(matrix) => {
4048                let mut output = Array1::<f64>::zeros(matrix.ncols());
4049                let (symbolic, values) = matrix.parts();
4050                let col_ptr = symbolic.col_ptr();
4051                let row_idx = symbolic.row_idx();
4052                for col in 0..matrix.ncols() {
4053                    let mut acc = 0.0;
4054                    let start = col_ptr[col];
4055                    let end = col_ptr[col + 1];
4056                    for idx in start..end {
4057                        let row = row_idx[idx];
4058                        acc += values[idx] * vector[row];
4059                    }
4060                    output[col] = acc;
4061                }
4062                output
4063            }
4064        }
4065    }
4066
4067    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
4068        if weights.len() != self.nrows() {
4069            return Err(format!(
4070                "xt_diag_x dimension mismatch: weights length {} != nrows {}",
4071                weights.len(),
4072                self.nrows()
4073            ));
4074        }
4075        let p = self.ncols();
4076        match self {
4077            Self::Dense(x) => x.diag_xtw_x(weights),
4078            Self::Sparse(xs) => {
4079                // Two regimes for sparse-stored designs:
4080                //
4081                //   (A) Numerically dense — Matern / Duchon radial bases place
4082                //       a nonzero in every column for every row, so XᵀWX has
4083                //       O(p²) fills and the scalar row-loop is dominated by
4084                //       memory traffic over O(n·nnz_row²) ≈ O(n·p²) ops.  Faer
4085                //       hand-tuned BLAS3 runs parallel + SIMD over either the
4086                //       cached dense design or bounded CSC-materialized row
4087                //       chunks, depending on the dense materialization policy.
4088                //
4089                //   (B) Genuinely sparse — B-spline / banded bases keep
4090                //       nnz_row at a small constant (4–6), so the per-row
4091                //       O(nnz_row²) work is ~25× fewer FLOPs than the dense
4092                //       matmul and densification is a regression.  Run a
4093                //       row-parallel scalar accumulation in that regime.
4094                //
4095                // Heuristic: average nnz_per_row >= p/4 picks (A).  In practice
4096                // the upstream `should_use_sparse_native_pirls` already routes
4097                // banded sparse XᵀWX designs to a separate sparse-native PIRLS
4098                // path that does NOT call this function, so the (A) branch
4099                // covers every actual call site we have today; the (B) branch
4100                // is a correctness-preserving safety net for future callers.
4101                let n = self.nrows();
4102                let nnz_x = xs.as_ref().val().len();
4103                let avg_nnz_row = if n > 0 { nnz_x / n } else { p };
4104                let dense_regime = 4 * avg_nnz_row >= p;
4105                if dense_regime {
4106                    let mut xtwx = Array2::<f64>::zeros((p, p));
4107                    let dense_bytes =
4108                        checked_dense_nbytes(n, p, "DesignMatrix::diag_xtw_x dense sparse route")?;
4109                    if dense_bytes <= MAX_SPARSE_TO_DENSE_BYTES {
4110                        let xd =
4111                            xs.try_to_dense_arc("DesignMatrix::diag_xtw_x dense sparse route")?;
4112                        stream_weighted_crossprod_into(
4113                            xd.as_ref(),
4114                            weights,
4115                            &mut xtwx,
4116                            CrossprodStructure::Full,
4117                            CrossprodAccum::Replace,
4118                            effective_global_parallelism(),
4119                        );
4120                    } else {
4121                        let (symbolic, values) = xs.parts();
4122                        streaming_sparse_csc_xt_diag_x(
4123                            symbolic.col_ptr(),
4124                            symbolic.row_idx(),
4125                            values,
4126                            n,
4127                            p,
4128                            weights.view(),
4129                            &mut xtwx,
4130                        );
4131                    }
4132                    return Ok(xtwx);
4133                }
4134                let csr = xs
4135                    .to_csr_arc()
4136                    .ok_or_else(|| "failed to obtain CSR view in xt_diag_x".to_string())?;
4137                let sym = csr.symbolic();
4138                Ok(sparse_csr_weighted_xtwx(
4139                    sym.row_ptr(),
4140                    sym.col_idx(),
4141                    csr.val(),
4142                    n,
4143                    p,
4144                    weights.view(),
4145                ))
4146            }
4147        }
4148    }
4149
4150    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
4151        if weights.len() != self.nrows() {
4152            return Err(format!(
4153                "diag_gram dimension mismatch: weights length {} != nrows {}",
4154                weights.len(),
4155                self.nrows()
4156            ));
4157        }
4158        let p = self.ncols();
4159        match self {
4160            Self::Dense(x) => x.diag_gram(weights),
4161            Self::Sparse(xs) => {
4162                let csr = xs
4163                    .to_csr_arc()
4164                    .ok_or_else(|| "failed to obtain CSR view in diag_gram".to_string())?;
4165                let sym = csr.symbolic();
4166                Ok(sparse_csr_diag_gram(
4167                    sym.row_ptr(),
4168                    sym.col_idx(),
4169                    csr.val(),
4170                    self.nrows(),
4171                    p,
4172                    weights.view(),
4173                ))
4174            }
4175        }
4176    }
4177
4178    fn factorize_system(
4179        &self,
4180        weights: &Array1<f64>,
4181        penalty: Option<&Array2<f64>>,
4182    ) -> Result<Box<dyn FactorizedSystem>, String> {
4183        if weights.len() != self.nrows() {
4184            return Err(format!(
4185                "factorize_system dimension mismatch: weights length {} != nrows {}",
4186                weights.len(),
4187                self.nrows()
4188            ));
4189        }
4190        match self {
4191            Self::Dense(_) => self.factorize_system_dense(weights, penalty),
4192            Self::Sparse(matrix) => {
4193                let system = assemble_sparseweighted_gram_system(matrix, weights, penalty)?;
4194                let factor = crate::sparse_exact::factorize_sparse_spd(&system)
4195                    .map_err(|e| format!("factorize_system failed: {e:?}"))?;
4196                Ok(Box::new(factor))
4197            }
4198        }
4199    }
4200}
4201
4202impl DenseDesignOperator for DesignMatrix {
4203    fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
4204        if weights.len() != self.nrows() || y.len() != self.nrows() {
4205            return Err(format!(
4206                "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
4207                weights.len(),
4208                y.len(),
4209                self.nrows()
4210            ));
4211        }
4212        match self {
4213            Self::Dense(x) => x.compute_xtwy(weights, y),
4214            Self::Sparse(xs) => {
4215                let csr = xs
4216                    .as_ref()
4217                    .to_row_major()
4218                    .map_err(|_| "failed to obtain CSR view in compute_xtwy".to_string())?;
4219                let sym = csr.symbolic();
4220                let row_ptr = sym.row_ptr();
4221                let col_idx = sym.col_idx();
4222                let vals = csr.val();
4223                let mut out = Array1::<f64>::zeros(xs.ncols());
4224                for i in 0..xs.nrows() {
4225                    let scaled = weights[i].max(0.0) * y[i];
4226                    if scaled == 0.0 {
4227                        continue;
4228                    }
4229                    for idx in row_ptr[i]..row_ptr[i + 1] {
4230                        out[col_idx[idx]] += vals[idx] * scaled;
4231                    }
4232                }
4233                Ok(out)
4234            }
4235        }
4236    }
4237
4238    fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
4239        if middle.nrows() != self.ncols() || middle.ncols() != self.ncols() {
4240            return Err(format!(
4241                "quadratic_form_diag dimension mismatch: matrix is {}x{}, expected {}x{}",
4242                middle.nrows(),
4243                middle.ncols(),
4244                self.ncols(),
4245                self.ncols()
4246            ));
4247        }
4248
4249        match self {
4250            Self::Dense(xd) => xd.quadratic_form_diag(middle),
4251            Self::Sparse(xs) => {
4252                let csr = xs
4253                    .to_csr_arc()
4254                    .ok_or_else(|| "quadratic_form_diag: failed to obtain CSR view".to_string())?;
4255                let sym = csr.symbolic();
4256                let row_ptr = sym.row_ptr();
4257                let col_idx = sym.col_idx();
4258                let vals = csr.val();
4259                let mut out = Array1::<f64>::zeros(self.nrows());
4260                for i in 0..xs.nrows() {
4261                    let start = row_ptr[i];
4262                    let end = row_ptr[i + 1];
4263                    let mut acc = 0.0_f64;
4264                    for a in start..end {
4265                        let j = col_idx[a];
4266                        let xij = vals[a];
4267                        for b in start..end {
4268                            let k = col_idx[b];
4269                            let xik = vals[b];
4270                            acc += xij * middle[[j, k]] * xik;
4271                        }
4272                    }
4273                    out[i] = acc.max(0.0);
4274                }
4275                Ok(out)
4276            }
4277        }
4278    }
4279
4280    fn row_chunk_into(
4281        &self,
4282        rows: Range<usize>,
4283        mut out: ArrayViewMut2<'_, f64>,
4284    ) -> Result<(), MatrixMaterializationError> {
4285        if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
4286            return Err(MatrixMaterializationError::MissingRowChunk {
4287                context: "DesignMatrix::row_chunk_into shape mismatch",
4288            });
4289        }
4290        match self {
4291            Self::Dense(matrix) => matrix.row_chunk_into(rows, out),
4292            Self::Sparse(matrix) => {
4293                out.fill(0.0);
4294                let csr =
4295                    matrix
4296                        .to_csr_arc()
4297                        .ok_or(MatrixMaterializationError::MissingRowChunk {
4298                            context: "DesignMatrix::row_chunk_into: failed to obtain CSR view",
4299                        })?;
4300                let sym = csr.symbolic();
4301                let row_ptr = sym.row_ptr();
4302                let col_idx = sym.col_idx();
4303                let vals = csr.val();
4304                for (local_row, row) in rows.enumerate() {
4305                    for ptr in row_ptr[row]..row_ptr[row + 1] {
4306                        out[[local_row, col_idx[ptr]]] = vals[ptr];
4307                    }
4308                }
4309                Ok(())
4310            }
4311        }
4312    }
4313
4314    fn to_dense(&self) -> Array2<f64> {
4315        DesignMatrix::to_dense(self)
4316    }
4317}
4318
4319impl LinearOperator for DenseRightProductView<'_> {
4320    fn nrows(&self) -> usize {
4321        self.base.nrows()
4322    }
4323
4324    fn ncols(&self) -> usize {
4325        self.transformed_ncols()
4326    }
4327
4328    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
4329        let rhs;
4330        let v = match (self.second, self.first) {
4331            (None, None) => vector,
4332            (Some(s), None) => {
4333                rhs = fast_av(s, vector);
4334                &rhs
4335            }
4336            (None, Some(f)) => {
4337                rhs = fast_av(f, vector);
4338                &rhs
4339            }
4340            (Some(s), Some(f)) => {
4341                let tmp = fast_av(s, vector);
4342                rhs = fast_av(f, &tmp);
4343                &rhs
4344            }
4345        };
4346        fast_av(self.base, v)
4347    }
4348
4349    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
4350        let mut out = fast_atv(self.base, vector);
4351        if let Some(factor) = self.first {
4352            out = fast_atv(factor, &out);
4353        }
4354        if let Some(factor) = self.second {
4355            out = fast_atv(factor, &out);
4356        }
4357        out
4358    }
4359
4360    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
4361        if weights.len() != self.nrows() {
4362            return Err(format!(
4363                "xt_diag_x dimension mismatch: weights length {} != nrows {}",
4364                weights.len(),
4365                self.nrows()
4366            ));
4367        }
4368        let mut gram = fast_xt_diag_x(self.base, weights);
4369        if let Some(factor) = self.first {
4370            gram = fast_ab(&fast_atb(factor, &gram), factor);
4371        }
4372        if let Some(factor) = self.second {
4373            gram = fast_ab(&fast_atb(factor, &gram), factor);
4374        }
4375        Ok(gram)
4376    }
4377
4378    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
4379        Ok(self.diag_xtw_x(weights)?.diag().to_owned())
4380    }
4381}
4382
4383impl DenseRightProductView<'_> {
4384    pub fn compute_xtwy(
4385        &self,
4386        weights: &Array1<f64>,
4387        y: &Array1<f64>,
4388    ) -> Result<Array1<f64>, String> {
4389        if weights.len() != self.nrows() || y.len() != self.nrows() {
4390            return Err(format!(
4391                "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
4392                weights.len(),
4393                y.len(),
4394                self.nrows()
4395            ));
4396        }
4397        let weighted_xty = dense_transpose_weighted_response(self.base, weights, y, None);
4398        let mut out = weighted_xty;
4399        if let Some(factor) = self.first {
4400            out = fast_atv(factor, &out);
4401        }
4402        if let Some(factor) = self.second {
4403            out = fast_atv(factor, &out);
4404        }
4405        Ok(out)
4406    }
4407
4408    pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
4409        let dense = self.materialize();
4410        DesignMatrix::Dense(DenseDesignMatrix::from(dense)).quadratic_form_diag(middle)
4411    }
4412}
4413
4414impl LinearOperator for EmbeddedColumnBlock<'_> {
4415    fn nrows(&self) -> usize {
4416        self.local.nrows()
4417    }
4418
4419    fn ncols(&self) -> usize {
4420        self.total_cols
4421    }
4422
4423    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
4424        fast_av(
4425            self.local,
4426            &vector.slice(ndarray::s![self.global_range.clone()]),
4427        )
4428    }
4429
4430    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
4431        let mut out = Array1::<f64>::zeros(self.total_cols);
4432        out.slice_mut(ndarray::s![self.global_range.clone()])
4433            .assign(&fast_atv(self.local, vector));
4434        out
4435    }
4436
4437    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
4438        if weights.len() != self.nrows() {
4439            return Err(format!(
4440                "xt_diag_x dimension mismatch: weights length {} != nrows {}",
4441                weights.len(),
4442                self.nrows()
4443            ));
4444        }
4445        let mut out = Array2::<f64>::zeros((self.total_cols, self.total_cols));
4446        let local = fast_xt_diag_x(self.local, weights);
4447        out.slice_mut(ndarray::s![
4448            self.global_range.clone(),
4449            self.global_range.clone()
4450        ])
4451        .assign(&local);
4452        Ok(out)
4453    }
4454
4455    fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
4456        let mut out = Array1::<f64>::zeros(self.total_cols);
4457        let local =
4458            DesignMatrix::Dense(DenseDesignMatrix::from(self.local.clone())).diag_gram(weights)?;
4459        out.slice_mut(ndarray::s![self.global_range.clone()])
4460            .assign(&local);
4461        Ok(out)
4462    }
4463}
4464
4465impl EmbeddedColumnBlock<'_> {
4466    pub fn compute_xtwy(
4467        &self,
4468        weights: &Array1<f64>,
4469        y: &Array1<f64>,
4470    ) -> Result<Array1<f64>, String> {
4471        if weights.len() != self.nrows() || y.len() != self.nrows() {
4472            return Err(format!(
4473                "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
4474                weights.len(),
4475                y.len(),
4476                self.nrows()
4477            ));
4478        }
4479        let local = dense_transpose_weighted_response(self.local, weights, y, None);
4480        let mut out = Array1::<f64>::zeros(self.total_cols);
4481        out.slice_mut(ndarray::s![self.global_range.clone()])
4482            .assign(&local);
4483        Ok(out)
4484    }
4485
4486    pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
4487        let middle_local = middle
4488            .slice(ndarray::s![
4489                self.global_range.clone(),
4490                self.global_range.clone()
4491            ])
4492            .to_owned();
4493        DesignMatrix::Dense(DenseDesignMatrix::from(self.local.clone()))
4494            .quadratic_form_diag(&middle_local)
4495    }
4496}
4497
4498impl DesignMatrix {
4499    fn factorize_system_dense(
4500        &self,
4501        weights: &Array1<f64>,
4502        penalty: Option<&Array2<f64>>,
4503    ) -> Result<Box<dyn FactorizedSystem>, String> {
4504        let mut system = self.diag_xtw_x(weights)?;
4505        if let Some(pen) = penalty {
4506            if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
4507                return Err(format!(
4508                    "factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
4509                    pen.nrows(),
4510                    pen.ncols(),
4511                    system.nrows(),
4512                    system.ncols()
4513                ));
4514            }
4515            system += pen;
4516        }
4517        let factor = crate::utils::StableSolver::new("linear operator system")
4518            .factorize(&system)
4519            .map_err(|e| format!("factorize_system failed: {e:?}"))?;
4520        Ok(Box::new(factor))
4521    }
4522}
4523
4524fn assemble_sparseweighted_gram_system(
4525    matrix: &SparseDesignMatrix,
4526    weights: &Array1<f64>,
4527    penalty: Option<&Array2<f64>>,
4528) -> Result<SparseColMat<usize, f64>, String> {
4529    let csr = matrix
4530        .to_csr_arc()
4531        .ok_or_else(|| "failed to obtain CSR view in factorize_system".to_string())?;
4532    let sym = csr.symbolic();
4533    let row_ptr = sym.row_ptr();
4534    let col_idx = sym.col_idx();
4535    let vals = csr.val();
4536    let p = matrix.ncols();
4537    let mut upper = BTreeMap::<(usize, usize), f64>::new();
4538
4539    for i in 0..csr.nrows() {
4540        let wi = weights[i].max(0.0);
4541        if wi == 0.0 {
4542            continue;
4543        }
4544        let start = row_ptr[i];
4545        let end = row_ptr[i + 1];
4546        for a_ptr in start..end {
4547            let a = col_idx[a_ptr];
4548            let xa = vals[a_ptr];
4549            for b_ptr in a_ptr..end {
4550                let b = col_idx[b_ptr];
4551                let xb = vals[b_ptr];
4552                let key = if a <= b { (a, b) } else { (b, a) };
4553                *upper.entry(key).or_insert(0.0) += wi * xa * xb;
4554            }
4555        }
4556    }
4557
4558    if let Some(pen) = penalty {
4559        if pen.nrows() != p || pen.ncols() != p {
4560            return Err(format!(
4561                "factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
4562                pen.nrows(),
4563                pen.ncols(),
4564                p,
4565                p
4566            ));
4567        }
4568        for i in 0..p {
4569            for j in i..p {
4570                let value = pen[[i, j]];
4571                if value != 0.0 {
4572                    *upper.entry((i, j)).or_insert(0.0) += value;
4573                }
4574            }
4575        }
4576    }
4577
4578    let mut triplets = Vec::with_capacity(upper.len());
4579    for ((row, col), value) in upper {
4580        if value != 0.0 {
4581            triplets.push(Triplet::new(row, col, value));
4582        }
4583    }
4584    SparseColMat::try_new_from_triplets(p, p, &triplets)
4585        .map_err(|_| "failed to build sparse penalized system".to_string())
4586}
4587
4588impl DesignMatrix {
4589    /// Horizontally concatenate design blocks without forcing eager densification.
4590    ///
4591    /// The returned matrix is a lazy `BlockDesignOperator` when more than one
4592    /// block is provided, so operator-backed inputs stay chunkable on the
4593    /// prediction path.
4594    pub fn hstack(blocks: Vec<DesignMatrix>) -> Result<Self, String> {
4595        if blocks.is_empty() {
4596            return Err("DesignMatrix::hstack requires at least one block".to_string());
4597        }
4598        if blocks.len() == 1 {
4599            return Ok(blocks.into_iter().next().expect("non-empty block list"));
4600        }
4601        let operator =
4602            BlockDesignOperator::new(blocks.into_iter().map(DesignBlock::from).collect())?;
4603        Ok(Self::Dense(DenseDesignMatrix::from(Arc::new(operator))))
4604    }
4605
4606    pub fn nrows(&self) -> usize {
4607        <Self as LinearOperator>::nrows(self)
4608    }
4609
4610    pub fn ncols(&self) -> usize {
4611        <Self as LinearOperator>::ncols(self)
4612    }
4613
4614    /// Extract a dense row chunk without materializing the full matrix.
4615    ///
4616    /// Returns a `(rows.len(), ncols())` dense `Array2` for the requested row
4617    /// range. For lazy dense designs this delegates to the operator-backed
4618    /// implementation, which should remain O(chunk).
4619    pub fn try_row_chunk(
4620        &self,
4621        rows: Range<usize>,
4622    ) -> Result<Array2<f64>, MatrixMaterializationError> {
4623        match self {
4624            Self::Dense(matrix) => matrix.try_row_chunk(rows),
4625            Self::Sparse(matrix) => {
4626                let csr =
4627                    matrix
4628                        .to_csr_arc()
4629                        .ok_or(MatrixMaterializationError::MissingRowChunk {
4630                            context: "DesignMatrix::try_row_chunk: failed to obtain CSR view",
4631                        })?;
4632                let sym = csr.symbolic();
4633                let row_ptr = sym.row_ptr();
4634                let col_idx = sym.col_idx();
4635                let vals = csr.val();
4636                let chunk_rows = rows.end - rows.start;
4637                let ncols = self.ncols();
4638                let mut out = Array2::<f64>::zeros((chunk_rows, ncols));
4639                for (local_row, row) in rows.enumerate() {
4640                    for ptr in row_ptr[row]..row_ptr[row + 1] {
4641                        out[[local_row, col_idx[ptr]]] = vals[ptr];
4642                    }
4643                }
4644                Ok(out)
4645            }
4646        }
4647    }
4648
4649    /// Borrow-only row-chunk accessor: writes the requested rows into an
4650    /// existing `(rows.len(), ncols())` buffer instead of allocating a fresh
4651    /// `Array2<f64>` like [`Self::try_row_chunk`]. Used by hot per-row loops
4652    /// (e.g. latent-survival evaluate) that want to reuse a single 1-row
4653    /// scratch buffer across iterations.
4654    pub fn row_chunk_into(
4655        &self,
4656        rows: Range<usize>,
4657        out: ArrayViewMut2<'_, f64>,
4658    ) -> Result<(), MatrixMaterializationError> {
4659        <Self as DenseDesignOperator>::row_chunk_into(self, rows, out)
4660    }
4661
4662    pub fn try_to_dense_by_chunks(&self, context: &str) -> Result<Array2<f64>, String> {
4663        let n = self.nrows();
4664        let p = self.ncols();
4665        let chunk_rows = dense_materialization_chunk_rows(n, p);
4666        let mut out = Array2::<f64>::zeros((n, p));
4667        for start in (0..n).step_by(chunk_rows) {
4668            let end = (start + chunk_rows).min(n);
4669            let slice = out.slice_mut(s![start..end, ..]);
4670            self.row_chunk_into(start..end, slice)
4671                .map_err(|err| format!("{context}: failed to materialize row chunk: {err}"))?;
4672        }
4673        Ok(out)
4674    }
4675
4676    /// Like [`Self::try_to_dense_by_chunks`] but refuses to allocate when the
4677    /// dense footprint would exceed `max_bytes`. Returned `Err` is the same
4678    /// shape as a densification-refused error from the resource policy, so
4679    /// observability-only callers can convert it into a `warn!` and skip
4680    /// without ever touching the allocator at huge `n`.
4681    pub fn try_to_dense_by_chunks_budgeted(
4682        &self,
4683        context: &str,
4684        max_bytes: usize,
4685    ) -> Result<Array2<f64>, String> {
4686        let n = self.nrows();
4687        let p = self.ncols();
4688        let dense_bytes = checked_dense_nbytes(n, p, context)?;
4689        if dense_bytes > max_bytes {
4690            let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
4691            let cap_gib = max_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
4692            return Err(MatrixError::DensificationRefused {
4693                reason: format!(
4694                    "{context}: refusing to densify {n}x{p} (~{gib:.2} GiB, cap ~{cap_gib:.2} GiB)"
4695                ),
4696            }
4697            .into());
4698        }
4699        self.try_to_dense_by_chunks(context)
4700    }
4701
4702    /// Dot a single design row against a coefficient vector without allocating
4703    /// a standalone row buffer when the underlying storage permits.
4704    pub fn dot_row(&self, row: usize, beta: &Array1<f64>) -> f64 {
4705        self.dot_row_view(row, beta.view())
4706    }
4707
4708    pub fn dot_row_view(&self, row: usize, beta: ArrayView1<'_, f64>) -> f64 {
4709        assert_eq!(
4710            beta.len(),
4711            self.ncols(),
4712            "DesignMatrix::dot_row_view length mismatch: beta={}, ncols={}",
4713            beta.len(),
4714            self.ncols()
4715        );
4716        match self {
4717            Self::Dense(matrix) => {
4718                if let Some(dense) = matrix.as_dense_ref() {
4719                    dense.row(row).dot(&beta)
4720                } else {
4721                    matrix
4722                        .try_row_chunk(row..row + 1)
4723                        .expect("DesignMatrix::dot_row_view: try_row_chunk must succeed")
4724                        .row(0)
4725                        .dot(&beta)
4726                }
4727            }
4728            Self::Sparse(matrix) => {
4729                // SAFETY: `to_csr_arc` only returns `None` if the underlying
4730                // SparseColMat fails `to_row_major`, which is infallible for
4731                // any well-formed sparse matrix the surrounding type system
4732                // permits (csc → csr conversion requires only valid column
4733                // pointers). Reaching `None` would mean the SparseDesignMatrix
4734                // invariant was violated upstream.
4735                // SAFETY: SparseDesignMatrix invariants guarantee csc→csr conversion succeeds.
4736                let csr = matrix
4737                    .to_csr_arc()
4738                    .expect("DesignMatrix::dot_row: failed to obtain CSR view");
4739                let sym = csr.symbolic();
4740                let row_ptr = sym.row_ptr();
4741                let col_idx = sym.col_idx();
4742                let vals = csr.val();
4743                let mut out = 0.0;
4744                for ptr in row_ptr[row]..row_ptr[row + 1] {
4745                    out += vals[ptr] * beta[col_idx[ptr]];
4746                }
4747                out
4748            }
4749        }
4750    }
4751
4752    /// Add `alpha * X[row, :]` into `out` without allocating a row buffer.
4753    pub fn axpy_row_into(
4754        &self,
4755        row: usize,
4756        alpha: f64,
4757        out: &mut ArrayViewMut1<'_, f64>,
4758    ) -> Result<(), String> {
4759        self.axpy_row_into_impl(row, alpha, out, false, "axpy_row_into")
4760    }
4761
4762    /// Add `alpha * X[row, :]^2` elementwise into `out` without allocating a
4763    /// standalone row buffer.
4764    pub fn squared_axpy_row_into(
4765        &self,
4766        row: usize,
4767        alpha: f64,
4768        out: &mut ArrayViewMut1<'_, f64>,
4769    ) -> Result<(), String> {
4770        self.axpy_row_into_impl(row, alpha, out, true, "squared_axpy_row_into")
4771    }
4772
4773    /// Shared kernel for [`axpy_row_into`](Self::axpy_row_into) and
4774    /// [`squared_axpy_row_into`](Self::squared_axpy_row_into): adds
4775    /// `alpha * X[row, :]` (when `square` is `false`) or
4776    /// `alpha * X[row, :]^2` elementwise (when `square` is `true`) into `out`
4777    /// without allocating a row buffer. `method` names the public entry point
4778    /// in error messages.
4779    #[inline]
4780    fn axpy_row_into_impl(
4781        &self,
4782        row: usize,
4783        alpha: f64,
4784        out: &mut ArrayViewMut1<'_, f64>,
4785        square: bool,
4786        method: &str,
4787    ) -> Result<(), String> {
4788        if out.len() != self.ncols() {
4789            return Err(format!(
4790                "DesignMatrix::{method} length mismatch: out={}, ncols={}",
4791                out.len(),
4792                self.ncols()
4793            ));
4794        }
4795        if alpha == 0.0 {
4796            return Ok(());
4797        }
4798        // Per-element scaling: `alpha * v` (axpy) or `alpha * v^2` (squared).
4799        let scale = |value: f64| {
4800            if square {
4801                alpha * value * value
4802            } else {
4803                alpha * value
4804            }
4805        };
4806        match self {
4807            Self::Dense(matrix) => {
4808                if let Some(dense) = matrix.as_dense_ref() {
4809                    for (dst, &value) in out.iter_mut().zip(dense.row(row).iter()) {
4810                        *dst += scale(value);
4811                    }
4812                } else {
4813                    let chunk = matrix
4814                        .try_row_chunk(row..row + 1)
4815                        .map_err(|e| format!("DesignMatrix::{method}: {e}"))?;
4816                    for (dst, &value) in out.iter_mut().zip(chunk.row(0).iter()) {
4817                        *dst += scale(value);
4818                    }
4819                }
4820            }
4821            Self::Sparse(matrix) => {
4822                // SAFETY: `to_csr_arc` returns `None` only if csc→csr conversion
4823                // fails, which is infallible for the well-formed sparse matrices
4824                // that `SparseDesignMatrix` is contractually allowed to hold.
4825                // SAFETY: SparseDesignMatrix invariants guarantee csc→csr conversion succeeds.
4826                let csr = matrix
4827                    .to_csr_arc()
4828                    .ok_or_else(|| format!("DesignMatrix::{method}: failed to obtain CSR view"))?;
4829                let sym = csr.symbolic();
4830                let row_ptr = sym.row_ptr();
4831                let col_idx = sym.col_idx();
4832                let vals = csr.val();
4833                for ptr in row_ptr[row]..row_ptr[row + 1] {
4834                    out[col_idx[ptr]] += scale(vals[ptr]);
4835                }
4836            }
4837        }
4838        Ok(())
4839    }
4840
4841    /// Add `alpha * self[row, :] * other[row, :]` elementwise into `out`.
4842    ///
4843    /// Both matrices must have the same number of columns (== `out.len()`).
4844    /// For Sparse×Sparse this runs in O(nnz_lhs + nnz_rhs) via sorted
4845    /// merge-intersection on the CSR column indices — no dense expansion.
4846    pub fn crossdiag_axpy_row_into(
4847        &self,
4848        row: usize,
4849        other: &DesignMatrix,
4850        alpha: f64,
4851        out: &mut ArrayViewMut1<'_, f64>,
4852    ) -> Result<(), String> {
4853        assert_eq!(self.ncols(), other.ncols());
4854        assert_eq!(out.len(), self.ncols());
4855        if alpha == 0.0 {
4856            return Ok(());
4857        }
4858        match (self, other) {
4859            (Self::Dense(lhs), Self::Dense(rhs)) => {
4860                let lhs_chunk;
4861                let rhs_chunk;
4862                let x = if let Some(lhs_dense) = lhs.as_dense_ref() {
4863                    lhs_dense.row(row)
4864                } else {
4865                    lhs_chunk = lhs
4866                        .try_row_chunk(row..row + 1)
4867                        .map_err(|e| format!("crossdiag_axpy_row_into lhs: {e}"))?;
4868                    lhs_chunk.row(0)
4869                };
4870                let y = if let Some(rhs_dense) = rhs.as_dense_ref() {
4871                    rhs_dense.row(row)
4872                } else {
4873                    rhs_chunk = rhs
4874                        .try_row_chunk(row..row + 1)
4875                        .map_err(|e| format!("crossdiag_axpy_row_into rhs: {e}"))?;
4876                    rhs_chunk.row(0)
4877                };
4878                for (dst, (&xi, &yi)) in out.iter_mut().zip(x.iter().zip(y.iter())) {
4879                    *dst += alpha * xi * yi;
4880                }
4881            }
4882            (Self::Sparse(lhs), Self::Sparse(rhs)) => {
4883                // `to_csr_arc` returns `None` only if csc→csr conversion fails;
4884                // `SparseDesignMatrix`'s validation invariants make that
4885                // structurally impossible, but the function returns `Result`
4886                // so propagate rather than panic if a future invariant break
4887                // surfaces here.
4888                let lhs_csr = lhs.to_csr_arc().ok_or_else(|| {
4889                    "crossdiag_axpy_row_into: failed to obtain lhs CSR view".to_string()
4890                })?;
4891                let rhs_csr = rhs.to_csr_arc().ok_or_else(|| {
4892                    "crossdiag_axpy_row_into: failed to obtain rhs CSR view".to_string()
4893                })?;
4894                let lhs_sym = lhs_csr.symbolic();
4895                let rhs_sym = rhs_csr.symbolic();
4896                let lhs_rp = lhs_sym.row_ptr();
4897                let rhs_rp = rhs_sym.row_ptr();
4898                let lhs_ci = lhs_sym.col_idx();
4899                let rhs_ci = rhs_sym.col_idx();
4900                let lhs_v = lhs_csr.val();
4901                let rhs_v = rhs_csr.val();
4902                // Merge-intersection: both col_idx slices are sorted.
4903                let mut li = lhs_rp[row];
4904                let mut ri = rhs_rp[row];
4905                let l_end = lhs_rp[row + 1];
4906                let r_end = rhs_rp[row + 1];
4907                while li < l_end && ri < r_end {
4908                    let lc = lhs_ci[li];
4909                    let rc = rhs_ci[ri];
4910                    if lc == rc {
4911                        out[lc] += alpha * lhs_v[li] * rhs_v[ri];
4912                        li += 1;
4913                        ri += 1;
4914                    } else if lc < rc {
4915                        li += 1;
4916                    } else {
4917                        ri += 1;
4918                    }
4919                }
4920            }
4921            _ => {
4922                // Mixed dense/sparse: iterate the sparse side, index into dense.
4923                let (sparse_mat, dense_mat) = match (self, other) {
4924                    (Self::Sparse(s), Self::Dense(d)) => (s, d),
4925                    (Self::Dense(d), Self::Sparse(s)) => (s, d),
4926                    // Outer match's first two arms already handled (Dense,Dense)
4927                    // and (Sparse,Sparse); only mixed pairs reach this fallback.
4928                    _ => {
4929                        return Err(
4930                            "crossdiag_axpy_row_into: mixed-arm dispatch reached non-mixed pair"
4931                                .to_string(),
4932                        );
4933                    }
4934                };
4935                // Same CSR conversion contract as the (Sparse, Sparse) arm
4936                // above — propagate the (structurally impossible) failure
4937                // through this fn's `Result` rather than panicking.
4938                let csr = sparse_mat.to_csr_arc().ok_or_else(|| {
4939                    "crossdiag_axpy_row_into: failed to obtain CSR view".to_string()
4940                })?;
4941                let sym = csr.symbolic();
4942                let row_ptr = sym.row_ptr();
4943                let col_idx = sym.col_idx();
4944                let vals = csr.val();
4945                let dense_chunk;
4946                let dense_row = if let Some(dense_ref) = dense_mat.as_dense_ref() {
4947                    dense_ref.row(row)
4948                } else {
4949                    dense_chunk = dense_mat
4950                        .try_row_chunk(row..row + 1)
4951                        .map_err(|e| format!("crossdiag_axpy_row_into dense chunk: {e}"))?;
4952                    dense_chunk.row(0)
4953                };
4954                for ptr in row_ptr[row]..row_ptr[row + 1] {
4955                    let c = col_idx[ptr];
4956                    out[c] += alpha * vals[ptr] * dense_row[c];
4957                }
4958            }
4959        }
4960        Ok(())
4961    }
4962
4963    /// Symmetric rank-1 update `target += alpha * x_row x_row^T` for one row.
4964    pub fn syr_row_into(
4965        &self,
4966        row: usize,
4967        alpha: f64,
4968        target: &mut Array2<f64>,
4969    ) -> Result<(), String> {
4970        self.syr_row_into_view(row, alpha, target.view_mut())
4971    }
4972
4973    /// Like `syr_row_into` but accepts a mutable view, so callers can pass
4974    /// a slice of a larger matrix without allocating a temporary.
4975    pub fn syr_row_into_view(
4976        &self,
4977        row: usize,
4978        alpha: f64,
4979        mut target: ArrayViewMut2<'_, f64>,
4980    ) -> Result<(), String> {
4981        if target.nrows() != self.ncols() || target.ncols() != self.ncols() {
4982            return Err(format!(
4983                "DesignMatrix::syr_row_into shape mismatch: target={}x{}, ncols={}",
4984                target.nrows(),
4985                target.ncols(),
4986                self.ncols()
4987            ));
4988        }
4989        if alpha == 0.0 {
4990            return Ok(());
4991        }
4992        match self {
4993            Self::Dense(matrix) => {
4994                if let Some(dense) = matrix.as_dense_ref() {
4995                    let x = dense.row(row);
4996                    for i in 0..x.len() {
4997                        let xi = x[i];
4998                        if xi == 0.0 {
4999                            continue;
5000                        }
5001                        for j in 0..x.len() {
5002                            target[[i, j]] += alpha * xi * x[j];
5003                        }
5004                    }
5005                } else {
5006                    let chunk = matrix
5007                        .try_row_chunk(row..row + 1)
5008                        .map_err(|e| format!("DesignMatrix::syr_row_into: {e}"))?;
5009                    let x = chunk.row(0);
5010                    for i in 0..x.len() {
5011                        let xi = x[i];
5012                        if xi == 0.0 {
5013                            continue;
5014                        }
5015                        for j in 0..x.len() {
5016                            target[[i, j]] += alpha * xi * x[j];
5017                        }
5018                    }
5019                }
5020            }
5021            Self::Sparse(matrix) => {
5022                // SAFETY: `to_csr_arc` returns `None` only on csc→csr conversion
5023                // failure for a malformed sparse matrix; `SparseDesignMatrix`
5024                // invariants forbid that case.
5025                // SAFETY: SparseDesignMatrix invariants guarantee csc→csr conversion succeeds.
5026                let csr = matrix.to_csr_arc().ok_or_else(|| {
5027                    "DesignMatrix::syr_row_into: failed to obtain CSR view".to_string()
5028                })?;
5029                let sym = csr.symbolic();
5030                let row_ptr = sym.row_ptr();
5031                let col_idx = sym.col_idx();
5032                let vals = csr.val();
5033                for ptr_i in row_ptr[row]..row_ptr[row + 1] {
5034                    let i = col_idx[ptr_i];
5035                    let xi = vals[ptr_i];
5036                    for ptr_j in row_ptr[row]..row_ptr[row + 1] {
5037                        let j = col_idx[ptr_j];
5038                        target[[i, j]] += alpha * xi * vals[ptr_j];
5039                    }
5040                }
5041            }
5042        }
5043        Ok(())
5044    }
5045
5046    /// Asymmetric rank-1 update: `target += alpha * lhs_row * rhs_row^T`.
5047    ///
5048    /// `self` provides `lhs_row`, `other` provides `rhs_row`.
5049    /// `target` must be `self.ncols() x other.ncols()`.
5050    pub fn row_outer_into(
5051        &self,
5052        row: usize,
5053        other: &DesignMatrix,
5054        alpha: f64,
5055        target: &mut Array2<f64>,
5056    ) -> Result<(), String> {
5057        self.row_outer_into_view(row, other, alpha, target.view_mut())
5058    }
5059
5060    /// Like `row_outer_into` but accepts a mutable view, so callers can pass
5061    /// a slice of a larger matrix without allocating a temporary.
5062    pub fn row_outer_into_view(
5063        &self,
5064        row: usize,
5065        other: &DesignMatrix,
5066        alpha: f64,
5067        mut target: ArrayViewMut2<'_, f64>,
5068    ) -> Result<(), String> {
5069        if target.nrows() != self.ncols() || target.ncols() != other.ncols() {
5070            return Err(format!(
5071                "DesignMatrix::row_outer_into shape mismatch: target={}x{}, lhs={}, rhs={}",
5072                target.nrows(),
5073                target.ncols(),
5074                self.ncols(),
5075                other.ncols()
5076            ));
5077        }
5078        if alpha == 0.0 {
5079            return Ok(());
5080        }
5081        match (self, other) {
5082            (Self::Dense(lhs), Self::Dense(rhs)) => {
5083                let lhs_chunk;
5084                let rhs_chunk;
5085                let x = if let Some(lhs_dense) = lhs.as_dense_ref() {
5086                    lhs_dense.row(row)
5087                } else {
5088                    lhs_chunk = lhs
5089                        .try_row_chunk(row..row + 1)
5090                        .map_err(|e| format!("row_outer_into_view lhs: {e}"))?;
5091                    lhs_chunk.row(0)
5092                };
5093                let y = if let Some(rhs_dense) = rhs.as_dense_ref() {
5094                    rhs_dense.row(row)
5095                } else {
5096                    rhs_chunk = rhs
5097                        .try_row_chunk(row..row + 1)
5098                        .map_err(|e| format!("row_outer_into_view rhs: {e}"))?;
5099                    rhs_chunk.row(0)
5100                };
5101                for i in 0..x.len() {
5102                    let xi = x[i];
5103                    if xi == 0.0 {
5104                        continue;
5105                    }
5106                    for j in 0..y.len() {
5107                        target[[i, j]] += alpha * xi * y[j];
5108                    }
5109                }
5110            }
5111            (Self::Sparse(lhs), Self::Sparse(rhs)) => {
5112                // SAFETY: both `to_csr_arc` calls only fail on csc→csr conversion
5113                // of a malformed sparse matrix; `SparseDesignMatrix` invariants
5114                // upstream guarantee both inputs round-trip to CSR.
5115                // SAFETY: SparseDesignMatrix invariants guarantee csc→csr conversion succeeds.
5116                let lhs_csr = lhs
5117                    .to_csr_arc()
5118                    .ok_or_else(|| "row_outer_into: failed to obtain lhs CSR view".to_string())?;
5119                // SAFETY: SparseDesignMatrix invariants guarantee csc→csr conversion succeeds.
5120                let rhs_csr = rhs
5121                    .to_csr_arc()
5122                    .ok_or_else(|| "row_outer_into: failed to obtain rhs CSR view".to_string())?;
5123                let lhs_sym = lhs_csr.symbolic();
5124                let rhs_sym = rhs_csr.symbolic();
5125                let lhs_rp = lhs_sym.row_ptr();
5126                let rhs_rp = rhs_sym.row_ptr();
5127                let lhs_ci = lhs_sym.col_idx();
5128                let rhs_ci = rhs_sym.col_idx();
5129                let lhs_v = lhs_csr.val();
5130                let rhs_v = rhs_csr.val();
5131                for pi in lhs_rp[row]..lhs_rp[row + 1] {
5132                    let i = lhs_ci[pi];
5133                    let xi = lhs_v[pi];
5134                    for pj in rhs_rp[row]..rhs_rp[row + 1] {
5135                        let j = rhs_ci[pj];
5136                        target[[i, j]] += alpha * xi * rhs_v[pj];
5137                    }
5138                }
5139            }
5140            _ => {
5141                // Mixed dense/sparse: materialize both rows.
5142                let x = self
5143                    .try_row_chunk(row..row + 1)
5144                    .map_err(|e| format!("row_outer_into_view lhs: {e}"))?;
5145                let x_row = x.row(0);
5146                let y = other
5147                    .try_row_chunk(row..row + 1)
5148                    .map_err(|e| format!("row_outer_into_view rhs: {e}"))?;
5149                let y_row = y.row(0);
5150                for i in 0..x_row.len() {
5151                    let xi = x_row[i];
5152                    if xi == 0.0 {
5153                        continue;
5154                    }
5155                    for j in 0..y_row.len() {
5156                        target[[i, j]] += alpha * xi * y_row[j];
5157                    }
5158                }
5159            }
5160        }
5161        Ok(())
5162    }
5163
5164    /// Element access: returns the value at row `i`, column `j`.
5165    ///
5166    /// For materialized dense matrices this is O(1). For sparse matrices,
5167    /// the dense form is cached on the `SparseDesignMatrix` itself, so
5168    /// repeated calls amortize to O(1) after the first call populates the
5169    /// cache. For operator-backed (Lazy) dense matrices this call performs
5170    /// an O(n) single-column materialization via `extract_column`; callers
5171    /// sweeping many cells should call `as_dense_cow()` or `to_dense()`
5172    /// once and index the returned array directly — calling `get` in a
5173    /// per-cell loop on a Lazy operator is O(nrows · ncols) per call
5174    /// because the operator has no dense cache.
5175    #[inline]
5176    pub fn get(&self, i: usize, j: usize) -> f64 {
5177        match self {
5178            Self::Dense(matrix) => match matrix.as_dense_ref() {
5179                Some(dense) => dense[[i, j]],
5180                // Lazy operator: pull a single column via apply(e_j) so that
5181                // each call is O(n) instead of O(nrows · ncols); the default
5182                // `try_to_dense_arc` would re-materialize the full operator
5183                // on every call because Lazy operators have no dense cache.
5184                None => {
5185                    let mut e_j = Array1::<f64>::zeros(matrix.ncols());
5186                    e_j[j] = 1.0;
5187                    matrix.apply(&e_j)[i]
5188                }
5189            },
5190            Self::Sparse(sp) => {
5191                // SAFETY: `DesignMatrix::get` is documented as an
5192                // infallible scalar accessor; callers that take this path
5193                // have already accepted dense materialization. A
5194                // densification failure here means the sparse matrix exceeds
5195                // the conservative byte budget, which `DesignMatrix::get`
5196                // contractually forbids.
5197                // SAFETY: `get` is an infallible scalar accessor; caller has accepted dense materialization budget.
5198                let dense = sp
5199                    .try_to_dense_arc("DesignMatrix::get")
5200                    .unwrap_or_else(|msg| std::panic::panic_any(msg));
5201                dense[[i, j]]
5202            }
5203        }
5204    }
5205
5206    /// Extract a single column as a dense vector without full densification.
5207    ///
5208    /// - `Dense`: O(n) column copy.
5209    /// - `Sparse` (CSC): O(nnz_j) using the column pointer structure.
5210    /// - lazy `Dense`: O(matvec) via unit-vector application.
5211    pub fn extract_column(&self, j: usize) -> Array1<f64> {
5212        match self {
5213            Self::Dense(m) => {
5214                if let Some(dense) = m.as_dense_ref() {
5215                    dense.column(j).to_owned()
5216                } else {
5217                    let mut e_j = Array1::zeros(m.ncols());
5218                    e_j[j] = 1.0;
5219                    m.apply(&e_j)
5220                }
5221            }
5222            Self::Sparse(sp) => {
5223                let n = sp.nrows();
5224                let mut col = Array1::zeros(n);
5225                let (symbolic, values) = sp.parts();
5226                let col_ptr = symbolic.col_ptr();
5227                let row_idx = symbolic.row_idx();
5228                let start = col_ptr[j];
5229                let end = col_ptr[j + 1];
5230                for idx in start..end {
5231                    col[row_idx[idx]] = values[idx];
5232                }
5233                col
5234            }
5235        }
5236    }
5237
5238    /// Batched column extraction: returns an `nrows × cols.len()` dense block
5239    /// whose k-th column equals `extract_column(cols[k])`.
5240    ///
5241    /// For lazy operator-backed designs this routes through the operator's
5242    /// `apply_columns`, which `ReparamOperator` implements as a single GEMM
5243    /// (`X · Qs[:, cols]`) instead of one matvec dispatch per column.
5244    pub fn extract_columns(&self, cols: &[usize]) -> Array2<f64> {
5245        match self {
5246            Self::Dense(m) => match m {
5247                DenseDesignMatrix::Materialized(mat) => mat.select(Axis(1), cols),
5248                DenseDesignMatrix::Lazy(op) => op.apply_columns(cols),
5249            },
5250            Self::Sparse(sp) => {
5251                let n = sp.nrows();
5252                let mut out = Array2::<f64>::zeros((n, cols.len()));
5253                let (symbolic, values) = sp.parts();
5254                let col_ptr = symbolic.col_ptr();
5255                let row_idx = symbolic.row_idx();
5256                for (k, &j) in cols.iter().enumerate() {
5257                    let start = col_ptr[j];
5258                    let end = col_ptr[j + 1];
5259                    let mut out_col = out.column_mut(k);
5260                    for idx in start..end {
5261                        out_col[row_idx[idx]] = values[idx];
5262                    }
5263                }
5264                out
5265            }
5266        }
5267    }
5268
5269    /// Returns a reference to the inner dense array if this is a `Dense` variant.
5270    pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
5271        match self {
5272            Self::Dense(matrix) => matrix.as_dense_ref(),
5273            Self::Sparse(_) => None,
5274        }
5275    }
5276
5277    pub const fn is_materialized_dense(&self) -> bool {
5278        matches!(self, Self::Dense(DenseDesignMatrix::Materialized(_)))
5279    }
5280
5281    pub const fn is_operator_backed(&self) -> bool {
5282        match self {
5283            Self::Dense(matrix) => matrix.is_operator_backed(),
5284            Self::Sparse(_) => false,
5285        }
5286    }
5287
5288    /// Whether this design is backed by a sparse (CSR/COO) representation
5289    /// rather than a dense or dense-operator backing. Used to gate the
5290    /// row-chunked `Xᵀ diag(w) X` BLAS-3 Gram path, which is structurally
5291    /// applicable only to dense / dense-operator designs (a sparse block must
5292    /// keep the generic sparse-aware per-row pullback).
5293    pub const fn is_sparse(&self) -> bool {
5294        matches!(self, Self::Sparse(_))
5295    }
5296
5297    /// Zero-copy borrow when `Dense`, materialized conversion when `Sparse`.
5298    ///
5299    /// This avoids the unconditional clone that `to_dense()` performs on dense
5300    /// matrices.  Callers that only need a `&Array2<f64>` should use this and
5301    /// then call `Cow::as_ref()` or `&*cow`.
5302    pub fn as_dense_cow(&self) -> Cow<'_, Array2<f64>> {
5303        match self {
5304            Self::Dense(DenseDesignMatrix::Materialized(matrix)) => Cow::Borrowed(matrix.as_ref()),
5305            Self::Dense(DenseDesignMatrix::Lazy(op)) => match op.as_dense_ref() {
5306                Some(dense) => Cow::Borrowed(dense),
5307                // SAFETY: `as_dense_cow` is the zero-copy view accessor; its
5308                // contract forbids operator-backed designs that cannot expose
5309                // a pre-materialized dense view. A caller that reached this
5310                // arm used the borrow API on an operator representation it
5311                // should have streamed through row chunks instead.
5312                // SAFETY: as_dense_cow's zero-copy contract forbids operator-backed designs without a materialized view.
5313                None => std::panic::panic_any(format!(
5314                    "DesignMatrix::as_dense_cow called on operator-backed design ({}x{}); use row chunks or matrix-vector products",
5315                    op.nrows(),
5316                    op.ncols()
5317                )),
5318            },
5319            Self::Sparse(matrix) => Cow::Owned(
5320                matrix
5321                    .try_to_dense_arc("DesignMatrix::as_dense_cow")
5322                    // SAFETY: callers of `as_dense_cow` have accepted dense
5323                    // materialization; densification failure here means the
5324                    // sparse matrix exceeds the byte-cap that this accessor
5325                    // contractually forbids.
5326                    // SAFETY: caller of as_dense_cow has accepted dense materialization budget.
5327                    .unwrap_or_else(|msg| std::panic::panic_any(msg))
5328                    .as_ref()
5329                    .clone(),
5330            ),
5331        }
5332    }
5333
5334    /// Borrow when already-materialized dense, otherwise materialize via
5335    /// chunks (or via the sparse conversion path) and return an owned `Cow`.
5336    ///
5337    /// Use this when a code path genuinely needs a contiguous `Array2<f64>`
5338    /// view of an operator-backed design (e.g. legacy dense linear-algebra
5339    /// helpers that the operator-aware code paths have not yet replaced).
5340    /// Prefer `try_row_chunk` / `matrixvectormultiply` when chunked or
5341    /// matrix-free access suffices.
5342    pub fn to_dense_cow(&self) -> Cow<'_, Array2<f64>> {
5343        match self {
5344            Self::Dense(DenseDesignMatrix::Materialized(matrix)) => Cow::Borrowed(matrix.as_ref()),
5345            Self::Dense(DenseDesignMatrix::Lazy(op)) => {
5346                if let Some(dense) = op.as_dense_ref() {
5347                    Cow::Borrowed(dense)
5348                } else {
5349                    // Bypass the size-capped policy guard: callers reaching
5350                    // `to_dense_cow` are committing to a dense consumer.
5351                    Cow::Owned(
5352                        dense_operator_to_dense_by_chunks(op.as_ref()).unwrap_or_else(|err| {
5353                            // SAFETY: documented bypass — callers of
5354                            // `to_dense_cow` are infallible-by-contract dense
5355                            // consumers; the row-chunk path has no byte cap
5356                            // and only fails on operator implementation bugs,
5357                            // which the operator trait contract forbids.
5358                            // SAFETY: row_chunk_into is infallible-by-contract for valid operators.
5359                            std::panic::panic_any(format!(
5360                                "DesignMatrix::to_dense_cow: failed to materialize {}x{} \
5361                                 operator-backed design via row chunks: {err}",
5362                                op.nrows(),
5363                                op.ncols(),
5364                            ))
5365                        }),
5366                    )
5367                }
5368            }
5369            Self::Sparse(matrix) => Cow::Owned(
5370                matrix
5371                    .try_to_dense_arc("DesignMatrix::to_dense_cow")
5372                    // SAFETY: callers of `to_dense_cow` have committed to a
5373                    // dense `Array2<f64>` consumer; densification failure
5374                    // would mean the sparse matrix exceeds the conservative
5375                    // byte cap which this accessor's contract forbids.
5376                    // SAFETY: caller of to_dense_cow has accepted dense materialization budget.
5377                    .unwrap_or_else(|msg| std::panic::panic_any(msg))
5378                    .as_ref()
5379                    .clone(),
5380            ),
5381        }
5382    }
5383
5384    /// Returns the design as a contiguous `Array2<f64>`.
5385    ///
5386    /// **Bypass contract** (regression-pinned at
5387    /// `to_dense_arc_bypasses_policy_cap_strict_policy_still_refuses`): for
5388    /// operator-backed dense designs this streams row chunks via the
5389    /// operator's `row_chunk_into` and does NOT consult the
5390    /// `ResourcePolicy::max_single_materialization_bytes` cap. A caller
5391    /// reaching this method has already committed to a dense
5392    /// `Array2<f64>` consumer and owns the memory budget; consulting a
5393    /// conservative byte cap here would refuse legitimate workloads
5394    /// (e.g. the 4194304×10 Duchon basis at large scale that
5395    /// historically panicked with "refusing to densify operator-backed
5396    /// design").
5397    ///
5398    /// Strict-operator math (`DerivativeStorageMode::AnalyticOperatorRequired`)
5399    /// must instead call
5400    /// `try_to_dense_arc_with_policy(ctx, &ResourcePolicy::analytic_operator_required())`,
5401    /// which keeps refusal semantics intact for large-scale invariants.
5402    ///
5403    /// Sparse designs still honor their own internal
5404    /// `MAX_SPARSE_TO_DENSE_BYTES` cap (which is a separate hard limit
5405    /// guarding against accidental n×p dense materialization of a sparse
5406    /// design that should have stayed sparse).
5407    pub fn to_dense(&self) -> Array2<f64> {
5408        match self {
5409            Self::Dense(matrix) => matrix.to_dense(),
5410            Self::Sparse(matrix) => matrix
5411                .try_to_dense_arc("DesignMatrix::to_dense")
5412                // SAFETY: `to_dense` is documented as a dense-by-contract
5413                // accessor (see the bypass-contract doc above). Sparse path
5414                // honours `MAX_SPARSE_TO_DENSE_BYTES` internally; failure
5415                // means the matrix exceeded that hard limit which callers of
5416                // `to_dense` are forbidden from triggering.
5417                // SAFETY: dense-by-contract accessor; failure means caller broke MAX_SPARSE_TO_DENSE_BYTES contract.
5418                .unwrap_or_else(|msg| std::panic::panic_any(msg))
5419                .as_ref()
5420                .clone(),
5421        }
5422    }
5423
5424    /// Arc-shared variant of [`Self::to_dense`]; same bypass contract for the
5425    /// operator-backed dense branch (no policy guard), same sparse cap.
5426    pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
5427        match self {
5428            Self::Dense(matrix) => matrix.to_dense_arc(),
5429            Self::Sparse(matrix) => matrix
5430                .try_to_dense_arc("DesignMatrix::to_dense_arc")
5431                // SAFETY: arc-shared variant of `to_dense` — same bypass
5432                // contract; callers have committed to a dense consumer, so a
5433                // densification failure would mean the sparse matrix exceeded
5434                // `MAX_SPARSE_TO_DENSE_BYTES`, which this method's contract
5435                // forbids.
5436                // SAFETY: dense-by-contract accessor; failure means caller broke MAX_SPARSE_TO_DENSE_BYTES contract.
5437                .unwrap_or_else(|msg| std::panic::panic_any(msg)),
5438        }
5439    }
5440
5441    pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
5442        match self {
5443            Self::Dense(matrix) => matrix.try_to_dense_arc(context),
5444            Self::Sparse(matrix) => matrix.try_to_dense_arc(context),
5445        }
5446    }
5447
5448    /// Policy-aware densify: callers that own the consumer's dense budget can
5449    /// override the conservative default cap used by [`Self::try_to_dense_arc`].
5450    pub fn try_to_dense_arc_with_policy(
5451        &self,
5452        context: &str,
5453        policy: &ResourcePolicy,
5454    ) -> Result<Arc<Array2<f64>>, String> {
5455        match self {
5456            Self::Dense(matrix) => matrix.try_to_dense_arc_with_policy(context, policy),
5457            Self::Sparse(matrix) => matrix.try_to_dense_arc(context),
5458        }
5459    }
5460
5461    pub fn to_csr_cache(&self) -> Option<SparseRowMat<usize, f64>> {
5462        match self {
5463            Self::Dense(_) => None,
5464            Self::Sparse(matrix) => matrix.to_csr_arc().map(|arc| (*arc).clone()),
5465        }
5466    }
5467
5468    pub fn as_sparse(&self) -> Option<&SparseDesignMatrix> {
5469        match self {
5470            Self::Sparse(matrix) => Some(matrix),
5471            Self::Dense(_) => None,
5472        }
5473    }
5474
5475    pub fn as_dense(&self) -> Option<&Array2<f64>> {
5476        match self {
5477            Self::Dense(matrix) => matrix.as_dense_ref(),
5478            Self::Sparse(_) => None,
5479        }
5480    }
5481
5482    fn apply_transpose_view(&self, vector: ArrayView1<'_, f64>) -> Array1<f64> {
5483        match self {
5484            Self::Dense(DenseDesignMatrix::Materialized(matrix)) => fast_atv(matrix, &vector),
5485            Self::Dense(DenseDesignMatrix::Lazy(op)) => op.apply_transpose(&vector.to_owned()),
5486            Self::Sparse(matrix) => {
5487                let mut output = Array1::<f64>::zeros(matrix.ncols());
5488                let (symbolic, values) = matrix.parts();
5489                let col_ptr = symbolic.col_ptr();
5490                let row_idx = symbolic.row_idx();
5491                for col in 0..matrix.ncols() {
5492                    let mut acc = 0.0;
5493                    let start = col_ptr[col];
5494                    let end = col_ptr[col + 1];
5495                    for idx in start..end {
5496                        acc += values[idx] * vector[row_idx[idx]];
5497                    }
5498                    output[col] = acc;
5499                }
5500                output
5501            }
5502        }
5503    }
5504
5505    fn diag_gram_view(&self, weights: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
5506        if weights.len() != self.nrows() {
5507            return Err(format!(
5508                "diag_gram dimension mismatch: weights length {} != nrows {}",
5509                weights.len(),
5510                self.nrows()
5511            ));
5512        }
5513        match self {
5514            Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
5515                // Diagonal-of-Gram is a PSD-precondition kernel (used as a
5516                // Jacobi preconditioner diagonal). Discharge `w ≥ 0` once at
5517                // the boundary so the kernel below operates on a typed PSD
5518                // view without re-scanning.
5519                let psd = PsdWeightsView::try_new(weights)?;
5520                Ok(dense_diag_gram_view(matrix, psd))
5521            }
5522            Self::Dense(DenseDesignMatrix::Lazy(op)) => op.diag_gram(&weights.to_owned()),
5523            Self::Sparse(xs) => {
5524                let p = xs.ncols();
5525                let csr = xs
5526                    .to_csr_arc()
5527                    .ok_or_else(|| "failed to obtain CSR view in diag_gram".to_string())?;
5528                let sym = csr.symbolic();
5529                Ok(sparse_csr_diag_gram(
5530                    sym.row_ptr(),
5531                    sym.col_idx(),
5532                    csr.val(),
5533                    xs.nrows(),
5534                    p,
5535                    weights,
5536                ))
5537            }
5538        }
5539    }
5540
5541    fn compute_xtwy_view(
5542        &self,
5543        weights: ArrayView1<'_, f64>,
5544        y: ArrayView1<'_, f64>,
5545    ) -> Result<Array1<f64>, String> {
5546        if weights.len() != self.nrows() || y.len() != self.nrows() {
5547            return Err(format!(
5548                "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
5549                weights.len(),
5550                y.len(),
5551                self.nrows()
5552            ));
5553        }
5554        match self {
5555            Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
5556                Ok(dense_transpose_weighted_response_view(matrix, weights, y))
5557            }
5558            Self::Dense(DenseDesignMatrix::Lazy(op)) => {
5559                op.compute_xtwy(&weights.to_owned(), &y.to_owned())
5560            }
5561            Self::Sparse(xs) => {
5562                let csr = xs
5563                    .as_ref()
5564                    .to_row_major()
5565                    .map_err(|_| "failed to obtain CSR view in compute_xtwy".to_string())?;
5566                let sym = csr.symbolic();
5567                let row_ptr = sym.row_ptr();
5568                let col_idx = sym.col_idx();
5569                let vals = csr.val();
5570                let mut out = Array1::<f64>::zeros(xs.ncols());
5571                for i in 0..xs.nrows() {
5572                    let scaled = weights[i].max(0.0) * y[i];
5573                    if scaled == 0.0 {
5574                        continue;
5575                    }
5576                    for idx in row_ptr[i]..row_ptr[i + 1] {
5577                        out[col_idx[idx]] += vals[idx] * scaled;
5578                    }
5579                }
5580                Ok(out)
5581            }
5582        }
5583    }
5584
5585    pub fn dot(&self, vector: &Array1<f64>) -> Array1<f64> {
5586        <Self as LinearOperator>::apply(self, vector)
5587    }
5588
5589    pub fn matrixvectormultiply(&self, vector: &Array1<f64>) -> Array1<f64> {
5590        <Self as LinearOperator>::apply(self, vector)
5591    }
5592
5593    pub fn transpose_vector_multiply(&self, vector: &Array1<f64>) -> Array1<f64> {
5594        <Self as LinearOperator>::apply_transpose(self, vector)
5595    }
5596
5597    pub fn compute_xtwy(
5598        &self,
5599        weights: &Array1<f64>,
5600        y: &Array1<f64>,
5601    ) -> Result<Array1<f64>, String> {
5602        <Self as DenseDesignOperator>::compute_xtwy(self, weights, y)
5603    }
5604
5605    pub fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
5606        <Self as LinearOperator>::diag_gram(self, weights)
5607    }
5608
5609    pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
5610        <Self as DenseDesignOperator>::quadratic_form_diag(self, middle)
5611    }
5612
5613    pub fn apply_weighted_normal(
5614        &self,
5615        weights: &Array1<f64>,
5616        vector: &Array1<f64>,
5617        penalty: Option<&Array2<f64>>,
5618        ridge: f64,
5619    ) -> Array1<f64> {
5620        <Self as LinearOperator>::apply_weighted_normal(self, weights, vector, penalty, ridge)
5621    }
5622
5623    pub fn solve_system(
5624        &self,
5625        weights: &Array1<f64>,
5626        rhs: &Array1<f64>,
5627        penalty: Option<&Array2<f64>>,
5628    ) -> Result<Array1<f64>, String> {
5629        <Self as LinearOperator>::solve_system(self, weights, rhs, penalty)
5630    }
5631
5632    pub fn solve_systemwith_policy(
5633        &self,
5634        weights: &Array1<f64>,
5635        rhs: &Array1<f64>,
5636        penalty: Option<&Array2<f64>>,
5637        ridge_floor: f64,
5638        ridge_policy: RidgePolicy,
5639    ) -> Result<Array1<f64>, String> {
5640        <Self as LinearOperator>::solve_systemwith_policy(
5641            self,
5642            weights,
5643            rhs,
5644            penalty,
5645            ridge_floor,
5646            ridge_policy,
5647        )
5648    }
5649
5650    pub fn solve_system_matrix_free_pcg(
5651        &self,
5652        weights: &Array1<f64>,
5653        rhs: &Array1<f64>,
5654        penalty: Option<&Array2<f64>>,
5655        ridge_floor: f64,
5656    ) -> Result<Array1<f64>, String> {
5657        <Self as LinearOperator>::solve_system_matrix_free_pcg_try(
5658            self,
5659            weights,
5660            rhs,
5661            penalty,
5662            ridge_floor.max(SPD_SOLVE_RIDGE_FLOOR),
5663        )
5664    }
5665
5666    pub fn solve_system_matrix_free_pcg_with_info(
5667        &self,
5668        weights: &Array1<f64>,
5669        rhs: &Array1<f64>,
5670        penalty: Option<&Array2<f64>>,
5671        ridge_floor: f64,
5672    ) -> Result<(Array1<f64>, PcgSolveInfo), String> {
5673        <Self as LinearOperator>::solve_system_matrix_free_pcg_with_info_try(
5674            self,
5675            weights,
5676            rhs,
5677            penalty,
5678            ridge_floor.max(SPD_SOLVE_RIDGE_FLOOR),
5679        )
5680    }
5681
5682    pub fn should_use_matrix_free_pcg(&self) -> bool {
5683        <Self as LinearOperator>::uses_matrix_free_pcg(self)
5684            && self.ncols() >= MATRIX_FREE_PCG_MIN_P
5685    }
5686
5687    pub fn factorize_system(
5688        &self,
5689        weights: &Array1<f64>,
5690        penalty: Option<&Array2<f64>>,
5691    ) -> Result<Box<dyn FactorizedSystem>, String> {
5692        <Self as LinearOperator>::factorize_system(self, weights, penalty)
5693    }
5694}
5695
5696impl<'a> From<ArrayView2<'a, f64>> for DesignMatrix {
5697    fn from(value: ArrayView2<'a, f64>) -> Self {
5698        Self::Dense(DenseDesignMatrix::from(value.to_owned()))
5699    }
5700}
5701
5702impl From<Array2<f64>> for DesignMatrix {
5703    fn from(value: Array2<f64>) -> Self {
5704        Self::Dense(DenseDesignMatrix::from(value))
5705    }
5706}
5707
5708impl From<Arc<Array2<f64>>> for DesignMatrix {
5709    fn from(value: Arc<Array2<f64>>) -> Self {
5710        Self::Dense(DenseDesignMatrix::from(value))
5711    }
5712}
5713
5714impl From<&Array2<f64>> for DesignMatrix {
5715    fn from(value: &Array2<f64>) -> Self {
5716        Self::Dense(DenseDesignMatrix::from(value.clone()))
5717    }
5718}
5719
5720impl From<DenseDesignMatrix> for DesignMatrix {
5721    fn from(value: DenseDesignMatrix) -> Self {
5722        Self::Dense(value)
5723    }
5724}
5725
5726impl From<SparseColMat<usize, f64>> for DesignMatrix {
5727    fn from(value: SparseColMat<usize, f64>) -> Self {
5728        Self::Sparse(SparseDesignMatrix::new(value))
5729    }
5730}
5731
5732impl From<&SparseColMat<usize, f64>> for DesignMatrix {
5733    fn from(value: &SparseColMat<usize, f64>) -> Self {
5734        Self::Sparse(SparseDesignMatrix::new(value.clone()))
5735    }
5736}
5737
5738impl From<&DesignMatrix> for DesignMatrix {
5739    fn from(value: &DesignMatrix) -> Self {
5740        value.clone()
5741    }
5742}
5743
5744impl From<DesignMatrix> for DesignBlock {
5745    fn from(value: DesignMatrix) -> Self {
5746        match value {
5747            DesignMatrix::Dense(matrix) => Self::Dense(matrix),
5748            DesignMatrix::Sparse(matrix) => Self::Sparse(matrix),
5749        }
5750    }
5751}
5752
5753impl From<&DesignMatrix> for DesignBlock {
5754    fn from(value: &DesignMatrix) -> Self {
5755        match value {
5756            DesignMatrix::Dense(matrix) => Self::Dense(matrix.clone()),
5757            DesignMatrix::Sparse(matrix) => Self::Sparse(matrix.clone()),
5758        }
5759    }
5760}
5761
5762#[cfg(test)]
5763mod tests {
5764    use super::{
5765        BlockDesignOperator, CoefficientTransformOperator, DenseDesignMatrix, DenseDesignOperator,
5766        DesignBlock, DesignMatrix, EmbeddedColumnBlock, MultiChannelOperator, PsdWeightsView,
5767        ReparamOperator, ResidualisedDesignOperator, RowwiseKroneckerOperator, SignedWeightsView,
5768        SparseDesignMatrix, dense_operator_to_dense_by_chunks, dense_transpose_weighted_response,
5769        fast_atv, fast_av, streaming_sparse_csc_xt_diag_x, weighted_crossprod_dense_view,
5770    };
5771    use crate::matrix::LinearOperator;
5772    use crate::test_support::no_densify_design;
5773    use crate::types::RidgePolicy;
5774    use crate::utils::{PcgSolveInfo, StableSolver};
5775    use faer::sparse::{SparseColMat, SymbolicSparseColMat, Triplet};
5776    use gam_runtime::resource::{MatrixMaterializationError, ResourcePolicy};
5777    use ndarray::{Array1, Array2, ArrayViewMut2, Axis, array, s};
5778    use std::ops::Range;
5779    use std::sync::Arc;
5780    use std::sync::atomic::{AtomicUsize, Ordering};
5781
5782    struct ChunkOnlyOperator {
5783        n: usize,
5784        p: usize,
5785        row_chunk_calls: AtomicUsize,
5786    }
5787
5788    impl ChunkOnlyOperator {
5789        fn value(&self, i: usize, j: usize) -> f64 {
5790            ((i % 251) as f64) * 0.25 - ((j % 127) as f64) * 0.5 + ((i + j) % 7) as f64
5791        }
5792    }
5793
5794    impl LinearOperator for ChunkOnlyOperator {
5795        fn nrows(&self) -> usize {
5796            self.n
5797        }
5798
5799        fn ncols(&self) -> usize {
5800            self.p
5801        }
5802
5803        fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5804            let mut out = Array1::<f64>::zeros(self.n);
5805            for i in 0..self.n {
5806                let mut acc = 0.0;
5807                for j in 0..self.p {
5808                    acc += self.value(i, j) * vector[j];
5809                }
5810                out[i] = acc;
5811            }
5812            out
5813        }
5814
5815        fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5816            let mut out = Array1::<f64>::zeros(self.p);
5817            for i in 0..self.n {
5818                for j in 0..self.p {
5819                    out[j] += self.value(i, j) * vector[i];
5820                }
5821            }
5822            out
5823        }
5824
5825        fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5826            let dense = dense_operator_to_dense_by_chunks(self).map_err(|err| err.to_string())?;
5827            let psd = PsdWeightsView::try_new(weights.view())?;
5828            Ok(weighted_crossprod_dense_view(&dense, psd.view(), &dense))
5829        }
5830    }
5831
5832    impl DenseDesignOperator for ChunkOnlyOperator {
5833        fn row_chunk_into(
5834            &self,
5835            rows: Range<usize>,
5836            mut out: ArrayViewMut2<'_, f64>,
5837        ) -> Result<(), MatrixMaterializationError> {
5838            self.row_chunk_calls.fetch_add(1, Ordering::SeqCst);
5839            if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
5840                return Err(MatrixMaterializationError::MissingRowChunk {
5841                    context: "ChunkOnlyOperator::row_chunk_into shape mismatch",
5842                });
5843            }
5844            for (local, row) in rows.enumerate() {
5845                for col in 0..self.p {
5846                    out[[local, col]] = self.value(row, col);
5847                }
5848            }
5849            Ok(())
5850        }
5851
5852        fn to_dense(&self) -> Array2<f64> {
5853            // SAFETY: test-only mock asserting row_chunk_into is exercised; reaching to_dense indicates a routing regression.
5854            panic!("ChunkOnlyOperator::to_dense fallback must not be used")
5855        }
5856    }
5857
5858    fn exact_weighted_penalized_solve(
5859        design: &Array2<f64>,
5860        weights: &Array1<f64>,
5861        rhs: &Array1<f64>,
5862        penalty: &Array2<f64>,
5863        ridge: f64,
5864    ) -> Array1<f64> {
5865        let mut h = design
5866            .t()
5867            .dot(&(design * &weights.view().insert_axis(Axis(1))));
5868        h += penalty;
5869        if ridge > 0.0 {
5870            for i in 0..h.nrows() {
5871                h[[i, i]] += ridge;
5872            }
5873        }
5874        StableSolver::new("matrix-free pcg exact reference")
5875            .solvevectorwithridge_retries(&h, rhs, 0.0)
5876            .expect("exact reference solve")
5877    }
5878
5879    #[test]
5880    fn fast_av_matches_ndarray_dot() {
5881        let x = array![[1.0, 2.0, -1.0], [0.5, -3.0, 4.0], [2.0, 0.0, 1.5]];
5882        let v = array![0.25, -1.0, 2.0];
5883        let expected = x.dot(&v);
5884        let got = fast_av(&x, &v);
5885        for i in 0..expected.len() {
5886            assert!((expected[i] - got[i]).abs() < 1e-12);
5887        }
5888    }
5889
5890    #[test]
5891    fn fast_atv_matches_ndarray_dot() {
5892        let x = array![[1.0, 2.0, -1.0], [0.5, -3.0, 4.0], [2.0, 0.0, 1.5]];
5893        let v = array![0.25, -1.0, 2.0];
5894        let expected = x.t().dot(&v);
5895        let got = fast_atv(&x, &v);
5896        for i in 0..expected.len() {
5897            assert!((expected[i] - got[i]).abs() < 1e-12);
5898        }
5899    }
5900
5901    #[test]
5902    fn sparse_to_dense_accumulates_duplicate_entries() {
5903        // Build a non-canonical CSC with duplicate row index in the same column.
5904        // This can happen if a caller bypasses canonical constructors.
5905        let symbolic = SymbolicSparseColMat::new_unsorted_checked(
5906            3,
5907            2,
5908            vec![0_usize, 2, 3],
5909            None,
5910            vec![1_usize, 1, 0],
5911        );
5912        let sparse = SparseColMat::new(symbolic, vec![2.0_f64, 3.5, -1.0]);
5913        let design = DesignMatrix::from(sparse);
5914        let dense = design.to_dense_arc();
5915
5916        assert!((dense[[1, 0]] - 5.5).abs() < 1e-12);
5917        assert!((dense[[0, 1]] + 1.0).abs() < 1e-12);
5918
5919        let v = array![4.0, -2.0];
5920        let y_sparse = design.matrixvectormultiply(&v);
5921        let y_dense = dense.dot(&v);
5922        for i in 0..y_sparse.len() {
5923            assert!((y_sparse[i] - y_dense[i]).abs() < 1e-12);
5924        }
5925    }
5926
5927    #[test]
5928    fn huge_sparse_densification_is_rejected_before_allocation() {
5929        let sparse = SparseColMat::try_new_from_triplets(500_000, 10_000, &[])
5930            .expect("empty sparse matrix should build");
5931        let design = SparseDesignMatrix::new(sparse);
5932        let err = design
5933            .try_to_dense_arc("matrix test")
5934            .expect_err("huge sparse densification should be rejected");
5935        assert!(err.contains("refusing to densify sparse design"));
5936    }
5937
5938    #[test]
5939    fn streaming_sparse_csc_xt_diag_x_matches_dense_signed_weights() {
5940        let sparse = SparseColMat::try_new_from_triplets(
5941            4,
5942            3,
5943            &[
5944                Triplet::new(0, 0, 1.0),
5945                Triplet::new(1, 0, 2.0),
5946                Triplet::new(2, 0, -1.0),
5947                Triplet::new(0, 1, 0.5),
5948                Triplet::new(1, 1, -3.0),
5949                Triplet::new(3, 1, 4.0),
5950                Triplet::new(0, 2, 2.0),
5951                Triplet::new(2, 2, 1.5),
5952                Triplet::new(3, 2, -0.25),
5953            ],
5954        )
5955        .expect("sparse matrix");
5956        let design = SparseDesignMatrix::new(sparse.clone());
5957        let dense = design.to_dense_arc();
5958        let weights = array![1.0, -2.0, 0.5, -1.5];
5959        let (symbolic, values) = sparse.parts();
5960        let mut got = Array2::<f64>::zeros((3, 3));
5961        streaming_sparse_csc_xt_diag_x(
5962            symbolic.col_ptr(),
5963            symbolic.row_idx(),
5964            values,
5965            4,
5966            3,
5967            weights.view(),
5968            &mut got,
5969        );
5970
5971        let mut expected = Array2::<f64>::zeros((3, 3));
5972        for row in 0..4 {
5973            for a in 0..3 {
5974                for b in 0..3 {
5975                    expected[[a, b]] += weights[row] * dense[[row, a]] * dense[[row, b]];
5976                }
5977            }
5978        }
5979        let max_diff = (&got - &expected)
5980            .iter()
5981            .map(|v| v.abs())
5982            .fold(0.0_f64, f64::max);
5983        assert!(
5984            max_diff < 1e-12,
5985            "streamed sparse weighted Gram mismatch: max_diff={max_diff}"
5986        );
5987    }
5988
5989    #[test]
5990    fn multi_channel_operator_view_paths_match_stacked_dense_reference() {
5991        let dense_channel = array![[1.0, 2.0], [0.5, -1.0], [3.0, 0.25]];
5992        let sparse_dense = array![[0.0, 1.5], [2.0, 0.0], [-1.0, 0.75]];
5993        let sparse = SparseColMat::try_new_from_triplets(
5994            3,
5995            2,
5996            &[
5997                Triplet::new(1, 0, 2.0),
5998                Triplet::new(2, 0, -1.0),
5999                Triplet::new(0, 1, 1.5),
6000                Triplet::new(2, 1, 0.75),
6001            ],
6002        )
6003        .expect("sparse channel");
6004        let op = MultiChannelOperator::new(vec![
6005            DesignMatrix::Dense(DenseDesignMatrix::from(dense_channel.clone())),
6006            DesignMatrix::from(sparse),
6007        ])
6008        .expect("multi-channel operator");
6009        let mut stacked = Array2::<f64>::zeros((6, 2));
6010        stacked.slice_mut(s![0..3, ..]).assign(&dense_channel);
6011        stacked.slice_mut(s![3..6, ..]).assign(&sparse_dense);
6012
6013        let beta = array![0.25, -0.4];
6014        let expected_apply = stacked.dot(&beta);
6015        let got_apply = op.apply(&beta);
6016        for i in 0..expected_apply.len() {
6017            assert!((expected_apply[i] - got_apply[i]).abs() < 1e-12);
6018        }
6019
6020        let probe = array![0.5, -1.0, 0.25, 1.5, -0.75, 0.2];
6021        let expected_transpose = stacked.t().dot(&probe);
6022        let got_transpose = op.apply_transpose(&probe);
6023        for i in 0..expected_transpose.len() {
6024            assert!((expected_transpose[i] - got_transpose[i]).abs() < 1e-12);
6025        }
6026
6027        let weights = array![1.0, -0.5, 0.75, 2.0, 0.25, 1.5];
6028        let w_pos = weights.mapv(|w: f64| w.max(0.0));
6029        let weighted = stacked.clone() * w_pos.view().insert_axis(Axis(1));
6030        let expected_xtwx = stacked.t().dot(&weighted);
6031        let got_xtwx = op.diag_xtw_x(&weights).expect("multi-channel xtwx");
6032        for i in 0..expected_xtwx.nrows() {
6033            for j in 0..expected_xtwx.ncols() {
6034                assert!((expected_xtwx[[i, j]] - got_xtwx[[i, j]]).abs() < 1e-12);
6035            }
6036        }
6037
6038        let expected_diag = Array1::from_iter((0..2).map(|j| expected_xtwx[[j, j]]));
6039        let got_diag = op.diag_gram(&weights).expect("multi-channel diag gram");
6040        for i in 0..expected_diag.len() {
6041            assert!((expected_diag[i] - got_diag[i]).abs() < 1e-12);
6042        }
6043
6044        let y = array![1.0, 0.5, -0.25, 2.0, -1.0, 0.75];
6045        let expected_xtwy = stacked.t().dot(&(w_pos * &y));
6046        let got_xtwy = op.compute_xtwy(&weights, &y).expect("multi-channel xtwy");
6047        for i in 0..expected_xtwy.len() {
6048            assert!((expected_xtwy[i] - got_xtwy[i]).abs() < 1e-12);
6049        }
6050    }
6051
6052    /// Perf (#1017): the fused scale-once + `fast_atb` Dense×Dense cross-block
6053    /// assembly in `BlockDesignOperator::diag_xtw_x` must equal the full stacked
6054    /// reference Gram `Xᵀ diag(w) X` for a multi-block dense layout with SIGNED
6055    /// weights (the observed-Hessian regime that exercises the sign-correct
6056    /// asymmetric cross kernel). Several dense blocks of differing widths so the
6057    /// off-diagonal slicing and the symmetric transpose fill are both covered.
6058    #[test]
6059    fn block_design_fused_dense_cross_matches_stacked_reference_xtwx() {
6060        let b0 = array![
6061            [1.0, 2.0],
6062            [0.5, -1.0],
6063            [3.0, 0.25],
6064            [-2.0, 1.5],
6065            [0.75, -0.5],
6066        ];
6067        let b1 = array![
6068            [-1.0, 0.5, 2.0],
6069            [1.5, -0.25, 0.0],
6070            [0.0, 1.0, -1.5],
6071            [2.0, 0.5, 1.0],
6072            [-0.5, -1.0, 0.25],
6073        ];
6074        let b2 = array![[0.5], [-1.0], [2.0], [0.25], [-0.75]];
6075
6076        let mut stacked = Array2::<f64>::zeros((5, 6));
6077        stacked.slice_mut(s![.., 0..2]).assign(&b0);
6078        stacked.slice_mut(s![.., 2..5]).assign(&b1);
6079        stacked.slice_mut(s![.., 5..6]).assign(&b2);
6080
6081        let blocks = vec![
6082            DesignBlock::Dense(DenseDesignMatrix::from(b0)),
6083            DesignBlock::Dense(DenseDesignMatrix::from(b1)),
6084            DesignBlock::Dense(DenseDesignMatrix::from(b2)),
6085        ];
6086        let op = BlockDesignOperator::new(blocks).expect("block design");
6087
6088        // Signed weights: the cross kernel must NOT clamp to PSD here.
6089        let weights = array![1.5, -0.5, 2.0, -1.0, 0.75];
6090        let weighted = stacked.clone() * weights.view().insert_axis(Axis(1));
6091        let expected = stacked.t().dot(&weighted);
6092
6093        let got = op.diag_xtw_x(&weights).expect("block fused xtwx");
6094        assert_eq!(got.dim(), (6, 6));
6095        let max_diff = (&got - &expected)
6096            .iter()
6097            .map(|v| v.abs())
6098            .fold(0.0_f64, f64::max);
6099        assert!(
6100            max_diff < 1e-10,
6101            "fused block Dense×Dense Gram mismatch: max_diff={max_diff}"
6102        );
6103    }
6104
6105    #[test]
6106    #[should_panic(expected = "ReparamOperator: X cols (2) must match Qs rows (3)")]
6107    fn reparam_operator_rejects_incompatible_transform_shape() {
6108        let x = array![[1.0, 2.0], [0.5, -1.0]];
6109        let qs = Arc::new(Array2::<f64>::zeros((3, 1)));
6110        ReparamOperator::new(DesignMatrix::Dense(DenseDesignMatrix::from(x)), qs);
6111    }
6112
6113    /// Locks in the dispatch path for the BLAS-3 cross-block fast path:
6114    /// when a `CoefficientTransformOperator` is wrapped as
6115    /// `DenseDesignMatrix::Lazy`, `DenseDesignMatrix::as_dense_ref` must reach
6116    /// the operator's cached materialization. The dispatch goes
6117    /// `DenseDesignMatrix::as_dense_ref` → `DenseDesignOperator::as_dense_ref`,
6118    /// so the override has to live on `DenseDesignOperator`, not
6119    /// `LinearOperator`. A misplaced override on `LinearOperator` is a hard
6120    /// build break today (E0407, fixed in b516891), but if `LinearOperator`
6121    /// ever grew an `as_dense_ref` slot the silent failure would be
6122    /// `BlockDesignOperator::cross_block` falling back to the chunked scalar
6123    /// path with no test signal — this assertion is the missing signal.
6124    #[test]
6125    fn coefficient_transform_operator_exposes_cached_dense_to_block_dispatch() {
6126        let inner = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
6127        let transform = array![[0.5, -1.0, 2.0], [1.0, 0.0, -0.5]];
6128        let expected = inner.dot(&transform);
6129
6130        let op =
6131            CoefficientTransformOperator::new(DenseDesignMatrix::from(inner), transform.clone())
6132                .expect("coefficient transform operator");
6133        let dense_design = DenseDesignMatrix::from(Arc::new(op));
6134
6135        // Touch the cache through any LinearOperator path (`apply_transpose`
6136        // short-circuits through `materialized_combined`). The OnceLock is empty
6137        // until something exercises it, so `as_dense_ref` would otherwise
6138        // return None before the first hot call.
6139        let probe = Array1::from_elem(3, 1.0);
6140        let warmed = dense_design.apply_transpose(&probe);
6141        assert_eq!(warmed.len(), expected.ncols());
6142
6143        let dense_ref = dense_design
6144            .as_dense_ref()
6145            .expect("DenseDesignMatrix::as_dense_ref must reach the cached X·T");
6146        assert_eq!(dense_ref.dim(), expected.dim());
6147        for ((r, c), v) in expected.indexed_iter() {
6148            assert!((dense_ref[[r, c]] - v).abs() < 1e-12);
6149        }
6150    }
6151
6152    #[test]
6153    fn design_matrix_hstack_preserves_lazy_blocks() {
6154        let left_dense = array![[1.0, 2.0], [3.0, 4.0]];
6155        let right_dense = array![[5.0], [6.0]];
6156        let left = no_densify_design(left_dense.clone());
6157        let right = no_densify_design(right_dense.clone());
6158        let stacked = DesignMatrix::hstack(vec![left, right]).expect("stacked design");
6159
6160        assert!(stacked.as_dense_ref().is_none());
6161        assert!(!stacked.is_materialized_dense());
6162        assert!(stacked.is_operator_backed());
6163        assert_eq!(stacked.nrows(), 2);
6164        assert_eq!(stacked.ncols(), 3);
6165
6166        let beta = array![0.25, -0.5, 2.0];
6167        let expected = array![9.25, 10.75];
6168        let got = stacked.dot(&beta);
6169        for i in 0..expected.len() {
6170            assert!((got[i] - expected[i]).abs() < 1e-12);
6171        }
6172
6173        let chunk = stacked
6174            .try_row_chunk(0..2)
6175            .expect("stacked.try_row_chunk must succeed");
6176        assert_eq!(chunk, array![[1.0, 2.0, 5.0], [3.0, 4.0, 6.0]]);
6177    }
6178
6179    #[test]
6180    #[should_panic(expected = "DesignMatrix::as_dense_cow called on operator-backed design")]
6181    fn design_matrix_as_dense_cow_rejects_operator_backed_designs() {
6182        let design = no_densify_design(array![[1.0, 2.0], [3.0, 4.0]]);
6183        design.as_dense_cow();
6184    }
6185
6186    #[test]
6187    fn sparse_factorized_solve_matches_dense_operator_solve() {
6188        let triplets = vec![
6189            Triplet::new(0usize, 0usize, 1.0),
6190            Triplet::new(1, 0, 2.0),
6191            Triplet::new(1, 1, -1.0),
6192            Triplet::new(2, 1, 3.0),
6193            Triplet::new(2, 2, 0.5),
6194        ];
6195        let sparse = SparseColMat::try_new_from_triplets(3, 3, &triplets)
6196            .expect("sparse design should build");
6197        let sparse_design = DesignMatrix::from(sparse);
6198        let dense_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
6199            sparse_design.to_dense(),
6200        ));
6201        let weights = array![1.5, 0.75, 2.0];
6202        let rhs = array![1.0, -0.5, 2.0];
6203        let penalty = Array2::from_diag(&array![0.25, 0.5, 0.75]);
6204
6205        let sparse_sol = sparse_design
6206            .solve_system(&weights, &rhs, Some(&penalty))
6207            .expect("sparse solve should factorize natively");
6208        let dense_sol = dense_design
6209            .solve_system(&weights, &rhs, Some(&penalty))
6210            .expect("dense solve should factorize");
6211
6212        for i in 0..rhs.len() {
6213            assert!(
6214                (sparse_sol[i] - dense_sol[i]).abs() < 1e-10,
6215                "solution mismatch at {i}: sparse={} dense={}",
6216                sparse_sol[i],
6217                dense_sol[i]
6218            );
6219        }
6220    }
6221
6222    #[test]
6223    fn solve_system_stabilizes_indefinite_penalty_and_returns_finite_solution() {
6224        let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![
6225            [1.0, 0.0],
6226            [0.0, 0.0]
6227        ]));
6228        let weights = array![1.0, 1.0];
6229        let rhs = array![2.0, 0.0];
6230        let penalty = array![[0.0, 0.0], [0.0, -1e-12]];
6231
6232        let beta = design
6233            .solve_system(&weights, &rhs, Some(&penalty))
6234            .expect("solve_system should stabilize indefinite systems");
6235
6236        assert!(beta.iter().all(|v| v.is_finite()));
6237        assert!((beta[0] - 2.0).abs() < 1e-10);
6238        assert!(beta[1].abs() < 1e-8);
6239    }
6240
6241    #[test]
6242    fn explicit_matrix_free_pcg_matches_exact_large_dense_weighted_penalized_solve() {
6243        let n = 48usize;
6244        let p = 520usize;
6245        let mut x = Array2::<f64>::zeros((n, p));
6246        for i in 0..n {
6247            for j in 0..p {
6248                x[[i, j]] = (((i + 3) * (j + 5)) % 17) as f64 / 17.0
6249                    + 0.02 * (i as f64)
6250                    + 0.001 * (j as f64);
6251            }
6252        }
6253        let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone()));
6254        let weights = Array1::from_iter((0..n).map(|i| 0.5 + (i as f64) / (2.0 * n as f64)));
6255        let rhs = Array1::from_iter((0..p).map(|j| ((j % 13) as f64 - 6.0) / 13.0));
6256        let penalty = Array2::from_diag(&Array1::from_iter(
6257            (0..p).map(|j| 0.1 + 0.005 * ((j % 7) as f64)),
6258        ));
6259        let ridge = 1e-8;
6260
6261        let pcg = design
6262            .solve_system_matrix_free_pcg(&weights, &rhs, Some(&penalty), ridge)
6263            .expect("matrix-free pcg solve");
6264        let exact = exact_weighted_penalized_solve(&x, &weights, &rhs, &penalty, ridge);
6265        for i in 0..p {
6266            assert!(
6267                (pcg[i] - exact[i]).abs() < 1e-5,
6268                "solution mismatch at {i}: pcg={} exact={}",
6269                pcg[i],
6270                exact[i]
6271            );
6272        }
6273        let mut h = x
6274            .t()
6275            .dot(&(x.clone() * weights.view().insert_axis(Axis(1))));
6276        h += &penalty;
6277        for i in 0..p {
6278            h[[i, i]] += ridge;
6279        }
6280        let residual = h.dot(&pcg) - &rhs;
6281        let residual_norm = residual.dot(&residual).sqrt();
6282        assert!(residual_norm < 1e-4, "residual_norm={residual_norm}");
6283    }
6284
6285    #[test]
6286    fn policy_solve_matches_explicit_matrix_free_pcg_on_large_dense_system() {
6287        let n = 40usize;
6288        let p = 520usize;
6289        let mut x = Array2::<f64>::zeros((n, p));
6290        for i in 0..n {
6291            for j in 0..p {
6292                x[[i, j]] = (((2 * i + j + 11) % 23) as f64 / 23.0) + 0.0005 * (j as f64);
6293            }
6294        }
6295        let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x));
6296        let weights = Array1::from_iter((0..n).map(|i| 1.0 + 0.01 * i as f64));
6297        let rhs = Array1::from_iter((0..p).map(|j| ((j % 5) as f64) - 2.0));
6298        let penalty = Array2::from_diag(&Array1::from_iter(
6299            (0..p).map(|j| 0.2 + 0.01 * ((j % 3) as f64)),
6300        ));
6301        let ridge_floor = 1e-8;
6302
6303        let explicit = design
6304            .solve_system_matrix_free_pcg(&weights, &rhs, Some(&penalty), ridge_floor)
6305            .expect("explicit pcg");
6306        let policy = design
6307            .solve_systemwith_policy(
6308                &weights,
6309                &rhs,
6310                Some(&penalty),
6311                ridge_floor,
6312                RidgePolicy::explicit_stabilization_pospart(),
6313            )
6314            .expect("policy solve");
6315        for i in 0..p {
6316            // This system is heavily rank-deficient (rank ≤ n = 40, p = 520,
6317            // p ≫ n) with only a weak ~0.2 diagonal penalty + 1e-8 ridge_floor,
6318            // so the normal matrix is severely ill-conditioned. Both arms are
6319            // matrix-free PCG (explicit vs `explicit_stabilization_pospart`
6320            // policy); they terminate at slightly different points on the
6321            // near-null manifold. A fixed 1e-6 absolute gate is below what PCG
6322            // can guarantee at this conditioning; assert a relative tolerance
6323            // scaled by the coefficient magnitude instead (gam#846).
6324            let tol = 1e-5 * (1.0 + explicit[i].abs());
6325            assert!(
6326                (explicit[i] - policy[i]).abs() < tol,
6327                "policy mismatch at {i}: explicit={} policy={} (tol={tol})",
6328                explicit[i],
6329                policy[i]
6330            );
6331        }
6332    }
6333
6334    #[test]
6335    fn explicit_matrix_free_pcg_reports_convergence_diagnostics() {
6336        let n = 36usize;
6337        let p = 2160usize;
6338        let mut x = Array2::<f64>::zeros((n, p));
6339        for i in 0..n {
6340            for j in 0..p {
6341                x[[i, j]] = (((3 * i + 5 * j + 7) % 29) as f64 / 29.0)
6342                    + 0.015 * (i as f64)
6343                    + 1e-4 * j as f64;
6344            }
6345        }
6346        let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone()));
6347        assert!(design.should_use_matrix_free_pcg());
6348        let weights = Array1::from_iter((0..n).map(|i| 0.75 + 0.01 * i as f64));
6349        let rhs = Array1::from_iter((0..p).map(|j| ((j % 9) as f64 - 4.0) / 9.0));
6350        let penalty = Array2::from_diag(&Array1::from_iter(
6351            (0..p).map(|j| 0.05 + 0.002 * ((j % 11) as f64)),
6352        ));
6353        let ridge = 1e-8;
6354
6355        let (pcg, info): (Array1<f64>, PcgSolveInfo) = design
6356            .solve_system_matrix_free_pcg_with_info(&weights, &rhs, Some(&penalty), ridge)
6357            .expect("pcg with info");
6358        assert!(info.converged);
6359        assert!(info.iterations > 0);
6360        assert!(info.relative_residual_norm.is_finite());
6361        assert!(info.relative_residual_norm < 1e-6);
6362
6363        let exact = exact_weighted_penalized_solve(&x, &weights, &rhs, &penalty, ridge);
6364        for i in 0..p {
6365            assert!(
6366                (pcg[i] - exact[i]).abs() < 1e-5,
6367                "solution mismatch at {i}: pcg={} exact={}",
6368                pcg[i],
6369                exact[i]
6370            );
6371        }
6372    }
6373
6374    #[test]
6375    fn compute_xtwy_dense_allocationfree_matches_matvec() {
6376        let n = 2_000usize;
6377        let p = 64usize;
6378        let mut x = Array2::<f64>::zeros((n, p));
6379        let mut y = Array1::<f64>::zeros(n);
6380        let mut w = Array1::<f64>::zeros(n);
6381        for i in 0..n {
6382            y[i] = ((i % 17) as f64 - 8.0) * 0.1;
6383            w[i] = 0.25 + ((i % 11) as f64) * 0.05;
6384            for j in 0..p {
6385                x[[i, j]] = (((i * 13 + j * 7) % 97) as f64) / 97.0;
6386            }
6387        }
6388
6389        let reference = {
6390            let wy = Array1::from_shape_fn(n, |i| y[i] * w[i].max(0.0));
6391            fast_atv(&x, &wy)
6392        };
6393        let fused = dense_transpose_weighted_response(&x, &w, &y, None);
6394        for j in 0..p {
6395            assert!(
6396                (reference[j] - fused[j]).abs() < 1e-10,
6397                "mismatch at column {j}: ref={} fused={}",
6398                reference[j],
6399                fused[j]
6400            );
6401        }
6402    }
6403
6404    #[test]
6405    fn large_lazy_dense_materialization_streams_chunks_without_to_dense_fallback() {
6406        let n = 11_000usize;
6407        let p = 128usize;
6408        let op = Arc::new(ChunkOnlyOperator {
6409            n,
6410            p,
6411            row_chunk_calls: AtomicUsize::new(0),
6412        });
6413        let design = DenseDesignMatrix::from(Arc::clone(&op));
6414
6415        let dense = design.to_dense_arc();
6416
6417        assert_eq!(dense.dim(), (n, p));
6418        assert!(
6419            op.row_chunk_calls.load(Ordering::SeqCst) > 1,
6420            "expected dense materialization to stream more than one row chunk"
6421        );
6422        for &(i, j) in &[(0, 0), (8_191, 127), (8_192, 0), (10_999, 64)] {
6423            assert_eq!(dense[[i, j]], op.value(i, j));
6424        }
6425    }
6426
6427    /// Regression: the original `4194304x10 (~0.31 GiB)` Duchon panic was a
6428    /// policy-cap refusal inside `try_to_dense_arc_with_policy`. The infallible
6429    /// entry points (`to_dense`, `to_dense_arc`, `to_dense_cow`) must bypass
6430    /// that guard so callers asking for a contiguous `Array2<f64>` get one,
6431    /// while strict-operator callers still see refusal through the explicit
6432    /// `try_to_dense_arc_with_policy(_, &analytic_operator_required())` path.
6433    /// Use a tiny operator so the contract is pinned without allocating
6434    /// hundreds of MiB on CI.
6435    #[test]
6436    fn to_dense_arc_bypasses_policy_cap_strict_policy_still_refuses() {
6437        let op = Arc::new(ChunkOnlyOperator {
6438            n: 128,
6439            p: 4,
6440            row_chunk_calls: AtomicUsize::new(0),
6441        });
6442        let design = DenseDesignMatrix::from(Arc::clone(&op));
6443
6444        // The panic-free entry points must succeed regardless of the policy cap.
6445        let dense = design.to_dense_arc();
6446        assert_eq!(dense.dim(), (128, 4));
6447
6448        // Strict operator-required policy still refuses on the explicit
6449        // policy-aware API — large-scale invariants preserved.
6450        let strict = ResourcePolicy::analytic_operator_required();
6451        let err = design
6452            .try_to_dense_arc_with_policy("regression strict refuses", &strict)
6453            .expect_err("strict policy must refuse lazy materialization");
6454        assert!(
6455            err.contains("refusing to densify operator-backed design")
6456                && err.contains("AnalyticOperatorRequired"),
6457            "unexpected strict-policy error: {err}"
6458        );
6459
6460        // Tighten the size cap below the design footprint and confirm the
6461        // policy-aware API rejects on size (the contract the infallible path
6462        // is documented to bypass).
6463        let mut tight = ResourcePolicy::default_library();
6464        tight.max_single_materialization_bytes = 1;
6465        let size_err = design
6466            .try_to_dense_arc_with_policy("regression tight refuses", &tight)
6467            .expect_err("undersized cap must refuse lazy materialization");
6468        assert!(
6469            size_err.contains("refusing to densify operator-backed design"),
6470            "unexpected size-cap error: {size_err}"
6471        );
6472    }
6473
6474    #[test]
6475    fn try_to_dense_by_chunks_writes_directly_into_output_slices() {
6476        let n = 11_000usize;
6477        let p = 128usize;
6478        let op = Arc::new(ChunkOnlyOperator {
6479            n,
6480            p,
6481            row_chunk_calls: AtomicUsize::new(0),
6482        });
6483        let design = DesignMatrix::Dense(DenseDesignMatrix::from(Arc::clone(&op)));
6484
6485        let dense = design
6486            .try_to_dense_by_chunks("large chunked regression")
6487            .expect("chunked materialization");
6488
6489        assert_eq!(dense.dim(), (n, p));
6490        assert!(
6491            op.row_chunk_calls.load(Ordering::SeqCst) > 1,
6492            "expected direct chunked conversion to use bounded row chunks"
6493        );
6494        for &(i, j) in &[(1, 7), (4_096, 12), (8_193, 63), (10_998, 127)] {
6495            assert_eq!(dense[[i, j]], op.value(i, j));
6496        }
6497    }
6498
6499    #[test]
6500    fn tensor_product_design_operator_matches_dense_2d() {
6501        use super::{DenseDesignOperator, TensorProductDesignOperator};
6502
6503        // Two marginal B-spline-like bases: 10 rows, 4 and 3 columns.
6504        let n = 10;
6505        let q1 = 4;
6506        let q2 = 3;
6507        let mut b1 = Array2::<f64>::zeros((n, q1));
6508        let mut b2 = Array2::<f64>::zeros((n, q2));
6509        // Fill with simple hat-function-like patterns (sparse per row).
6510        for i in 0..n {
6511            let t1 = i as f64 / (n - 1) as f64 * (q1 - 1) as f64;
6512            let j1 = (t1.floor() as usize).min(q1 - 2);
6513            let frac1 = t1 - j1 as f64;
6514            b1[[i, j1]] = 1.0 - frac1;
6515            b1[[i, j1 + 1]] = frac1;
6516
6517            let t2 = i as f64 / (n - 1) as f64 * (q2 - 1) as f64;
6518            let j2 = (t2.floor() as usize).min(q2 - 2);
6519            let frac2 = t2 - j2 as f64;
6520            b2[[i, j2]] = 1.0 - frac2;
6521            b2[[i, j2 + 1]] = frac2;
6522        }
6523
6524        let op = TensorProductDesignOperator::new(vec![Arc::new(b1.clone()), Arc::new(b2.clone())])
6525            .unwrap();
6526
6527        // Build dense reference via explicit Kronecker row products.
6528        let p = q1 * q2;
6529        let mut dense = Array2::<f64>::zeros((n, p));
6530        for i in 0..n {
6531            for j1 in 0..q1 {
6532                for j2 in 0..q2 {
6533                    dense[[i, j1 * q2 + j2]] = b1[[i, j1]] * b2[[i, j2]];
6534                }
6535            }
6536        }
6537
6538        // Test to_dense.
6539        let op_dense = op.to_dense();
6540        let max_diff = (&op_dense - &dense)
6541            .iter()
6542            .map(|v: &f64| v.abs())
6543            .fold(0.0f64, f64::max);
6544        assert!(max_diff < 1e-14, "to_dense mismatch: max_diff={max_diff}");
6545
6546        // Test apply.
6547        let beta = Array1::from_vec((0..p).map(|j| (j as f64 + 1.0) * 0.1).collect());
6548        let ref_result = dense.dot(&beta);
6549        let op_result = op.apply(&beta);
6550        let max_diff = (&op_result - &ref_result)
6551            .iter()
6552            .map(|v: &f64| v.abs())
6553            .fold(0.0f64, f64::max);
6554        assert!(max_diff < 1e-12, "apply mismatch: max_diff={max_diff}");
6555
6556        // Test apply_transpose.
6557        let v = Array1::from_vec((0..n).map(|i| (i as f64 + 1.0) * 0.3).collect());
6558        let ref_xt_v = dense.t().dot(&v);
6559        let op_xt_v = op.apply_transpose(&v);
6560        let max_diff = (&op_xt_v - &ref_xt_v)
6561            .iter()
6562            .map(|v: &f64| v.abs())
6563            .fold(0.0f64, f64::max);
6564        assert!(
6565            max_diff < 1e-12,
6566            "apply_transpose mismatch: max_diff={max_diff}"
6567        );
6568
6569        // Test diag_xtw_x.
6570        let w = Array1::from_vec((0..n).map(|i| 1.0 + i as f64 * 0.1).collect());
6571        let ref_xtwx = {
6572            let mut out = Array2::<f64>::zeros((p, p));
6573            for i in 0..n {
6574                for a in 0..p {
6575                    for b in 0..p {
6576                        out[[a, b]] += w[i] * dense[[i, a]] * dense[[i, b]];
6577                    }
6578                }
6579            }
6580            out
6581        };
6582        let op_xtwx = op.diag_xtw_x(&w).unwrap();
6583        let max_diff = (&op_xtwx - &ref_xtwx)
6584            .iter()
6585            .map(|v: &f64| v.abs())
6586            .fold(0.0f64, f64::max);
6587        assert!(max_diff < 1e-10, "diag_xtw_x mismatch: max_diff={max_diff}");
6588    }
6589
6590    #[test]
6591    fn tensor_product_design_operator_3d() {
6592        use super::{DenseDesignOperator, TensorProductDesignOperator};
6593
6594        let n = 8;
6595        let dims = [3, 2, 2];
6596        let mut marginals: Vec<Array2<f64>> = Vec::new();
6597        for &q in &dims {
6598            let mut b = Array2::<f64>::zeros((n, q));
6599            for i in 0..n {
6600                let t = i as f64 / (n - 1) as f64 * (q - 1) as f64;
6601                let j = (t.floor() as usize).min(q - 2);
6602                let frac = t - j as f64;
6603                b[[i, j]] = 1.0 - frac;
6604                b[[i, j + 1]] = frac;
6605            }
6606            marginals.push(b);
6607        }
6608
6609        let op = TensorProductDesignOperator::new(
6610            marginals.iter().map(|m| Arc::new(m.clone())).collect(),
6611        )
6612        .unwrap();
6613
6614        // Dense reference.
6615        let p: usize = dims.iter().copied().product();
6616        let mut dense = Array2::<f64>::zeros((n, p));
6617        for i in 0..n {
6618            for j0 in 0..dims[0] {
6619                for j1 in 0..dims[1] {
6620                    for j2 in 0..dims[2] {
6621                        let col = j0 * dims[1] * dims[2] + j1 * dims[2] + j2;
6622                        dense[[i, col]] =
6623                            marginals[0][[i, j0]] * marginals[1][[i, j1]] * marginals[2][[i, j2]];
6624                    }
6625                }
6626            }
6627        }
6628
6629        let op_dense = op.to_dense();
6630        let max_diff = (&op_dense - &dense)
6631            .iter()
6632            .map(|v: &f64| v.abs())
6633            .fold(0.0f64, f64::max);
6634        assert!(
6635            max_diff < 1e-14,
6636            "3D to_dense mismatch: max_diff={max_diff}"
6637        );
6638
6639        // Test round-trip: apply then apply_transpose.
6640        let beta = Array1::from_vec((0..p).map(|j| (j as f64).sin()).collect());
6641        let xb = op.apply(&beta);
6642        let xtxb = op.apply_transpose(&xb);
6643        let ref_xtxb = dense.t().dot(&dense.dot(&beta));
6644        let max_diff = (&xtxb - &ref_xtxb)
6645            .iter()
6646            .map(|v: &f64| v.abs())
6647            .fold(0.0f64, f64::max);
6648        assert!(max_diff < 1e-10, "3D X'Xβ mismatch: max_diff={max_diff}");
6649    }
6650
6651    #[test]
6652    fn sparse_weighted_crossprod_parallel_path_matches_dense_reference() {
6653        use faer::sparse::Triplet;
6654
6655        let n = 4096;
6656        let p = 192;
6657        let mut triplets = Vec::with_capacity(n * 4);
6658        let mut dense = Array2::<f64>::zeros((n, p));
6659        for i in 0..n {
6660            let base = (i * 37) % p;
6661            for k in 0..4 {
6662                let col = (base + k * 11) % p;
6663                let val = ((i + 3 * k + 1) as f64).sin() * 0.25 + 0.5;
6664                triplets.push(Triplet::new(i, col, val));
6665                dense[[i, col]] = val;
6666            }
6667        }
6668        let sparse = faer::sparse::SparseColMat::try_new_from_triplets(n, p, &triplets).unwrap();
6669        let design = DesignMatrix::Sparse(SparseDesignMatrix::new(sparse));
6670        let weights = Array1::from_iter((0..n).map(|i| match i % 7 {
6671            0 => 0.0,
6672            r => 0.5 + r as f64 * 0.125,
6673        }));
6674
6675        let got = <DesignMatrix as LinearOperator>::xt_diag_x_signed_op(
6676            &design,
6677            SignedWeightsView::from_array(&weights),
6678        )
6679        .unwrap();
6680        let mut reference = Array2::<f64>::zeros((p, p));
6681        for i in 0..n {
6682            let wi = weights[i].max(0.0);
6683            if wi == 0.0 {
6684                continue;
6685            }
6686            for a in 0..p {
6687                let xa = dense[[i, a]];
6688                if xa == 0.0 {
6689                    continue;
6690                }
6691                for b in 0..p {
6692                    reference[[a, b]] += wi * xa * dense[[i, b]];
6693                }
6694            }
6695        }
6696        let max_diff = (&got - &reference)
6697            .iter()
6698            .map(|v: &f64| v.abs())
6699            .fold(0.0_f64, f64::max);
6700        assert!(
6701            max_diff < 1e-10,
6702            "sparse xtwx mismatch: max_diff={max_diff}"
6703        );
6704
6705        let got_diag = design.diag_gram(&weights).unwrap();
6706        let ref_diag = reference.diag().to_owned();
6707        let max_diag_diff = (&got_diag - &ref_diag)
6708            .iter()
6709            .map(|v: &f64| v.abs())
6710            .fold(0.0_f64, f64::max);
6711        assert!(
6712            max_diag_diff < 1e-10,
6713            "sparse diag gram mismatch: max_diff={max_diag_diff}"
6714        );
6715    }
6716
6717    #[test]
6718    fn rowwise_kronecker_sparse_structured_xtwx_matches_dense_reference() {
6719        use faer::sparse::Triplet;
6720
6721        let n = 2048;
6722        let p_cov = 64;
6723        let p_time = 6;
6724        let mut triplets = Vec::with_capacity(n * 3);
6725        let mut cov_dense = Array2::<f64>::zeros((n, p_cov));
6726        for i in 0..n {
6727            let base = (i * 17) % p_cov;
6728            for k in 0..3 {
6729                let col = (base + k * 7) % p_cov;
6730                let val = 0.2 + (((i + k) % 13) as f64) / 17.0;
6731                triplets.push(Triplet::new(i, col, val));
6732                cov_dense[[i, col]] = val;
6733            }
6734        }
6735        let cov_sparse =
6736            faer::sparse::SparseColMat::try_new_from_triplets(n, p_cov, &triplets).unwrap();
6737        let cov = DesignMatrix::Sparse(SparseDesignMatrix::new(cov_sparse));
6738        let mut time = Array2::<f64>::zeros((n, p_time));
6739        for i in 0..n {
6740            for t in 0..p_time {
6741                time[[i, t]] = (((i + 1) * (t + 3)) as f64).cos() * 0.1 + 0.4;
6742            }
6743        }
6744        let op = RowwiseKroneckerOperator::new(cov, Arc::new(time.clone())).unwrap();
6745        let weights = Array1::from_iter((0..n).map(|i| 0.25 + ((i % 11) as f64) * 0.05));
6746        let got = op.diag_xtw_x(&weights).unwrap();
6747
6748        let p_total = p_cov * p_time;
6749        let mut reference = Array2::<f64>::zeros((p_total, p_total));
6750        for i in 0..n {
6751            for c1 in 0..p_cov {
6752                let x1 = cov_dense[[i, c1]];
6753                if x1 == 0.0 {
6754                    continue;
6755                }
6756                for t1 in 0..p_time {
6757                    let a = c1 * p_time + t1;
6758                    let xa = x1 * time[[i, t1]];
6759                    for c2 in 0..p_cov {
6760                        let x2 = cov_dense[[i, c2]];
6761                        if x2 == 0.0 {
6762                            continue;
6763                        }
6764                        for t2 in 0..p_time {
6765                            let b = c2 * p_time + t2;
6766                            reference[[a, b]] += weights[i] * xa * x2 * time[[i, t2]];
6767                        }
6768                    }
6769                }
6770            }
6771        }
6772        let max_diff = (&got - &reference)
6773            .iter()
6774            .map(|v: &f64| v.abs())
6775            .fold(0.0_f64, f64::max);
6776        assert!(
6777            max_diff < 1e-9,
6778            "rowwise kronecker sparse xtwx mismatch: max_diff={max_diff}"
6779        );
6780    }
6781
6782    #[test]
6783    fn embedded_column_block_zero_row_local_materializes_empty_global_width() {
6784        let local = Array2::<f64>::zeros((0, 0));
6785        let out = EmbeddedColumnBlock::new(&local, 2..5, 7).materialize();
6786        assert_eq!(out.dim(), (0, 7));
6787    }
6788
6789    /// Identity case: with V_b = I and a zero r_block, the residualised
6790    /// operator must emit the raw inner block unchanged. Anchored by an
6791    /// arbitrary 3×2 anchor whose contribution is zeroed out by the all-zero
6792    /// r_block — verifies the subtraction path is wired but contributes
6793    /// nothing when the reparam happens to be identity-with-no-residual.
6794    #[test]
6795    fn residualised_design_operator_identity_passthrough() {
6796        let inner = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
6797        let transform = Array2::<f64>::eye(2);
6798        let anchor_raw = array![[7.0, -1.0], [0.5, 2.0], [-3.0, 1.5]];
6799        let r_block = Arc::new(Array2::<f64>::zeros((
6800            anchor_raw.ncols(),
6801            transform.ncols(),
6802        )));
6803        let anchor_design = DesignMatrix::from(anchor_raw);
6804
6805        let op = ResidualisedDesignOperator::new(
6806            DenseDesignMatrix::from(inner.clone()),
6807            transform,
6808            vec![(anchor_design, r_block)],
6809        )
6810        .expect("residualised operator constructs");
6811
6812        // Row-chunk path (cold — exercises the streaming branch before the
6813        // materialisation cache is warmed).
6814        let mut chunk = Array2::<f64>::zeros((3, 2));
6815        op.row_chunk_into(0..3, chunk.view_mut())
6816            .expect("row chunk");
6817        for ((r, c), v) in inner.indexed_iter() {
6818            assert!(
6819                (chunk[[r, c]] - v).abs() < 1e-12,
6820                "identity row_chunk mismatch at ({r},{c}): got {} expected {v}",
6821                chunk[[r, c]]
6822            );
6823        }
6824
6825        // Through the DenseDesignMatrix wrapper — confirms the
6826        // generic `From<Arc<T: DenseDesignOperator>>` integration carries
6827        // shape and row-access semantics into the rest of the design stack.
6828        let dense_design = DenseDesignMatrix::from(Arc::new(op));
6829        assert_eq!(dense_design.nrows(), 3);
6830        assert_eq!(dense_design.ncols(), 2);
6831        let probe = ndarray::Array1::from_vec(vec![1.0, -2.0]);
6832        let got = dense_design.apply(&probe);
6833        let expected = inner.dot(&probe);
6834        for i in 0..3 {
6835            assert!((got[i] - expected[i]).abs() < 1e-12);
6836        }
6837    }
6838
6839    /// Two-block case with a shared column: build raw A and B that overlap
6840    /// on one direction, hand-construct V_b and R_b so that
6841    /// `out_full = A·γ_A + (C_b·V_b − A·R_b)·θ_b` recovers the raw row
6842    /// prediction `A·γ_A + B·β_b` exactly. Anchors the contract that
6843    /// `R_b = M_b · V_b` projects out the anchor-overlapping direction so
6844    /// the emitted compiled column is orthogonal-to-the-anchor at the
6845    /// design-matrix level.
6846    #[test]
6847    fn residualised_design_operator_two_block_reconstruction() {
6848        // Anchor A (n × 2) and raw block B (n × 2); they share their first
6849        // column up to scale, so the "kept" direction of B is its second
6850        // column residualised against A's column space.
6851        let anchor = array![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]];
6852        let b_raw = array![[1.0, 2.0], [1.0, 1.5], [1.0, 0.5], [1.0, -1.0]];
6853
6854        // Choose V_b that picks the second raw direction of B (kept dim = 1).
6855        // V_b is (p_b_raw=2) × (p_b_kept=1) selecting column 1 of B.
6856        let v_b = array![[0.0], [1.0]];
6857
6858        // Solve M_b = (A'A)^{-1} A'B · V_b → here A'A is diag-ish so do
6859        // it via least squares directly. The contract is:
6860        //   R_b = M_b · V_b  where M_b ≈ (A'A)^{-1} A'B (size p_a × p_b_raw)
6861        // We can just compute the projection coefficients of B·V_b onto A.
6862        let bv = b_raw.dot(&v_b); // n × 1
6863        let ata = anchor.t().dot(&anchor); // 2x2
6864        let atbv = anchor.t().dot(&bv); // 2x1
6865        let ata_inv = {
6866            let det = ata[[0, 0]] * ata[[1, 1]] - ata[[0, 1]] * ata[[1, 0]];
6867            array![
6868                [ata[[1, 1]] / det, -ata[[0, 1]] / det],
6869                [-ata[[1, 0]] / det, ata[[0, 0]] / det],
6870            ]
6871        };
6872        let r_b: Array2<f64> = ata_inv.dot(&atbv); // 2 × 1  — already R_b
6873
6874        let op = ResidualisedDesignOperator::new(
6875            DenseDesignMatrix::from(b_raw.clone()),
6876            v_b.clone(),
6877            vec![(DesignMatrix::from(anchor.clone()), Arc::new(r_b.clone()))],
6878        )
6879        .expect("residualised operator constructs");
6880
6881        // Choose anchor coefficients γ_A and kept block coefficient θ_b.
6882        let gamma_a = ndarray::Array1::from_vec(vec![0.5, -1.25]);
6883        let theta_b = ndarray::Array1::from_vec(vec![2.5]);
6884
6885        // Expected via the explicit emitted row: A·γ_A + (C_b·V_b − A·R_b)·θ_b.
6886        let cv = b_raw.dot(&v_b); // C_b · V_b
6887        let ar = anchor.dot(&r_b); // A · R_b
6888        let emitted_b_chunk = &cv - &ar;
6889        let expected = anchor.dot(&gamma_a) + emitted_b_chunk.dot(&theta_b);
6890
6891        // Pull a streaming chunk through the operator and verify the
6892        // contribution matches the hand-computed (C_b·V_b − A·R_b)·θ_b row
6893        // by row.
6894        let mut got_chunk = Array2::<f64>::zeros((4, 1));
6895        op.row_chunk_into(0..4, got_chunk.view_mut())
6896            .expect("row chunk");
6897        let got = anchor.dot(&gamma_a) + got_chunk.dot(&theta_b);
6898        for i in 0..4 {
6899            assert!(
6900                (got[i] - expected[i]).abs() < 1e-10,
6901                "two-block reconstruction mismatch at row {i}: got {} expected {}",
6902                got[i],
6903                expected[i]
6904            );
6905        }
6906
6907        // Cross-check via the LinearOperator::apply path (vector-valued v_b
6908        // matmul against the compiled width). This goes through the
6909        // streaming inner.apply / fast_av routes, distinct from the
6910        // row_chunk_into path covered above.
6911        let applied = op.apply(&theta_b);
6912        for i in 0..4 {
6913            assert!((applied[i] - emitted_b_chunk[[i, 0]] * theta_b[0]).abs() < 1e-10);
6914        }
6915    }
6916}