1use crate::faer_ndarray::{
2 CrossprodAccum, CrossprodStructure, FaerArrayView, array2_to_matmut,
3 effective_global_parallelism, fast_ab, fast_atb, fast_atv, fast_atv_into, fast_av,
4 fast_av_into, fast_xt_diag_x, stream_weighted_crossprod_into,
5};
6use crate::types::RidgePolicy;
7use faer::Accum;
8use faer::linalg::matmul::matmul;
9use faer::sparse::{SparseColMat, SparseRowMat, Triplet};
10use gam_runtime::resource::{
11 MaterializationPolicy, MatrixMaterializationError, ResourcePolicy, rows_for_target_bytes,
12};
13use ndarray::{
14 Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis, ShapeBuilder, s,
15};
16use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
17use std::borrow::Cow;
18use std::collections::BTreeMap;
19use std::ops::Deref;
20use std::ops::Range;
21use std::sync::{Arc, OnceLock};
22
23const MATRIX_FREE_PCG_MIN_P: usize = 2048;
24const MATRIX_FREE_PCG_REL_TOL: f64 = 1e-8;
25const SPD_SOLVE_RIDGE_FLOOR: f64 = 1e-15;
30const MATRIX_FREE_PCG_MAX_ITER: usize = 2000;
31const MAX_SINGLE_DENSE_MATERIALIZATION_BYTES: usize = 256 * 1024 * 1024;
32const MAX_PERSISTENT_SPARSE_DENSE_CACHE_BYTES: usize = 256 * 1024 * 1024;
33const MAX_SPARSE_TO_DENSE_BYTES: usize = MAX_SINGLE_DENSE_MATERIALIZATION_BYTES;
34const CHUNKED_DENSE_MATERIALIZATION_BYTES: usize = 8 * 1024 * 1024;
35const OPERATOR_ROW_CHUNK_SIZE: usize = 256;
36const DENSE_ROW_PARALLEL_MIN_NP: u64 = 200_000;
40const WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS: u64 = 500_000;
41const SPARSE_ROW_PARALLEL_MIN_FLOPS: u64 = 100_000;
42const TENSOR_GEMM_MAX_INTERMEDIATE_BYTES: usize = 128 * 1024 * 1024; pub use crate::utils::PcgSolveInfo;
47
48mod sparse_hessian;
49pub use sparse_hessian::SparseHessianAccumulator;
50
51mod weights;
52pub use weights::{PsdWeightsView, SignedWeightsArc, SignedWeightsView};
53
54#[derive(Debug, Clone)]
59pub enum MatrixError {
60 DimensionMismatch { reason: String },
63 DensificationRefused { reason: String },
67}
68
69crate::impl_reason_error_boilerplate! {
70 MatrixError {
71 DimensionMismatch,
72 DensificationRefused,
73 }
74}
75
76#[inline]
77fn dense_materialization_chunk_rows(nrows: usize, ncols: usize) -> usize {
78 rows_for_target_bytes(CHUNKED_DENSE_MATERIALIZATION_BYTES, ncols)
79 .max(1)
80 .min(nrows.max(1))
81}
82
83fn dense_operator_to_dense_by_chunks<O: DenseDesignOperator + ?Sized>(
84 op: &O,
85) -> Result<Array2<f64>, MatrixMaterializationError> {
86 let n = op.nrows();
87 let p = op.ncols();
88 let chunk_rows = dense_materialization_chunk_rows(n, p);
89 let mut out = Array2::<f64>::zeros((n, p));
90 for start in (0..n).step_by(chunk_rows) {
91 let end = (start + chunk_rows).min(n);
92 let slice = out.slice_mut(s![start..end, ..]);
93 op.row_chunk_into(start..end, slice)?;
94 }
95 Ok(out)
96}
97
98pub fn checked_dense_nbytes(nrows: usize, ncols: usize, context: &str) -> Result<usize, String> {
99 nrows
100 .checked_mul(ncols)
101 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
102 .ok_or_else(|| {
103 MatrixError::DimensionMismatch {
104 reason: format!("{context}: dense size overflow for {nrows}x{ncols}"),
105 }
106 .into()
107 })
108}
109
110pub fn panic_or_error_if_large_scale_mode_and_to_dense_called_with_policy(
111 context: &str,
112 n: usize,
113 p: usize,
114 policy: &ResourcePolicy,
115) -> Result<(), String> {
116 if matches!(
122 policy.derivative_storage_mode,
123 gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired
124 ) {
125 return Err(MatrixError::DensificationRefused {
126 reason: format!(
127 "{context}: refusing to densify operator-backed design {n}x{p} under \
128 AnalyticOperatorRequired policy; provide an operator-form path"
129 ),
130 }
131 .into());
132 }
133 let dense_bytes = checked_dense_nbytes(n, p, context)?;
134 let limit = policy.max_single_materialization_bytes;
135 if dense_bytes > limit {
136 let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
137 return Err(MatrixError::DensificationRefused {
138 reason: format!(
139 "{context}: refusing to densify operator-backed design {n}x{p} (~{gib:.2} GiB); use matrix-free or chunked code"
140 ),
141 }
142 .into());
143 }
144 Ok(())
145}
146
147fn weighted_crossprod_dense(
148 left: &Array2<f64>,
149 weights: &Array1<f64>,
150 right: &Array2<f64>,
151) -> Result<Array2<f64>, String> {
152 if left.nrows() != weights.len() || right.nrows() != weights.len() {
153 return Err(MatrixError::DimensionMismatch {
154 reason: format!(
155 "weighted_crossprod_dense row mismatch: left={}, weights={}, right={}",
156 left.nrows(),
157 weights.len(),
158 right.nrows()
159 ),
160 }
161 .into());
162 }
163 Ok(weighted_crossprod_dense_view(left, weights.view(), right))
164}
165
166fn weighted_crossprod_dense_view(
167 left: &Array2<f64>,
168 weights: ArrayView1<'_, f64>,
169 right: &Array2<f64>,
170) -> Array2<f64> {
171 let n = weights.len();
172 let p_left = left.ncols();
173 let p_right = right.ncols();
174 let work = (n as u64)
175 .saturating_mul(p_left as u64)
176 .saturating_mul(p_right as u64);
177 if rayon::current_num_threads() <= 1 || work < WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS {
178 return weighted_crossprod_dense_rows(left, weights, right, 0..n);
179 }
180
181 let min_parallel_work = WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS.min(usize::MAX as u64) as usize;
182 let Some(chunk_rows) = crate::parallel::row_reduction_chunk_rows(
183 n,
184 p_left.saturating_mul(p_right),
185 p_left.saturating_mul(p_right),
186 min_parallel_work,
187 ) else {
188 return weighted_crossprod_dense_rows(left, weights, right, 0..n);
189 };
190 let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
191 let partials: Vec<Array2<f64>> = starts
192 .into_par_iter()
193 .map(|start| {
194 weighted_crossprod_dense_rows(left, weights, right, start..(start + chunk_rows).min(n))
195 })
196 .collect();
197 let mut out = Array2::<f64>::zeros((p_left, p_right));
198 for partial in &partials {
199 out += partial;
200 }
201 out
202}
203
204fn weighted_crossprod_dense_rows(
205 left: &Array2<f64>,
206 weights: ArrayView1<'_, f64>,
207 right: &Array2<f64>,
208 rows: Range<usize>,
209) -> Array2<f64> {
210 let p_left = left.ncols();
218 let p_right = right.ncols();
219 let mut out = Array2::<f64>::zeros((p_left, p_right));
220 if left.is_standard_layout()
221 && right.is_standard_layout()
222 && let (Some(lx), Some(rx), Some(w)) =
223 (left.as_slice(), right.as_slice(), weights.as_slice())
224 {
225 let out_slice = out.as_slice_mut().expect("zeros are contiguous");
226 for i in rows {
227 let wi = w[i];
228 if wi == 0.0 {
229 continue;
230 }
231 let l_row = &lx[i * p_left..i * p_left + p_left];
232 let r_row = &rx[i * p_right..i * p_right + p_right];
233 for a in 0..p_left {
234 let scaled = wi * l_row[a];
235 if scaled == 0.0 {
236 continue;
237 }
238 let out_row = &mut out_slice[a * p_right..a * p_right + p_right];
239 for b in 0..p_right {
240 out_row[b] += scaled * r_row[b];
241 }
242 }
243 }
244 return out;
245 }
246 for i in rows {
247 let wi = weights[i];
248 if wi == 0.0 {
249 continue;
250 }
251 for a in 0..p_left {
252 let scaled = wi * left[[i, a]];
253 if scaled == 0.0 {
254 continue;
255 }
256 for b in 0..p_right {
257 out[[a, b]] += scaled * right[[i, b]];
258 }
259 }
260 }
261 out
262}
263
264pub struct DenseRightProductView<'a> {
265 base: &'a Array2<f64>,
266 first: Option<&'a Array2<f64>>,
267 second: Option<&'a Array2<f64>>,
268}
269
270impl<'a> DenseRightProductView<'a> {
271 pub fn new(base: &'a Array2<f64>) -> Self {
272 Self {
273 base,
274 first: None,
275 second: None,
276 }
277 }
278
279 pub fn with_factor(mut self, factor: &'a Array2<f64>) -> Self {
280 if self.first.is_none() {
281 self.first = Some(factor);
282 } else if self.second.is_none() {
283 self.second = Some(factor);
284 } else {
285 std::panic::panic_any("DenseRightProductView supports at most two right factors");
291 }
292 self
293 }
294
295 pub fn with_optional_factor(self, factor: Option<&'a Array2<f64>>) -> Self {
296 match factor {
297 Some(factor) => self.with_factor(factor),
298 None => self,
299 }
300 }
301
302 pub fn materialize(&self) -> Array2<f64> {
303 let mut out = self.base.clone();
304 if let Some(factor) = self.first {
305 out = fast_ab(&out, factor);
306 }
307 if let Some(factor) = self.second {
308 out = fast_ab(&out, factor);
309 }
310 out
311 }
312
313 fn transformed_ncols(&self) -> usize {
314 if let Some(factor) = self.second {
315 factor.ncols()
316 } else if let Some(factor) = self.first {
317 factor.ncols()
318 } else {
319 self.base.ncols()
320 }
321 }
322}
323
324pub struct EmbeddedColumnBlock<'a> {
325 local: &'a Array2<f64>,
326 global_range: Range<usize>,
327 total_cols: usize,
328}
329
330impl<'a> EmbeddedColumnBlock<'a> {
331 pub fn new(local: &'a Array2<f64>, global_range: Range<usize>, total_cols: usize) -> Self {
332 Self {
333 local,
334 global_range,
335 total_cols,
336 }
337 }
338
339 pub fn materialize(&self) -> Array2<f64> {
340 if self.local.nrows() == 0 {
341 return Array2::<f64>::zeros((0, self.total_cols));
342 }
343 assert_eq!(
344 self.local.ncols(),
345 self.global_range.len(),
346 "embedded column block width mismatch"
347 );
348 let mut out = Array2::<f64>::zeros((self.local.nrows(), self.total_cols));
349 out.slice_mut(ndarray::s![.., self.global_range.clone()])
350 .assign(self.local);
351 out
352 }
353}
354
355pub struct EmbeddedSquareBlock<'a> {
356 local: &'a Array2<f64>,
357 global_range: Range<usize>,
358 total_dim: usize,
359}
360
361impl<'a> EmbeddedSquareBlock<'a> {
362 pub fn new(local: &'a Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
363 Self {
364 local,
365 global_range,
366 total_dim,
367 }
368 }
369
370 pub fn materialize(&self) -> Array2<f64> {
371 let mut out = Array2::<f64>::zeros((self.total_dim, self.total_dim));
372 out.slice_mut(ndarray::s![
373 self.global_range.clone(),
374 self.global_range.clone()
375 ])
376 .assign(self.local);
377 out
378 }
379}
380
381struct PenalizedWeightedNormalOperator<'a, O: LinearOperator + ?Sized> {
382 operator: &'a O,
383 weights: &'a Array1<f64>,
384 penalty: Option<&'a Array2<f64>>,
385 ridge: f64,
386}
387
388impl<'a, O: LinearOperator + ?Sized> PenalizedWeightedNormalOperator<'a, O> {
389 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
390 self.operator
391 .apply_weighted_normal(self.weights, vector, self.penalty, self.ridge)
392 }
393
394 fn jacobi_preconditioner(&self) -> Result<Array1<f64>, String> {
395 let mut diag = self.operator.diag_gram(self.weights)?;
396 if let Some(pen) = self.penalty {
397 for i in 0..diag.len() {
398 diag[i] += pen[[i, i]];
399 }
400 }
401 if self.ridge > 0.0 {
402 for i in 0..diag.len() {
403 diag[i] += self.ridge;
404 }
405 }
406 Ok(diag)
407 }
408}
409
410#[inline]
411fn dense_diag_gram_view(matrix: &Array2<f64>, weights: PsdWeightsView<'_>) -> Array1<f64> {
412 let weights = weights.view();
418 let p = matrix.ncols();
419 let n = matrix.nrows();
420 let large = (n as u64) * (p as u64) >= DENSE_ROW_PARALLEL_MIN_NP;
421 let parallel = large && rayon::current_thread_index().is_none();
422 if matrix.is_standard_layout()
425 && let (Some(x), Some(w)) = (matrix.as_slice(), weights.as_slice())
426 {
427 if parallel {
428 return (0..n)
429 .into_par_iter()
430 .fold(
431 || vec![0.0_f64; p],
432 |mut acc, i| {
433 let wi = w[i];
434 if wi != 0.0 {
435 let row = &x[i * p..i * p + p];
436 for j in 0..p {
437 let xij = row[j];
438 acc[j] += wi * xij * xij;
439 }
440 }
441 acc
442 },
443 )
444 .reduce(
445 || vec![0.0_f64; p],
446 |mut a, b| {
447 for (av, bv) in a.iter_mut().zip(b) {
448 *av += bv;
449 }
450 a
451 },
452 )
453 .into();
454 }
455 let mut diag = Array1::<f64>::zeros(p);
456 let diag_slice = diag.as_slice_mut().expect("zeros are contiguous");
457 for i in 0..n {
458 let wi = w[i];
459 if wi == 0.0 {
460 continue;
461 }
462 let row = &x[i * p..i * p + p];
463 for j in 0..p {
464 let xij = row[j];
465 diag_slice[j] += wi * xij * xij;
466 }
467 }
468 return diag;
469 }
470 let mut diag = Array1::<f64>::zeros(p);
471 for i in 0..n {
472 let wi = weights[i];
473 if wi == 0.0 {
474 continue;
475 }
476 for j in 0..p {
477 let xij = matrix[[i, j]];
478 diag[j] += wi * xij * xij;
479 }
480 }
481 diag
482}
483
484fn sparse_csr_weighted_xtwx(
485 row_ptr: &[usize],
486 col_idx: &[usize],
487 vals: &[f64],
488 n: usize,
489 p: usize,
490 weights: ArrayView1<'_, f64>,
491) -> Array2<f64> {
492 let nnz = vals.len() as u64;
493 let avg = nnz.checked_div(n.max(1) as u64).unwrap_or(0);
494 let work = (n as u64).saturating_mul(avg.saturating_mul(avg));
495 if rayon::current_num_threads() <= 1 || work < SPARSE_ROW_PARALLEL_MIN_FLOPS {
496 return sparse_csr_weighted_xtwx_rows(row_ptr, col_idx, vals, p, weights, 0..n);
497 }
498
499 let min_parallel_work = SPARSE_ROW_PARALLEL_MIN_FLOPS.min(usize::MAX as u64) as usize;
500 let Some(chunk_rows) = crate::parallel::row_reduction_chunk_rows(
501 n,
502 avg.min(usize::MAX as u64) as usize,
503 p.saturating_mul(p),
504 min_parallel_work,
505 ) else {
506 return sparse_csr_weighted_xtwx_rows(row_ptr, col_idx, vals, p, weights, 0..n);
507 };
508 let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
509 let partials: Vec<Array2<f64>> = starts
510 .into_par_iter()
511 .map(|start| {
512 sparse_csr_weighted_xtwx_rows(
513 row_ptr,
514 col_idx,
515 vals,
516 p,
517 weights,
518 start..(start + chunk_rows).min(n),
519 )
520 })
521 .collect();
522 let mut xtwx = Array2::<f64>::zeros((p, p));
523 for partial in &partials {
524 xtwx += partial;
525 }
526 xtwx
527}
528
529fn sparse_csr_weighted_xtwx_rows(
530 row_ptr: &[usize],
531 col_idx: &[usize],
532 vals: &[f64],
533 p: usize,
534 weights: ArrayView1<'_, f64>,
535 rows: Range<usize>,
536) -> Array2<f64> {
537 let mut xtwx = Array2::<f64>::zeros((p, p));
544 for i in rows {
545 let wi = weights[i];
546 if wi == 0.0 {
547 continue;
548 }
549 let start = row_ptr[i];
550 let end = row_ptr[i + 1];
551 for a_ptr in start..end {
552 let a = col_idx[a_ptr];
553 let wxa = wi * vals[a_ptr];
554 for b_ptr in a_ptr..end {
555 let b = col_idx[b_ptr];
556 let v = wxa * vals[b_ptr];
557 xtwx[[a, b]] += v;
558 if a != b {
559 xtwx[[b, a]] += v;
560 }
561 }
562 }
563 }
564 xtwx
565}
566
567pub fn streaming_sparse_csc_xt_diag_x(
568 col_ptr: &[usize],
569 row_idx: &[usize],
570 vals: &[f64],
571 n: usize,
572 p: usize,
573 weights: ArrayView1<'_, f64>,
574 out: &mut Array2<f64>,
575) {
576 if n == 0 || p == 0 {
577 return;
578 }
579
580 let chunk_rows = dense_materialization_chunk_rows(n, p);
581 let par = effective_global_parallelism();
582 let mut x_chunk = Array2::<f64>::zeros((chunk_rows, p).f());
583 let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p).f());
584
585 {
586 let mut out_view = array2_to_matmut(out);
587
588 for start in (0..n).step_by(chunk_rows) {
589 let rows = (n - start).min(chunk_rows);
590 {
591 let mut x_slice = x_chunk.slice_mut(s![0..rows, ..]);
592 let mut wx_slice = wx_chunk.slice_mut(s![0..rows, ..]);
593 x_slice.fill(0.0);
594 wx_slice.fill(0.0);
595 let end = start + rows;
596 for col in 0..p {
597 let col_start = col_ptr[col];
598 let col_end = col_ptr[col + 1];
599 let rows_for_col = &row_idx[col_start..col_end];
600 let local_start = rows_for_col.partition_point(|&row| row < start);
601 let local_end = rows_for_col.partition_point(|&row| row < end);
602 for local_ptr in local_start..local_end {
603 let ptr = col_start + local_ptr;
604 let row = row_idx[ptr];
605 let local = row - start;
606 let wi = weights[row];
607 let value = vals[ptr];
608 x_slice[[local, col]] += value;
609 wx_slice[[local, col]] += wi * value;
610 }
611 }
612 }
613 let x_slice = x_chunk.slice(s![0..rows, ..]);
614 let wx_slice = wx_chunk.slice(s![0..rows, ..]);
615 let x_view = FaerArrayView::new(&x_slice);
616 let wx_view = FaerArrayView::new(&wx_slice);
617 matmul(
618 out_view.as_mut(),
619 Accum::Add,
620 x_view.as_ref().transpose(),
621 wx_view.as_ref(),
622 1.0,
623 par,
624 );
625 }
626 }
627}
628
629fn sparse_csr_diag_gram(
630 row_ptr: &[usize],
631 col_idx: &[usize],
632 vals: &[f64],
633 n: usize,
634 p: usize,
635 weights: ArrayView1<'_, f64>,
636) -> Array1<f64> {
637 let work = vals.len() as u64;
638 if rayon::current_num_threads() <= 1 || work < SPARSE_ROW_PARALLEL_MIN_FLOPS {
639 return sparse_csr_diag_gram_rows(row_ptr, col_idx, vals, p, weights, 0..n);
640 }
641 let min_parallel_work = SPARSE_ROW_PARALLEL_MIN_FLOPS.min(usize::MAX as u64) as usize;
642 let Some(chunk_rows) = crate::parallel::row_reduction_chunk_rows(n, 1, p, min_parallel_work)
643 else {
644 return sparse_csr_diag_gram_rows(row_ptr, col_idx, vals, p, weights, 0..n);
645 };
646 let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
647 let partials: Vec<Array1<f64>> = starts
648 .into_par_iter()
649 .map(|start| {
650 sparse_csr_diag_gram_rows(
651 row_ptr,
652 col_idx,
653 vals,
654 p,
655 weights,
656 start..(start + chunk_rows).min(n),
657 )
658 })
659 .collect();
660 let mut diag = Array1::<f64>::zeros(p);
661 for partial in &partials {
662 diag += partial;
663 }
664 diag
665}
666
667fn sparse_csr_diag_gram_rows(
668 row_ptr: &[usize],
669 col_idx: &[usize],
670 vals: &[f64],
671 p: usize,
672 weights: ArrayView1<'_, f64>,
673 rows: Range<usize>,
674) -> Array1<f64> {
675 let mut diag = Array1::<f64>::zeros(p);
680 for i in rows {
681 let wi = weights[i];
682 if wi == 0.0 {
683 continue;
684 }
685 for idx in row_ptr[i]..row_ptr[i + 1] {
686 let j = col_idx[idx];
687 let xij = vals[idx];
688 diag[j] += wi * xij * xij;
689 }
690 }
691 diag
692}
693
694#[inline]
695fn dense_transpose_weighted_response(
696 matrix: &Array2<f64>,
697 weights: &Array1<f64>,
698 y: &Array1<f64>,
699 row_scale: Option<&Array1<f64>>,
700) -> Array1<f64> {
701 let p = matrix.ncols();
706 let n = matrix.nrows();
707 let mut out = Array1::<f64>::zeros(p);
708 if matrix.is_standard_layout()
709 && let (Some(x), Some(w), Some(yslice)) =
710 (matrix.as_slice(), weights.as_slice(), y.as_slice())
711 {
712 let scale_slice = row_scale.and_then(|s| s.as_slice());
713 let out_slice = out.as_slice_mut().expect("zeros are contiguous");
714 for i in 0..n {
715 let mut scaled = yslice[i] * w[i];
716 if let Some(s) = scale_slice {
717 scaled *= s[i];
718 } else if let Some(scale) = row_scale {
719 scaled *= scale[i];
720 }
721 if scaled == 0.0 {
722 continue;
723 }
724 let row = &x[i * p..i * p + p];
725 for j in 0..p {
726 out_slice[j] += row[j] * scaled;
727 }
728 }
729 return out;
730 }
731 for i in 0..n {
732 let mut scaled = y[i] * weights[i];
733 if let Some(scale) = row_scale {
734 scaled *= scale[i];
735 }
736 if scaled == 0.0 {
737 continue;
738 }
739 for j in 0..p {
740 out[j] += matrix[[i, j]] * scaled;
741 }
742 }
743 out
744}
745
746#[inline]
747fn dense_transpose_weighted_response_view(
748 matrix: &Array2<f64>,
749 weights: ArrayView1<'_, f64>,
750 y: ArrayView1<'_, f64>,
751) -> Array1<f64> {
752 let p = matrix.ncols();
755 let n = matrix.nrows();
756 let mut out = Array1::<f64>::zeros(p);
757 if matrix.is_standard_layout()
758 && let (Some(x), Some(w), Some(yslice)) =
759 (matrix.as_slice(), weights.as_slice(), y.as_slice())
760 {
761 let out_slice = out.as_slice_mut().expect("zeros are contiguous");
762 for i in 0..n {
763 let scaled = yslice[i] * w[i];
764 if scaled == 0.0 {
765 continue;
766 }
767 let row = &x[i * p..i * p + p];
768 for j in 0..p {
769 out_slice[j] += row[j] * scaled;
770 }
771 }
772 return out;
773 }
774 for i in 0..n {
775 let scaled = y[i] * weights[i];
776 if scaled == 0.0 {
777 continue;
778 }
779 for j in 0..p {
780 out[j] += matrix[[i, j]] * scaled;
781 }
782 }
783 out
784}
785
786#[derive(Clone)]
787pub struct SparseDesignMatrix {
788 matrix: SparseColMat<usize, f64>,
789 dense_cache: Arc<OnceLock<Arc<Array2<f64>>>>,
790 csr_cache: Arc<OnceLock<Arc<SparseRowMat<usize, f64>>>>,
791}
792
793impl SparseDesignMatrix {
794 pub fn new(matrix: SparseColMat<usize, f64>) -> Self {
795 Self {
796 matrix,
797 dense_cache: Arc::new(OnceLock::new()),
798 csr_cache: Arc::new(OnceLock::new()),
799 }
800 }
801
802 fn dense_nbytes(&self) -> Result<usize, String> {
803 self.matrix
804 .nrows()
805 .checked_mul(self.matrix.ncols())
806 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
807 .ok_or_else(|| {
808 format!(
809 "dense size overflow for sparse design {}x{}",
810 self.matrix.nrows(),
811 self.matrix.ncols()
812 )
813 })
814 }
815
816 fn materialize_dense_arc(&self) -> Arc<Array2<f64>> {
817 let mut out = Array2::<f64>::zeros((self.matrix.nrows(), self.matrix.ncols()));
818 let (symbolic, values) = self.matrix.parts();
819 let col_ptr = symbolic.col_ptr();
820 let row_idx = symbolic.row_idx();
821 for col in 0..self.matrix.ncols() {
822 let start = col_ptr[col];
823 let end = col_ptr[col + 1];
824 for idx in start..end {
825 out[[row_idx[idx], col]] += values[idx];
826 }
827 }
828 Arc::new(out)
829 }
830
831 pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
832 let dense_bytes = self.dense_nbytes()?;
833 if dense_bytes > MAX_SPARSE_TO_DENSE_BYTES {
834 let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
835 return Err(MatrixError::DensificationRefused {
836 reason: format!(
837 "{context}: refusing to densify sparse design {}x{} (~{gib:.2} GiB); use sparse or matrix-free code",
838 self.matrix.nrows(),
839 self.matrix.ncols(),
840 ),
841 }
842 .into());
843 }
844 if dense_bytes <= MAX_PERSISTENT_SPARSE_DENSE_CACHE_BYTES {
845 Ok(self
846 .dense_cache
847 .get_or_init(|| self.materialize_dense_arc())
848 .clone())
849 } else {
850 Ok(self.materialize_dense_arc())
851 }
852 }
853
854 pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
855 self.try_to_dense_arc("SparseDesignMatrix::to_dense_arc")
856 .unwrap_or_else(|msg| {
857 let bt = std::backtrace::Backtrace::force_capture();
858 std::panic::panic_any(format!("{msg}\nbacktrace:\n{bt}"))
865 })
866 }
867
868 pub fn to_csr_arc(&self) -> Option<Arc<SparseRowMat<usize, f64>>> {
869 if let Some(cached) = self.csr_cache.get() {
870 return Some(cached.clone());
871 }
872 let csr = self.matrix.as_ref().to_row_major().ok()?;
873 let arc = Arc::new(csr);
874 self.csr_cache.set(arc.clone()).ok();
875 Some(arc)
876 }
877}
878
879impl Deref for SparseDesignMatrix {
880 type Target = SparseColMat<usize, f64>;
881 fn deref(&self) -> &Self::Target {
882 &self.matrix
883 }
884}
885
886impl AsRef<SparseColMat<usize, f64>> for SparseDesignMatrix {
887 fn as_ref(&self) -> &SparseColMat<usize, f64> {
888 &self.matrix
889 }
890}
891
892pub trait DenseDesignOperator: LinearOperator + Send + Sync {
900 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
901 let n = self.nrows();
903 if weights.len() != n || y.len() != n {
904 return Err(format!(
905 "DenseDesignOperator::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
906 weights.len(),
907 y.len(),
908 n
909 ));
910 }
911 let mut wy = Array1::<f64>::zeros(n);
914 ndarray::Zip::from(&mut wy)
915 .and(weights)
916 .and(y)
917 .par_for_each(|o, &w, &yi| *o = w * yi);
918 Ok(self.apply_transpose(&wy))
919 }
920
921 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
922 if middle.nrows() != self.ncols() || middle.ncols() != self.ncols() {
925 return Err(format!(
926 "DenseDesignOperator::quadratic_form_diag dimension mismatch: {}x{} vs expected {}x{}",
927 middle.nrows(),
928 middle.ncols(),
929 self.ncols(),
930 self.ncols()
931 ));
932 }
933 let n = self.nrows();
934 let mut out = Array1::<f64>::zeros(n);
935 let chunk_size = (8 * 1024 * 1024 / (self.ncols().max(1) * 8 * 2))
937 .max(16)
938 .min(n.max(1));
939 let mut start = 0;
940 while start < n {
941 let end = (start + chunk_size).min(n);
942 let x_chunk = self.try_row_chunk(start..end).map_err(|e| e.to_string())?;
943 let xm_chunk = fast_ab(&x_chunk, middle);
944 let mut chunk_out = out.slice_mut(ndarray::s![start..end]);
945 ndarray::Zip::from(&mut chunk_out)
946 .and(x_chunk.rows())
947 .and(xm_chunk.rows())
948 .par_for_each(|o, xr, xmr| *o = xr.dot(&xmr).max(0.0));
951 start = end;
952 }
953 Ok(out)
954 }
955
956 fn row_chunk_into(
959 &self,
960 rows: Range<usize>,
961 out: ArrayViewMut2<'_, f64>,
962 ) -> Result<(), MatrixMaterializationError>;
963
964 fn try_row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, MatrixMaterializationError> {
967 let mut out = Array2::<f64>::zeros((rows.end - rows.start, self.ncols()));
968 self.row_chunk_into(rows, out.view_mut())?;
969 Ok(out)
970 }
971
972 fn as_dense_ref(&self) -> Option<&Array2<f64>> {
974 None
975 }
976
977 fn apply_columns(&self, cols: &[usize]) -> Array2<f64> {
985 let n = self.nrows();
986 let p = self.ncols();
987 let mut out = Array2::<f64>::zeros((n, cols.len()));
988 let mut e = Array1::<f64>::zeros(p);
989 for (k, &j) in cols.iter().enumerate() {
990 assert!(
991 j < p,
992 "DenseDesignOperator::apply_columns: column index {j} out of bounds (ncols={p})"
993 );
994 e[j] = 1.0;
995 let col = self.apply(&e);
996 e[j] = 0.0;
997 out.column_mut(k).assign(&col);
998 }
999 out
1000 }
1001
1002 fn to_dense(&self) -> Array2<f64>;
1006
1007 fn estimated_dense_bytes(&self) -> usize {
1008 self.nrows()
1009 .saturating_mul(self.ncols())
1010 .saturating_mul(std::mem::size_of::<f64>())
1011 }
1012
1013 fn try_to_dense_with_policy(
1014 &self,
1015 policy: &MaterializationPolicy,
1016 context: &'static str,
1017 ) -> Result<Arc<Array2<f64>>, MatrixMaterializationError> {
1018 let bytes = self.estimated_dense_bytes();
1019 if !policy.allow_operator_materialization {
1020 return Err(MatrixMaterializationError::Forbidden {
1021 context,
1022 mode: gam_runtime::resource::DerivativeStorageMode::AnalyticOperatorRequired,
1023 });
1024 }
1025 if bytes > policy.max_single_dense_bytes {
1026 return Err(MatrixMaterializationError::TooLarge {
1027 context,
1028 nrows: self.nrows(),
1029 ncols: self.ncols(),
1030 bytes,
1031 limit_bytes: policy.max_single_dense_bytes,
1032 });
1033 }
1034 dense_operator_to_dense_by_chunks(self).map(Arc::new)
1035 }
1036
1037 fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1044 Arc::new(
1045 dense_operator_to_dense_by_chunks(self)
1046 .expect("DenseDesignOperator::to_dense_arc: row-chunk materialization failed"),
1047 )
1048 }
1049}
1050
1051#[derive(Clone)]
1052pub enum DenseDesignMatrix {
1053 Materialized(Arc<Array2<f64>>),
1054 Lazy(Arc<dyn DenseDesignOperator>),
1055}
1056
1057impl std::fmt::Debug for DenseDesignMatrix {
1058 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1059 match self {
1060 Self::Materialized(matrix) => {
1061 write!(
1062 f,
1063 "DenseDesignMatrix::Materialized({}x{})",
1064 matrix.nrows(),
1065 matrix.ncols()
1066 )
1067 }
1068 Self::Lazy(op) => write!(f, "DenseDesignMatrix::Lazy({}x{})", op.nrows(), op.ncols()),
1069 }
1070 }
1071}
1072
1073impl From<Arc<Array2<f64>>> for DenseDesignMatrix {
1074 fn from(value: Arc<Array2<f64>>) -> Self {
1075 Self::Materialized(value)
1076 }
1077}
1078
1079impl From<Array2<f64>> for DenseDesignMatrix {
1080 fn from(value: Array2<f64>) -> Self {
1081 Self::Materialized(Arc::new(value))
1082 }
1083}
1084
1085impl<T> From<Arc<T>> for DenseDesignMatrix
1086where
1087 T: DenseDesignOperator + 'static,
1088{
1089 fn from(value: Arc<T>) -> Self {
1090 Self::Lazy(value)
1091 }
1092}
1093
1094impl DenseDesignMatrix {
1095 pub fn cache_identity(&self) -> usize {
1104 match self {
1105 Self::Materialized(matrix) => Arc::as_ptr(matrix) as *const () as usize,
1106 Self::Lazy(op) => Arc::as_ptr(op) as *const () as usize,
1107 }
1108 }
1109
1110 pub fn nrows(&self) -> usize {
1111 match self {
1112 Self::Materialized(matrix) => matrix.nrows(),
1113 Self::Lazy(op) => op.nrows(),
1114 }
1115 }
1116
1117 pub fn ncols(&self) -> usize {
1118 match self {
1119 Self::Materialized(matrix) => matrix.ncols(),
1120 Self::Lazy(op) => op.ncols(),
1121 }
1122 }
1123
1124 pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
1125 match self {
1126 Self::Materialized(matrix) => Some(matrix.as_ref()),
1127 Self::Lazy(op) => op.as_dense_ref(),
1128 }
1129 }
1130
1131 pub const fn is_materialized_dense(&self) -> bool {
1132 matches!(self, Self::Materialized(_))
1133 }
1134
1135 pub const fn is_operator_backed(&self) -> bool {
1136 matches!(self, Self::Lazy(_))
1137 }
1138
1139 pub fn to_dense(&self) -> Array2<f64> {
1140 match self {
1141 Self::Materialized(matrix) => matrix.as_ref().clone(),
1142 Self::Lazy(op) => {
1151 dense_operator_to_dense_by_chunks(op.as_ref()).unwrap_or_else(|err| {
1152 std::panic::panic_any(format!(
1159 "DenseDesignMatrix::to_dense: failed to materialize {}x{} \
1160 operator-backed design via row chunks: {err}",
1161 op.nrows(),
1162 op.ncols(),
1163 ))
1164 })
1165 }
1166 }
1167 }
1168
1169 pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1170 match self {
1171 Self::Materialized(matrix) => Arc::clone(matrix),
1172 Self::Lazy(op) => Arc::new(
1173 dense_operator_to_dense_by_chunks(op.as_ref()).unwrap_or_else(|err| {
1174 std::panic::panic_any(format!(
1181 "DenseDesignMatrix::to_dense_arc: failed to materialize {}x{} \
1182 operator-backed design via row chunks: {err}",
1183 op.nrows(),
1184 op.ncols(),
1185 ))
1186 }),
1187 ),
1188 }
1189 }
1190
1191 pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
1192 let policy = ResourcePolicy::default_library();
1204 self.try_to_dense_arc_with_policy(context, &policy)
1205 }
1206
1207 pub fn try_to_dense_arc_with_policy(
1219 &self,
1220 context: &str,
1221 policy: &ResourcePolicy,
1222 ) -> Result<Arc<Array2<f64>>, String> {
1223 match self {
1224 Self::Materialized(matrix) => Ok(Arc::clone(matrix)),
1225 Self::Lazy(op) => {
1226 panic_or_error_if_large_scale_mode_and_to_dense_called_with_policy(
1227 context,
1228 op.nrows(),
1229 op.ncols(),
1230 policy,
1231 )?;
1232 dense_operator_to_dense_by_chunks(op.as_ref())
1233 .map(Arc::new)
1234 .map_err(|err| {
1235 format!("{context}: failed to materialize dense row chunks: {err}")
1236 })
1237 }
1238 }
1239 }
1240
1241 pub fn try_row_chunk(
1242 &self,
1243 rows: Range<usize>,
1244 ) -> Result<Array2<f64>, MatrixMaterializationError> {
1245 match self {
1246 Self::Materialized(matrix) => Ok(matrix.slice(s![rows, ..]).to_owned()),
1247 Self::Lazy(op) => op.try_row_chunk(rows),
1248 }
1249 }
1250
1251 pub fn row_chunk_into(
1252 &self,
1253 rows: Range<usize>,
1254 out: ArrayViewMut2<'_, f64>,
1255 ) -> Result<(), MatrixMaterializationError> {
1256 match self {
1257 Self::Materialized(matrix) => {
1258 let mut out = out;
1259 out.assign(&matrix.slice(s![rows, ..]));
1260 Ok(())
1261 }
1262 Self::Lazy(op) => op.row_chunk_into(rows, out),
1263 }
1264 }
1265}
1266
1267impl LinearOperator for DenseDesignMatrix {
1268 fn nrows(&self) -> usize {
1269 DenseDesignMatrix::nrows(self)
1270 }
1271
1272 fn ncols(&self) -> usize {
1273 DenseDesignMatrix::ncols(self)
1274 }
1275
1276 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
1277 match self {
1278 Self::Materialized(matrix) => fast_av(matrix, vector),
1279 Self::Lazy(op) => op.apply(vector),
1280 }
1281 }
1282
1283 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
1284 match self {
1285 Self::Materialized(matrix) => fast_atv(matrix, vector),
1286 Self::Lazy(op) => op.apply_transpose(vector),
1287 }
1288 }
1289
1290 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
1291 match self {
1292 Self::Materialized(matrix) => {
1293 if weights.len() != matrix.nrows() {
1294 return Err(format!(
1295 "DenseDesignMatrix::diag_xtw_x weight length mismatch: weights={}, nrows={}",
1296 weights.len(),
1297 matrix.nrows()
1298 ));
1299 }
1300 let mut xtwx = Array2::<f64>::zeros((matrix.ncols(), matrix.ncols()));
1301 stream_weighted_crossprod_into(
1302 matrix,
1303 weights,
1304 &mut xtwx,
1305 CrossprodStructure::Full,
1306 CrossprodAccum::Replace,
1307 effective_global_parallelism(),
1308 );
1309 Ok(xtwx)
1310 }
1311 Self::Lazy(op) => op.diag_xtw_x(weights),
1312 }
1313 }
1314
1315 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
1316 match self {
1320 Self::Materialized(matrix) => {
1321 let n = matrix.nrows();
1322 let p = matrix.ncols();
1323 if weights.len() != n {
1324 return Err(format!(
1325 "DenseDesignMatrix::diag_gram weight length mismatch: weights={}, nrows={}",
1326 weights.len(),
1327 n
1328 ));
1329 }
1330 if (n as u64) * (p as u64) < DENSE_ROW_PARALLEL_MIN_NP {
1331 let mut diag = Array1::<f64>::zeros(p);
1332 for i in 0..n {
1333 let wi = weights[i];
1334 if wi == 0.0 {
1335 continue;
1336 }
1337 for j in 0..p {
1338 let xij = matrix[[i, j]];
1339 diag[j] += wi * xij * xij;
1340 }
1341 }
1342 return Ok(diag);
1343 }
1344 let diag = (0..n)
1345 .into_par_iter()
1346 .fold(
1347 || Array1::<f64>::zeros(p),
1348 |mut acc, i| {
1349 let wi = weights[i];
1350 if wi != 0.0 {
1351 for j in 0..p {
1352 let xij = matrix[[i, j]];
1353 acc[j] += wi * xij * xij;
1354 }
1355 }
1356 acc
1357 },
1358 )
1359 .reduce(
1360 || Array1::<f64>::zeros(p),
1361 |mut a, b| {
1362 a += &b;
1363 a
1364 },
1365 );
1366 Ok(diag)
1367 }
1368 Self::Lazy(op) => op.diag_gram(weights),
1369 }
1370 }
1371
1372 fn apply_weighted_normal(
1373 &self,
1374 weights: &Array1<f64>,
1375 vector: &Array1<f64>,
1376 penalty: Option<&Array2<f64>>,
1377 ridge: f64,
1378 ) -> Array1<f64> {
1379 assert_eq!(
1380 weights.len(),
1381 self.nrows(),
1382 "DenseDesignMatrix::apply_weighted_normal weight length mismatch"
1383 );
1384 assert_eq!(
1385 vector.len(),
1386 self.ncols(),
1387 "DenseDesignMatrix::apply_weighted_normal vector length mismatch"
1388 );
1389 match self {
1394 Self::Materialized(matrix) => {
1395 let n = matrix.nrows();
1396 let p = matrix.ncols();
1397 let mut out = if (n as u64) * (p as u64) < DENSE_ROW_PARALLEL_MIN_NP {
1398 let mut out = Array1::<f64>::zeros(p);
1399 for i in 0..n {
1400 let wi = weights[i];
1401 if wi == 0.0 {
1402 continue;
1403 }
1404 let mut row_dot = 0.0_f64;
1405 for j in 0..p {
1406 row_dot += matrix[[i, j]] * vector[j];
1407 }
1408 if row_dot == 0.0 {
1409 continue;
1410 }
1411 let scaled = wi * row_dot;
1412 for j in 0..p {
1413 out[j] += scaled * matrix[[i, j]];
1414 }
1415 }
1416 out
1417 } else {
1418 (0..n)
1419 .into_par_iter()
1420 .fold(
1421 || Array1::<f64>::zeros(p),
1422 |mut acc, i| {
1423 let wi = weights[i];
1424 if wi != 0.0 {
1425 let mut row_dot = 0.0_f64;
1426 for j in 0..p {
1427 row_dot += matrix[[i, j]] * vector[j];
1428 }
1429 if row_dot != 0.0 {
1430 let scaled = wi * row_dot;
1431 for j in 0..p {
1432 acc[j] += scaled * matrix[[i, j]];
1433 }
1434 }
1435 }
1436 acc
1437 },
1438 )
1439 .reduce(
1440 || Array1::<f64>::zeros(p),
1441 |mut a, b| {
1442 a += &b;
1443 a
1444 },
1445 )
1446 };
1447 if let Some(pen) = penalty {
1448 out += &fast_av(pen, vector);
1449 }
1450 if ridge > 0.0 {
1451 for j in 0..p {
1452 out[j] += ridge * vector[j];
1453 }
1454 }
1455 out
1456 }
1457 Self::Lazy(op) => op.apply_weighted_normal(weights, vector, penalty, ridge),
1458 }
1459 }
1460
1461 fn uses_matrix_free_pcg(&self) -> bool {
1462 match self {
1463 Self::Materialized(_) => true,
1464 Self::Lazy(op) => op.uses_matrix_free_pcg(),
1465 }
1466 }
1467}
1468
1469impl DenseDesignOperator for DenseDesignMatrix {
1470 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
1471 match self {
1472 Self::Materialized(matrix) => {
1473 if weights.len() != matrix.nrows() || y.len() != matrix.nrows() {
1474 return Err(format!(
1475 "DenseDesignMatrix::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
1476 weights.len(),
1477 y.len(),
1478 matrix.nrows()
1479 ));
1480 }
1481 Ok(dense_transpose_weighted_response(matrix, weights, y, None))
1482 }
1483 Self::Lazy(op) => op.compute_xtwy(weights, y),
1484 }
1485 }
1486
1487 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
1488 match self {
1489 Self::Materialized(matrix) => {
1490 if middle.nrows() != matrix.ncols() || middle.ncols() != matrix.ncols() {
1491 return Err(format!(
1492 "quadratic_form_diag dimension mismatch: matrix is {}x{}, expected {}x{}",
1493 middle.nrows(),
1494 middle.ncols(),
1495 matrix.ncols(),
1496 matrix.ncols()
1497 ));
1498 }
1499 let xc = fast_ab(matrix, middle);
1500 let n = matrix.nrows();
1501 let p = matrix.ncols();
1502 let mut out = Array1::<f64>::zeros(n);
1503 if matrix.is_standard_layout()
1504 && xc.is_standard_layout()
1505 && let (Some(m_all), Some(xc_all), Some(out_slice)) =
1506 (matrix.as_slice(), xc.as_slice(), out.as_slice_mut())
1507 {
1508 use rayon::iter::{IndexedParallelIterator, ParallelIterator};
1513 use rayon::slice::ParallelSliceMut;
1514 out_slice
1515 .par_chunks_mut(1)
1516 .enumerate()
1517 .for_each(|(i, slot)| {
1518 let off = i * p;
1519 let m_row = &m_all[off..off + p];
1520 let xc_row = &xc_all[off..off + p];
1521 let mut acc = 0.0_f64;
1522 for j in 0..p {
1523 acc += m_row[j] * xc_row[j];
1524 }
1525 slot[0] = acc.max(0.0);
1528 });
1529 } else {
1530 for i in 0..n {
1531 out[i] = matrix.row(i).dot(&xc.row(i)).max(0.0);
1534 }
1535 }
1536 Ok(out)
1537 }
1538 Self::Lazy(op) => op.quadratic_form_diag(middle),
1539 }
1540 }
1541
1542 fn as_dense_ref(&self) -> Option<&Array2<f64>> {
1543 DenseDesignMatrix::as_dense_ref(self)
1544 }
1545
1546 fn row_chunk_into(
1547 &self,
1548 rows: Range<usize>,
1549 mut out: ArrayViewMut2<'_, f64>,
1550 ) -> Result<(), MatrixMaterializationError> {
1551 if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
1552 return Err(MatrixMaterializationError::MissingRowChunk {
1553 context: "DenseDesignMatrix::row_chunk_into shape mismatch",
1554 });
1555 }
1556 match self {
1557 Self::Materialized(matrix) => {
1558 out.assign(&matrix.slice(s![rows, ..]));
1559 Ok(())
1560 }
1561 Self::Lazy(op) => op.row_chunk_into(rows, out),
1562 }
1563 }
1564
1565 fn to_dense(&self) -> Array2<f64> {
1566 DenseDesignMatrix::to_dense(self)
1567 }
1568
1569 fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1570 DenseDesignMatrix::to_dense_arc(self)
1571 }
1572}
1573
1574pub struct ReparamOperator {
1589 x_original: DesignMatrix,
1590 qs: Arc<Array2<f64>>,
1591 n: usize,
1592 p: usize,
1593}
1594
1595impl ReparamOperator {
1596 pub fn new(x_original: DesignMatrix, qs: Arc<Array2<f64>>) -> Self {
1597 let n = x_original.nrows();
1598 let p = qs.ncols();
1599 assert_eq!(
1600 x_original.ncols(),
1601 qs.nrows(),
1602 "ReparamOperator: X cols ({}) must match Qs rows ({})",
1603 x_original.ncols(),
1604 qs.nrows()
1605 );
1606 Self {
1607 x_original,
1608 qs,
1609 n,
1610 p,
1611 }
1612 }
1613
1614 pub fn x_original(&self) -> &DesignMatrix {
1616 &self.x_original
1617 }
1618
1619 pub fn qs(&self) -> &Array2<f64> {
1621 &self.qs
1622 }
1623}
1624
1625impl LinearOperator for ReparamOperator {
1626 fn nrows(&self) -> usize {
1627 self.n
1628 }
1629
1630 fn ncols(&self) -> usize {
1631 self.p
1632 }
1633
1634 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
1635 let qv = self.qs.dot(vector);
1637 self.x_original.apply(&qv)
1638 }
1639
1640 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
1641 let xtv = self.x_original.apply_transpose(vector);
1643 fast_atv(&self.qs, &xtv)
1644 }
1645
1646 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
1647 let xtwx = self.x_original.diag_xtw_x(weights)?;
1650 let tmp = fast_atb(&self.qs, &xtwx);
1651 Ok(fast_ab(&tmp, &self.qs))
1652 }
1653
1654 fn apply_weighted_normal(
1655 &self,
1656 weights: &Array1<f64>,
1657 vector: &Array1<f64>,
1658 penalty: Option<&Array2<f64>>,
1659 ridge: f64,
1660 ) -> Array1<f64> {
1661 assert_eq!(
1662 weights.len(),
1663 self.x_original.nrows(),
1664 "ReparamOperator::apply_weighted_normal weight length mismatch"
1665 );
1666 assert_eq!(
1667 vector.len(),
1668 self.qs.ncols(),
1669 "ReparamOperator::apply_weighted_normal vector length mismatch"
1670 );
1671 let qv = self.qs.dot(vector);
1676 let xqv = self.x_original.apply(&qv);
1677 let mut wxqv = xqv;
1678 for i in 0..wxqv.len() {
1679 wxqv[i] *= weights[i];
1680 }
1681 let xtw = self.x_original.apply_transpose(&wxqv);
1682 let mut out = fast_atv(&self.qs, &xtw);
1683 if let Some(pen) = penalty {
1684 out += &fast_av(pen, vector);
1685 }
1686 if ridge > 0.0 {
1687 out.scaled_add(ridge, vector);
1689 }
1690 out
1691 }
1692}
1693
1694impl DenseDesignOperator for ReparamOperator {
1695 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
1696 let xtwy = self.x_original.compute_xtwy(weights, y)?;
1698 Ok(fast_atv(&self.qs, &xtwy))
1699 }
1700
1701 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
1702 let qm = fast_ab(&self.qs, middle);
1705 let m_orig = fast_ab(&qm, &self.qs.t().to_owned());
1706 self.x_original.quadratic_form_diag(&m_orig)
1707 }
1708
1709 fn to_dense(&self) -> Array2<f64> {
1710 match &self.x_original {
1711 DesignMatrix::Dense(x) => fast_ab(x.to_dense_arc().as_ref(), &self.qs),
1712 _ => {
1713 let x_dense = self.x_original.to_dense();
1714 fast_ab(&x_dense, &self.qs)
1715 }
1716 }
1717 }
1718
1719 fn to_dense_arc(&self) -> Arc<Array2<f64>> {
1720 Arc::new(self.to_dense())
1721 }
1722
1723 fn as_dense_ref(&self) -> Option<&Array2<f64>> {
1724 None
1725 }
1726
1727 fn apply_columns(&self, cols: &[usize]) -> Array2<f64> {
1728 let qs_cols = self.qs.select(Axis(1), cols);
1731 match &self.x_original {
1732 DesignMatrix::Dense(x) => match x.as_dense_ref() {
1733 Some(x_dense) => fast_ab(x_dense, &qs_cols),
1734 None => {
1735 let n = self.n;
1736 let mut out = Array2::<f64>::zeros((n, cols.len()));
1737 for k in 0..cols.len() {
1738 let col = qs_cols.column(k).to_owned();
1739 let xc = self.x_original.apply(&col);
1740 out.column_mut(k).assign(&xc);
1741 }
1742 out
1743 }
1744 },
1745 DesignMatrix::Sparse(_) => {
1746 let n = self.n;
1748 let mut out = Array2::<f64>::zeros((n, cols.len()));
1749 for k in 0..cols.len() {
1750 let col = qs_cols.column(k).to_owned();
1751 let xc = self.x_original.apply(&col);
1752 out.column_mut(k).assign(&xc);
1753 }
1754 out
1755 }
1756 }
1757 }
1758
1759 fn row_chunk_into(
1760 &self,
1761 rows: Range<usize>,
1762 mut out: ArrayViewMut2<'_, f64>,
1763 ) -> Result<(), MatrixMaterializationError> {
1764 if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
1765 return Err(MatrixMaterializationError::MissingRowChunk {
1766 context: "ReparamOperator::row_chunk_into shape mismatch",
1767 });
1768 }
1769 match &self.x_original {
1770 DesignMatrix::Dense(x) => {
1771 let chunk = x.try_row_chunk(rows)?;
1772 out.assign(&fast_ab(&chunk, &self.qs));
1773 }
1774 DesignMatrix::Sparse(sdm) => {
1775 let csr = sdm
1777 .to_csr_arc()
1778 .ok_or(MatrixMaterializationError::MissingRowChunk {
1779 context: "ReparamOperator::row_chunk_into: failed to obtain CSR view",
1780 })?;
1781 let sym = csr.symbolic();
1782 let row_ptr = sym.row_ptr();
1783 let col_idx = sym.col_idx();
1784 let vals = csr.val();
1785 let chunk_rows = rows.end - rows.start;
1786 let p_inner = sdm.ncols();
1787 let mut chunk = Array2::<f64>::zeros((chunk_rows, p_inner));
1788 for (local, global) in (rows.start..rows.end).enumerate() {
1789 for ptr in row_ptr[global]..row_ptr[global + 1] {
1790 chunk[[local, col_idx[ptr]]] = vals[ptr];
1791 }
1792 }
1793 out.assign(&fast_ab(&chunk, &self.qs));
1794 }
1795 }
1796 Ok(())
1797 }
1798}
1799
1800#[derive(Clone)]
1810pub struct RandomEffectOperator {
1811 pub group_ids: Vec<Option<usize>>,
1815 pub n: usize,
1817 pub num_groups: usize,
1819}
1820
1821impl RandomEffectOperator {
1822 pub fn new(group_ids: Vec<Option<usize>>, num_groups: usize) -> Self {
1823 let n = group_ids.len();
1824 Self {
1825 group_ids,
1826 n,
1827 num_groups,
1828 }
1829 }
1830
1831 pub fn weighted_cross_with_dense(
1837 &self,
1838 dense: &Array2<f64>,
1839 weights: &Array1<f64>,
1840 ) -> Array2<f64> {
1841 assert_eq!(
1842 dense.nrows(),
1843 self.n,
1844 "RandomEffectOperator::weighted_cross_with_dense row mismatch"
1845 );
1846 assert_eq!(
1847 weights.len(),
1848 self.n,
1849 "RandomEffectOperator::weighted_cross_with_dense weight length mismatch"
1850 );
1851 let p_dense = dense.ncols();
1852 let mut cross = Array2::<f64>::zeros((p_dense, self.num_groups));
1853 for i in 0..self.n {
1854 if let Some(g) = self.group_ids[i] {
1855 let wi = weights[i].max(0.0);
1856 if wi == 0.0 {
1857 continue;
1858 }
1859 for j in 0..p_dense {
1860 cross[[j, g]] += wi * dense[[i, j]];
1861 }
1862 }
1863 }
1864 cross
1865 }
1866
1867 pub fn weighted_cross_with_re(
1871 &self,
1872 other: &RandomEffectOperator,
1873 weights: &Array1<f64>,
1874 ) -> Array2<f64> {
1875 assert_eq!(
1876 other.n, self.n,
1877 "RandomEffectOperator::weighted_cross_with_re row mismatch"
1878 );
1879 assert_eq!(
1880 weights.len(),
1881 self.n,
1882 "RandomEffectOperator::weighted_cross_with_re weight length mismatch"
1883 );
1884 let mut cross = Array2::<f64>::zeros((self.num_groups, other.num_groups));
1885 for i in 0..self.n {
1886 if let (Some(a), Some(b)) = (self.group_ids[i], other.group_ids[i]) {
1887 let wi = weights[i].max(0.0);
1888 if wi != 0.0 {
1889 cross[[a, b]] += wi;
1890 }
1891 }
1892 }
1893 cross
1894 }
1895}
1896
1897impl LinearOperator for RandomEffectOperator {
1898 fn nrows(&self) -> usize {
1899 self.n
1900 }
1901
1902 fn ncols(&self) -> usize {
1903 self.num_groups
1904 }
1905
1906 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
1908 use rayon::prelude::*;
1909 let out: Vec<f64> = self
1910 .group_ids
1911 .par_iter()
1912 .map(|g| g.map(|g| vector[g]).unwrap_or(0.0))
1913 .collect();
1914 Array1::from(out)
1915 }
1916
1917 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
1919 let mut out = Array1::<f64>::zeros(self.num_groups);
1920 for i in 0..self.n {
1921 if let Some(g) = self.group_ids[i] {
1922 out[g] += vector[i];
1923 }
1924 }
1925 out
1926 }
1927
1928 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
1930 if weights.len() != self.n {
1931 return Err(format!(
1932 "RandomEffectOperator::diag_xtw_x weight length mismatch: weights={}, nrows={}",
1933 weights.len(),
1934 self.n
1935 ));
1936 }
1937 let q = self.num_groups;
1938 let mut xtwx = Array2::<f64>::zeros((q, q));
1939 for i in 0..self.n {
1940 if let Some(g) = self.group_ids[i] {
1941 xtwx[[g, g]] += weights[i].max(0.0);
1942 }
1943 }
1944 Ok(xtwx)
1945 }
1946
1947 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
1949 if weights.len() != self.n {
1950 return Err(format!(
1951 "RandomEffectOperator::diag_gram weight length mismatch: weights={}, nrows={}",
1952 weights.len(),
1953 self.n
1954 ));
1955 }
1956 let mut diag = Array1::<f64>::zeros(self.num_groups);
1957 for i in 0..self.n {
1958 if let Some(g) = self.group_ids[i] {
1959 diag[g] += weights[i].max(0.0);
1960 }
1961 }
1962 Ok(diag)
1963 }
1964
1965 fn apply_weighted_normal(
1967 &self,
1968 weights: &Array1<f64>,
1969 vector: &Array1<f64>,
1970 penalty: Option<&Array2<f64>>,
1971 ridge: f64,
1972 ) -> Array1<f64> {
1973 assert_eq!(
1974 weights.len(),
1975 self.n,
1976 "RandomEffectOperator::apply_weighted_normal weight length mismatch"
1977 );
1978 assert_eq!(
1979 vector.len(),
1980 self.num_groups,
1981 "RandomEffectOperator::apply_weighted_normal vector length mismatch"
1982 );
1983 let mut group_wacc = Array1::<f64>::zeros(self.num_groups);
1987 for i in 0..self.n {
1988 if let Some(g) = self.group_ids[i] {
1989 group_wacc[g] += weights[i].max(0.0);
1990 }
1991 }
1992 let mut out = Array1::<f64>::zeros(self.num_groups);
1993 for g in 0..self.num_groups {
1994 out[g] = group_wacc[g] * vector[g];
1995 }
1996 if let Some(pen) = penalty {
1997 out += &pen.dot(vector);
1998 }
1999 if ridge > 0.0 {
2000 for g in 0..self.num_groups {
2001 out[g] += ridge * vector[g];
2002 }
2003 }
2004 out
2005 }
2006
2007 fn uses_matrix_free_pcg(&self) -> bool {
2008 true
2009 }
2010}
2011
2012impl DenseDesignOperator for RandomEffectOperator {
2013 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
2014 if weights.len() != self.n || y.len() != self.n {
2015 return Err(format!(
2016 "RandomEffectOperator::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
2017 weights.len(),
2018 y.len(),
2019 self.n
2020 ));
2021 }
2022 let mut out = Array1::<f64>::zeros(self.num_groups);
2023 for i in 0..self.n {
2024 if let Some(g) = self.group_ids[i] {
2025 let wi = weights[i].max(0.0);
2026 out[g] += wi * y[i];
2027 }
2028 }
2029 Ok(out)
2030 }
2031
2032 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
2034 use rayon::prelude::*;
2035 let out: Vec<f64> = self
2036 .group_ids
2037 .par_iter()
2038 .map(|g| g.map(|g| middle[[g, g]].max(0.0)).unwrap_or(0.0))
2039 .collect();
2040 Ok(Array1::from(out))
2041 }
2042
2043 fn row_chunk_into(
2044 &self,
2045 rows: Range<usize>,
2046 mut out: ArrayViewMut2<'_, f64>,
2047 ) -> Result<(), MatrixMaterializationError> {
2048 if out.nrows() != rows.end - rows.start || out.ncols() != self.num_groups {
2049 return Err(MatrixMaterializationError::MissingRowChunk {
2050 context: "RandomEffectOperator::row_chunk_into shape mismatch",
2051 });
2052 }
2053 out.fill(0.0);
2054 for (local, global) in rows.enumerate() {
2055 if let Some(g) = self.group_ids[global] {
2056 out[[local, g]] = 1.0;
2057 }
2058 }
2059 Ok(())
2060 }
2061
2062 fn to_dense(&self) -> Array2<f64> {
2064 let mut out = Array2::<f64>::zeros((self.n, self.num_groups));
2065 ndarray::Zip::indexed(out.rows_mut()).par_for_each(|i, mut row| {
2066 if let Some(g) = self.group_ids[i] {
2067 row[g] = 1.0;
2068 }
2069 });
2070 out
2071 }
2072}
2073
2074#[derive(Clone)]
2080pub enum DesignBlock {
2081 Dense(DenseDesignMatrix),
2082 Sparse(SparseDesignMatrix),
2083 RandomEffect(Arc<RandomEffectOperator>),
2084 Intercept(usize),
2086}
2087
2088impl DesignBlock {
2089 pub fn nrows(&self) -> usize {
2090 match self {
2091 Self::Dense(d) => d.nrows(),
2092 Self::Sparse(s) => s.nrows(),
2093 Self::RandomEffect(op) => op.nrows(),
2094 Self::Intercept(n) => *n,
2095 }
2096 }
2097
2098 pub fn ncols(&self) -> usize {
2099 match self {
2100 Self::Dense(d) => d.ncols(),
2101 Self::Sparse(s) => s.ncols(),
2102 Self::RandomEffect(op) => op.ncols(),
2103 Self::Intercept(_) => 1,
2104 }
2105 }
2106
2107 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
2108 match self {
2109 Self::Dense(d) => d.apply(vector),
2110 Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).apply(vector),
2111 Self::RandomEffect(op) => op.apply(vector),
2112 Self::Intercept(n) => Array1::from_elem(*n, vector[0]),
2113 }
2114 }
2115
2116 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
2117 match self {
2118 Self::Dense(d) => d.apply_transpose(vector),
2119 Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).apply_transpose(vector),
2120 Self::RandomEffect(op) => op.apply_transpose(vector),
2121 Self::Intercept(_) => {
2122 let sum: f64 = vector.iter().sum();
2123 Array1::from_vec(vec![sum])
2124 }
2125 }
2126 }
2127
2128 fn try_row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, MatrixMaterializationError> {
2129 match self {
2130 Self::Dense(d) => d.try_row_chunk(rows),
2131 Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).try_row_chunk(rows),
2132 Self::RandomEffect(op) => op.try_row_chunk(rows),
2133 Self::Intercept(_) => Ok(Array2::ones((rows.end - rows.start, 1))),
2134 }
2135 }
2136
2137 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
2138 match self {
2139 Self::Dense(d) => d.diag_xtw_x(weights),
2140 Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).diag_xtw_x(weights),
2141 Self::RandomEffect(op) => op.diag_xtw_x(weights),
2142 Self::Intercept(_) => {
2143 let sum: f64 = weights.iter().map(|w| w.max(0.0)).sum();
2144 Ok(Array2::from_elem((1, 1), sum))
2145 }
2146 }
2147 }
2148
2149 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
2150 match self {
2151 Self::Dense(d) => d.diag_gram(weights),
2152 Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).diag_gram(weights),
2153 Self::RandomEffect(op) => op.diag_gram(weights),
2154 Self::Intercept(_) => {
2155 let sum: f64 = weights.iter().map(|w| w.max(0.0)).sum();
2156 Ok(Array1::from_vec(vec![sum]))
2157 }
2158 }
2159 }
2160
2161 fn to_dense(&self) -> Array2<f64> {
2163 match self {
2164 Self::Dense(d) => d.to_dense(),
2165 Self::Sparse(s) => s.to_dense_arc().as_ref().clone(),
2166 Self::RandomEffect(op) => op.to_dense(),
2167 Self::Intercept(n) => Array2::ones((*n, 1)),
2168 }
2169 }
2170}
2171
2172#[derive(Clone)]
2179pub struct BlockDesignOperator {
2180 pub blocks: Vec<DesignBlock>,
2181 pub col_offsets: Vec<usize>,
2183 pub total_cols: usize,
2184 pub n: usize,
2185}
2186
2187impl BlockDesignOperator {
2188 pub fn new(blocks: Vec<DesignBlock>) -> Result<Self, String> {
2189 if blocks.is_empty() {
2190 return Err("BlockDesignOperator: need at least one block".to_string());
2191 }
2192 let n = blocks[0].nrows();
2193 for (i, b) in blocks.iter().enumerate() {
2194 if b.nrows() != n {
2195 return Err(format!(
2196 "BlockDesignOperator: block {i} has {} rows, expected {n}",
2197 b.nrows()
2198 ));
2199 }
2200 }
2201 let mut col_offsets = Vec::with_capacity(blocks.len() + 1);
2202 col_offsets.push(0);
2203 for b in &blocks {
2204 col_offsets.push(col_offsets.last().unwrap() + b.ncols());
2205 }
2206 let total_cols = *col_offsets.last().unwrap();
2207 Ok(Self {
2208 blocks,
2209 col_offsets,
2210 total_cols,
2211 n,
2212 })
2213 }
2214
2215 fn weighted_cross_chunked(
2216 &self,
2217 left: &DesignBlock,
2218 right: &DesignBlock,
2219 weights: &Array1<f64>,
2220 ) -> Result<Array2<f64>, String> {
2221 let pi = left.ncols();
2222 let pj = right.ncols();
2223 let mut cross = Array2::<f64>::zeros((pi, pj));
2224 for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2225 let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2226 let left_chunk = left.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2227 let right_chunk = right.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2228 for local in 0..(end - start) {
2229 let wi = weights[start + local];
2237 if wi == 0.0 {
2238 continue;
2239 }
2240 for a in 0..pi {
2241 let scaled = wi * left_chunk[[local, a]];
2242 if scaled == 0.0 {
2243 continue;
2244 }
2245 for b in 0..pj {
2246 cross[[a, b]] += scaled * right_chunk[[local, b]];
2247 }
2248 }
2249 }
2250 }
2251 Ok(cross)
2252 }
2253
2254 fn quadratic_form_diag_cross_chunked(
2255 &self,
2256 block_a: &DesignBlock,
2257 block_b: &DesignBlock,
2258 m_ab: &Array2<f64>,
2259 ) -> Result<Array1<f64>, String> {
2260 let mut out = Array1::<f64>::zeros(self.n);
2261 for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2262 let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2263 let a_chunk = block_a
2264 .try_row_chunk(start..end)
2265 .map_err(|e| e.to_string())?;
2266 let b_chunk = block_b
2267 .try_row_chunk(start..end)
2268 .map_err(|e| e.to_string())?;
2269 let a_m = fast_ab(&a_chunk, m_ab);
2270 for local in 0..(end - start) {
2271 out[start + local] = a_m.row(local).dot(&b_chunk.row(local));
2272 }
2273 }
2274 Ok(out)
2275 }
2276
2277 fn cross_block(
2279 &self,
2280 i: usize,
2281 j: usize,
2282 weights: &Array1<f64>,
2283 ) -> Result<Array2<f64>, String> {
2284 match (&self.blocks[i], &self.blocks[j]) {
2285 (DesignBlock::Dense(d_i), DesignBlock::Dense(d_j)) => {
2287 if let (Some(xi), Some(xj)) = (d_i.as_dense_ref(), d_j.as_dense_ref()) {
2288 weighted_crossprod_dense(xi, weights, xj)
2289 } else {
2290 self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2291 }
2292 }
2293 (DesignBlock::Dense(_), DesignBlock::Sparse(_))
2294 | (DesignBlock::Sparse(_), DesignBlock::Dense(_))
2295 | (DesignBlock::Sparse(_), DesignBlock::Sparse(_))
2296 | (DesignBlock::Sparse(_), DesignBlock::RandomEffect(_))
2297 | (DesignBlock::RandomEffect(_), DesignBlock::Sparse(_)) => {
2298 self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2299 }
2300
2301 (DesignBlock::Dense(d), DesignBlock::RandomEffect(re)) => {
2303 if let Some(dense) = d.as_dense_ref() {
2304 Ok(re.weighted_cross_with_dense(dense, weights))
2305 } else {
2306 self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2307 }
2308 }
2309 (DesignBlock::RandomEffect(re), DesignBlock::Dense(d)) => {
2310 if let Some(dense) = d.as_dense_ref() {
2311 let cross_t = re.weighted_cross_with_dense(dense, weights);
2312 Ok(cross_t.t().to_owned())
2313 } else {
2314 self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
2315 }
2316 }
2317
2318 (DesignBlock::RandomEffect(re_a), DesignBlock::RandomEffect(re_b)) => {
2320 Ok(re_a.weighted_cross_with_re(re_b, weights))
2321 }
2322
2323 (DesignBlock::Intercept(_), other) => {
2326 let pj = other.ncols();
2327 let mut cross = Array2::<f64>::zeros((1, pj));
2328 let weighted = Array1::from_shape_fn(self.n, |idx| weights[idx].max(0.0));
2329 let row = other.apply_transpose(&weighted);
2330 cross.row_mut(0).assign(&row);
2331 Ok(cross)
2332 }
2333 (other, DesignBlock::Intercept(_)) => {
2334 let pi = other.ncols();
2335 let mut cross = Array2::<f64>::zeros((pi, 1));
2336 let weighted = Array1::from_shape_fn(self.n, |idx| weights[idx].max(0.0));
2337 let col = other.apply_transpose(&weighted);
2338 cross.column_mut(0).assign(&col);
2339 Ok(cross)
2340 }
2341 }
2342 }
2343
2344 fn quadratic_form_diag_block(
2346 &self,
2347 block: &DesignBlock,
2348 m_kk: &Array2<f64>,
2349 ) -> Result<Array1<f64>, String> {
2350 match block {
2351 DesignBlock::Dense(d) => {
2352 if let Some(dense) = d.as_dense_ref() {
2353 let xm = fast_ab(dense, m_kk);
2354 let mut out = Array1::<f64>::zeros(self.n);
2355 ndarray::Zip::from(&mut out)
2356 .and(dense.rows())
2357 .and(xm.rows())
2358 .par_for_each(|o, dr, xmr| *o = dr.dot(&xmr));
2359 Ok(out)
2360 } else {
2361 d.quadratic_form_diag(m_kk)
2362 }
2363 }
2364 DesignBlock::Sparse(s) => {
2365 let sparse = DesignMatrix::Sparse(s.clone());
2366 sparse.quadratic_form_diag(m_kk)
2367 }
2368 DesignBlock::RandomEffect(re) => {
2369 use rayon::prelude::*;
2370 let out: Vec<f64> = re
2371 .group_ids
2372 .par_iter()
2373 .map(|g| g.map(|g| m_kk[[g, g]]).unwrap_or(0.0))
2374 .collect();
2375 Ok(Array1::from(out))
2376 }
2377 DesignBlock::Intercept(_) => {
2378 Ok(Array1::from_elem(self.n, m_kk[[0, 0]]))
2380 }
2381 }
2382 }
2383
2384 fn quadratic_form_diag_cross(
2386 &self,
2387 block_a: &DesignBlock,
2388 block_b: &DesignBlock,
2389 m_ab: &Array2<f64>,
2390 ) -> Result<Array1<f64>, String> {
2391 match (block_a, block_b) {
2392 (DesignBlock::Dense(da), DesignBlock::Dense(db)) => {
2393 if let (Some(da), Some(db)) = (da.as_dense_ref(), db.as_dense_ref()) {
2394 let da_m = fast_ab(da, m_ab);
2395 let mut out = Array1::<f64>::zeros(self.n);
2396 ndarray::Zip::from(&mut out)
2397 .and(da_m.rows())
2398 .and(db.rows())
2399 .par_for_each(|o, ar, br| *o = ar.dot(&br));
2400 Ok(out)
2401 } else {
2402 self.quadratic_form_diag_cross_chunked(block_a, block_b, m_ab)
2403 }
2404 }
2405 (DesignBlock::Dense(_), DesignBlock::Sparse(_))
2406 | (DesignBlock::Sparse(_), DesignBlock::Dense(_))
2407 | (DesignBlock::Sparse(_), DesignBlock::Sparse(_))
2408 | (DesignBlock::Sparse(_), DesignBlock::RandomEffect(_))
2409 | (DesignBlock::RandomEffect(_), DesignBlock::Sparse(_)) => {
2410 self.quadratic_form_diag_cross_chunked(block_a, block_b, m_ab)
2411 }
2412 (DesignBlock::Dense(d), DesignBlock::RandomEffect(re)) => {
2413 let mut out = Array1::<f64>::zeros(self.n);
2414 for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2415 let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2416 let chunk = d.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2417 for local in 0..chunk.nrows() {
2418 let i = start + local;
2419 if let Some(g) = re.group_ids[i] {
2420 let mut val = 0.0;
2421 for j in 0..chunk.ncols() {
2422 val += chunk[[local, j]] * m_ab[[j, g]];
2423 }
2424 out[i] = val;
2425 }
2426 }
2427 }
2428 Ok(out)
2429 }
2430 (DesignBlock::RandomEffect(re), DesignBlock::Dense(d)) => {
2431 let mut out = Array1::<f64>::zeros(self.n);
2432 for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2433 let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2434 let chunk = d.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2435 for local in 0..chunk.nrows() {
2436 let i = start + local;
2437 if let Some(g) = re.group_ids[i] {
2438 let mut val = 0.0;
2439 for j in 0..chunk.ncols() {
2440 val += m_ab[[g, j]] * chunk[[local, j]];
2441 }
2442 out[i] = val;
2443 }
2444 }
2445 }
2446 Ok(out)
2447 }
2448 (DesignBlock::RandomEffect(re_a), DesignBlock::RandomEffect(re_b)) => {
2449 use rayon::prelude::*;
2450 let out: Vec<f64> = re_a
2451 .group_ids
2452 .par_iter()
2453 .zip(re_b.group_ids.par_iter())
2454 .map(|(ga, gb)| match (ga, gb) {
2455 (Some(ga), Some(gb)) => m_ab[[*ga, *gb]],
2456 _ => 0.0,
2457 })
2458 .collect();
2459 Ok(Array1::from(out))
2460 }
2461
2462 (DesignBlock::Intercept(_), other) => {
2464 let m_row = m_ab.row(0);
2465 let mut out = Array1::<f64>::zeros(self.n);
2466 for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2467 let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2468 let chunk = other.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2469 for local in 0..(end - start) {
2470 out[start + local] = chunk.row(local).dot(&m_row);
2471 }
2472 }
2473 Ok(out)
2474 }
2475 (other, DesignBlock::Intercept(_)) => {
2476 let m_col = m_ab.column(0);
2477 let mut out = Array1::<f64>::zeros(self.n);
2478 for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
2479 let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
2480 let chunk = other.try_row_chunk(start..end).map_err(|e| e.to_string())?;
2481 for local in 0..(end - start) {
2482 out[start + local] = chunk.row(local).dot(&m_col);
2483 }
2484 }
2485 Ok(out)
2486 }
2487 }
2488 }
2489}
2490
2491impl LinearOperator for BlockDesignOperator {
2492 fn nrows(&self) -> usize {
2493 self.n
2494 }
2495
2496 fn ncols(&self) -> usize {
2497 self.total_cols
2498 }
2499
2500 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
2501 let mut out = Array1::<f64>::zeros(self.n);
2502 for (idx, block) in self.blocks.iter().enumerate() {
2503 let start = self.col_offsets[idx];
2504 let end = self.col_offsets[idx + 1];
2505 let slice = vector.slice(s![start..end]).to_owned();
2506 let contribution = block.apply(&slice);
2507 out += &contribution;
2508 }
2509 out
2510 }
2511
2512 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
2513 let mut out = Array1::<f64>::zeros(self.total_cols);
2514 for (idx, block) in self.blocks.iter().enumerate() {
2515 let start = self.col_offsets[idx];
2516 let end = self.col_offsets[idx + 1];
2517 let transposed = block.apply_transpose(vector);
2518 out.slice_mut(s![start..end]).assign(&transposed);
2519 }
2520 out
2521 }
2522
2523 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
2524 if weights.len() != self.n {
2525 return Err(format!(
2526 "BlockDesignOperator::diag_xtw_x weight length mismatch: weights={}, nrows={}",
2527 weights.len(),
2528 self.n
2529 ));
2530 }
2531 let p = self.total_cols;
2532 let mut result = Array2::<f64>::zeros((p, p));
2533
2534 for (idx, block) in self.blocks.iter().enumerate() {
2536 let start = self.col_offsets[idx];
2537 let end = self.col_offsets[idx + 1];
2538 let block_xtwx = block.diag_xtw_x(weights)?;
2539 result
2540 .slice_mut(s![start..end, start..end])
2541 .assign(&block_xtwx);
2542 }
2543
2544 let weighted_dense: Vec<Option<Array2<f64>>> = self
2560 .blocks
2561 .iter()
2562 .map(|block| match block {
2563 DesignBlock::Dense(d) => d.as_dense_ref().map(|x| {
2564 x * &weights.view().insert_axis(Axis(1))
2567 }),
2568 _ => None,
2569 })
2570 .collect();
2571
2572 for i in 0..self.blocks.len() {
2573 for j in (i + 1)..self.blocks.len() {
2574 let cross = match (&weighted_dense[i], &self.blocks[j]) {
2575 (Some(wx_i), DesignBlock::Dense(d_j)) => match d_j.as_dense_ref() {
2578 Some(x_j) => fast_atb(wx_i, x_j),
2579 None => self.cross_block(i, j, weights)?,
2580 },
2581 _ => self.cross_block(i, j, weights)?,
2582 };
2583 let si = self.col_offsets[i];
2584 let ei = self.col_offsets[i + 1];
2585 let sj = self.col_offsets[j];
2586 let ej = self.col_offsets[j + 1];
2587 result.slice_mut(s![si..ei, sj..ej]).assign(&cross);
2588 result.slice_mut(s![sj..ej, si..ei]).assign(&cross.t());
2589 }
2590 }
2591
2592 Ok(result)
2593 }
2594
2595 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
2596 if weights.len() != self.n {
2597 return Err(format!(
2598 "BlockDesignOperator::diag_gram weight length mismatch: weights={}, nrows={}",
2599 weights.len(),
2600 self.n
2601 ));
2602 }
2603 let mut out = Array1::<f64>::zeros(self.total_cols);
2604 for (idx, block) in self.blocks.iter().enumerate() {
2605 let start = self.col_offsets[idx];
2606 let end = self.col_offsets[idx + 1];
2607 let block_diag = block.diag_gram(weights)?;
2608 out.slice_mut(s![start..end]).assign(&block_diag);
2609 }
2610 Ok(out)
2611 }
2612
2613 fn apply_weighted_normal(
2614 &self,
2615 weights: &Array1<f64>,
2616 vector: &Array1<f64>,
2617 penalty: Option<&Array2<f64>>,
2618 ridge: f64,
2619 ) -> Array1<f64> {
2620 assert_eq!(
2621 weights.len(),
2622 self.n,
2623 "BlockDesignOperator::apply_weighted_normal weight length mismatch"
2624 );
2625 assert_eq!(
2626 vector.len(),
2627 self.total_cols,
2628 "BlockDesignOperator::apply_weighted_normal vector length mismatch"
2629 );
2630 let xv = self.apply(vector);
2632 let mut weighted = xv;
2633 for i in 0..weighted.len() {
2634 weighted[i] *= weights[i].max(0.0);
2635 }
2636 let mut out = self.apply_transpose(&weighted);
2637 if let Some(pen) = penalty {
2638 out += &fast_av(pen, vector);
2639 }
2640 if ridge > 0.0 {
2641 out.scaled_add(ridge, vector);
2643 }
2644 out
2645 }
2646
2647 fn uses_matrix_free_pcg(&self) -> bool {
2648 self.blocks
2650 .iter()
2651 .any(|b| matches!(b, DesignBlock::RandomEffect(_) | DesignBlock::Intercept(_)))
2652 }
2653}
2654
2655impl DenseDesignOperator for BlockDesignOperator {
2656 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
2657 if weights.len() != self.n || y.len() != self.n {
2658 return Err(format!(
2659 "BlockDesignOperator::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
2660 weights.len(),
2661 y.len(),
2662 self.n
2663 ));
2664 }
2665 let mut wy = Array1::<f64>::zeros(self.n);
2666 ndarray::Zip::from(&mut wy)
2667 .and(weights)
2668 .and(y)
2669 .par_for_each(|o, &w, &yi| *o = w.max(0.0) * yi);
2670 Ok(self.apply_transpose(&wy))
2671 }
2672
2673 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
2674 let mut out = Array1::<f64>::zeros(self.n);
2677 let nb = self.blocks.len();
2678
2679 for k in 0..nb {
2681 let sk = self.col_offsets[k];
2682 let ek = self.col_offsets[k + 1];
2683 let m_kk = middle.slice(s![sk..ek, sk..ek]).to_owned();
2684 let block_diag = self.quadratic_form_diag_block(&self.blocks[k], &m_kk)?;
2685 out += &block_diag;
2686 }
2687
2688 for a in 0..nb {
2690 for b in (a + 1)..nb {
2691 let sa = self.col_offsets[a];
2692 let ea = self.col_offsets[a + 1];
2693 let sb = self.col_offsets[b];
2694 let eb = self.col_offsets[b + 1];
2695 let m_ab = middle.slice(s![sa..ea, sb..eb]);
2696
2697 let cross_diag = self.quadratic_form_diag_cross(
2698 &self.blocks[a],
2699 &self.blocks[b],
2700 &m_ab.to_owned(),
2701 )?;
2702 for i in 0..self.n {
2703 out[i] += 2.0 * cross_diag[i];
2704 }
2705 }
2706 }
2707
2708 for v in out.iter_mut() {
2710 *v = v.max(0.0);
2711 }
2712 Ok(out)
2713 }
2714
2715 fn row_chunk_into(
2716 &self,
2717 rows: Range<usize>,
2718 mut out: ArrayViewMut2<'_, f64>,
2719 ) -> Result<(), MatrixMaterializationError> {
2720 if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
2721 return Err(MatrixMaterializationError::MissingRowChunk {
2722 context: "BlockDesignOperator::row_chunk_into shape mismatch",
2723 });
2724 }
2725 for (idx, block) in self.blocks.iter().enumerate() {
2726 let cs = self.col_offsets[idx];
2727 let ce = self.col_offsets[idx + 1];
2728 let block_chunk = block.try_row_chunk(rows.clone())?;
2729 out.slice_mut(s![.., cs..ce]).assign(&block_chunk);
2730 }
2731 Ok(())
2732 }
2733
2734 fn to_dense(&self) -> Array2<f64> {
2735 let mut out = Array2::<f64>::zeros((self.n, self.total_cols));
2736 for (idx, block) in self.blocks.iter().enumerate() {
2737 let start = self.col_offsets[idx];
2738 let end = self.col_offsets[idx + 1];
2739 let dense_block = block.to_dense();
2740 out.slice_mut(s![.., start..end]).assign(&dense_block);
2741 }
2742 out
2743 }
2744}
2745
2746#[derive(Clone)]
2760pub struct MultiChannelOperator {
2761 pub channels: Vec<DesignMatrix>,
2763 pub n_per_channel: usize,
2765 pub p: usize,
2767}
2768
2769impl MultiChannelOperator {
2770 pub fn new(channels: Vec<DesignMatrix>) -> Result<Self, String> {
2771 if channels.is_empty() {
2772 return Err("MultiChannelOperator: need at least one channel".to_string());
2773 }
2774 let n = channels[0].nrows();
2775 let p = channels[0].ncols();
2776 for (i, ch) in channels.iter().enumerate() {
2777 if ch.nrows() != n {
2778 return Err(format!(
2779 "MultiChannelOperator: channel {i} has {} rows, expected {n}",
2780 ch.nrows()
2781 ));
2782 }
2783 if ch.ncols() != p {
2784 return Err(format!(
2785 "MultiChannelOperator: channel {i} has {} cols, expected {p}",
2786 ch.ncols()
2787 ));
2788 }
2789 }
2790 Ok(Self {
2791 channels,
2792 n_per_channel: n,
2793 p,
2794 })
2795 }
2796}
2797
2798impl LinearOperator for MultiChannelOperator {
2799 fn nrows(&self) -> usize {
2800 self.n_per_channel * self.channels.len()
2801 }
2802
2803 fn ncols(&self) -> usize {
2804 self.p
2805 }
2806
2807 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
2808 let total = self.nrows();
2809 let mut out = Array1::<f64>::zeros(total);
2810 let n = self.n_per_channel;
2811 for (i, ch) in self.channels.iter().enumerate() {
2812 let ch_result = ch.matrixvectormultiply(vector);
2813 out.slice_mut(s![i * n..(i + 1) * n]).assign(&ch_result);
2814 }
2815 out
2816 }
2817
2818 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
2819 let n = self.n_per_channel;
2820 let mut out = Array1::<f64>::zeros(self.p);
2821 for (i, ch) in self.channels.iter().enumerate() {
2822 out += &ch.apply_transpose_view(vector.slice(s![i * n..(i + 1) * n]));
2823 }
2824 out
2825 }
2826
2827 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
2828 let n = self.n_per_channel;
2829 if weights.len() != self.nrows() {
2830 return Err(format!(
2831 "MultiChannelOperator::diag_xtw_x: weights length {} != nrows {}",
2832 weights.len(),
2833 self.nrows()
2834 ));
2835 }
2836 let w_pos = weights.mapv(|w: f64| w.max(0.0));
2845 let mut xtwx = Array2::<f64>::zeros((self.p, self.p));
2846 for (i, ch) in self.channels.iter().enumerate() {
2847 let ch_xtwx = ch
2848 .xt_diag_x_signed_op(SignedWeightsView::new(w_pos.slice(s![i * n..(i + 1) * n])))?;
2849 xtwx += &ch_xtwx;
2850 }
2851 Ok(xtwx)
2852 }
2853
2854 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
2855 let n = self.n_per_channel;
2856 if weights.len() != self.nrows() {
2857 return Err(format!(
2858 "MultiChannelOperator::diag_gram: weights length {} != nrows {}",
2859 weights.len(),
2860 self.nrows()
2861 ));
2862 }
2863 let w_pos = weights.mapv(|w: f64| w.max(0.0));
2873 let mut diag = Array1::<f64>::zeros(self.p);
2874 for (i, ch) in self.channels.iter().enumerate() {
2875 diag += &ch.diag_gram_view(w_pos.slice(s![i * n..(i + 1) * n]))?;
2876 }
2877 Ok(diag)
2878 }
2879
2880 fn uses_matrix_free_pcg(&self) -> bool {
2881 true
2882 }
2883}
2884
2885impl DenseDesignOperator for MultiChannelOperator {
2886 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
2887 let n = self.n_per_channel;
2888 let total = self.nrows();
2889 if weights.len() != total || y.len() != total {
2890 return Err(format!(
2891 "MultiChannelOperator::compute_xtwy: weights={}, y={}, nrows={}",
2892 weights.len(),
2893 y.len(),
2894 total
2895 ));
2896 }
2897 let w_pos = weights.mapv(|w: f64| w.max(0.0));
2906 let mut out = Array1::<f64>::zeros(self.p);
2907 for (i, ch) in self.channels.iter().enumerate() {
2908 out += &ch.compute_xtwy_view(
2909 w_pos.slice(s![i * n..(i + 1) * n]),
2910 y.slice(s![i * n..(i + 1) * n]),
2911 )?;
2912 }
2913 Ok(out)
2914 }
2915
2916 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
2917 let n = self.n_per_channel;
2918 let mut out = Array1::<f64>::zeros(self.nrows());
2919 for (i, ch) in self.channels.iter().enumerate() {
2920 let ch_diag = ch.quadratic_form_diag(middle)?;
2921 out.slice_mut(s![i * n..(i + 1) * n]).assign(&ch_diag);
2922 }
2923 Ok(out)
2924 }
2925
2926 fn to_dense(&self) -> Array2<f64> {
2927 let total = self.nrows();
2928 let n = self.n_per_channel;
2929 let mut out = Array2::<f64>::zeros((total, self.p));
2930 for (i, ch) in self.channels.iter().enumerate() {
2931 let dense = ch.to_dense();
2932 out.slice_mut(s![i * n..(i + 1) * n, ..]).assign(&dense);
2933 }
2934 out
2935 }
2936
2937 fn row_chunk_into(
2938 &self,
2939 rows: Range<usize>,
2940 mut out: ArrayViewMut2<'_, f64>,
2941 ) -> Result<(), MatrixMaterializationError> {
2942 if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
2943 return Err(MatrixMaterializationError::MissingRowChunk {
2944 context: "MultiChannelOperator::row_chunk_into shape mismatch",
2945 });
2946 }
2947 let n = self.n_per_channel;
2948 let mut local = 0usize;
2949 let mut global = rows.start;
2950 while global < rows.end {
2951 let ch_idx = global / n;
2952 let ch_local_start = global % n;
2953 let ch_local_end = ((ch_idx + 1) * n).min(rows.end) - ch_idx * n;
2954 let segment_len = ch_local_end - ch_local_start;
2955 let ch_chunk = self.channels[ch_idx].try_row_chunk(ch_local_start..ch_local_end)?;
2956 out.slice_mut(s![local..local + segment_len, ..])
2957 .assign(&ch_chunk);
2958 local += segment_len;
2959 global += segment_len;
2960 }
2961 Ok(())
2962 }
2963}
2964
2965mod kronecker;
2967pub use kronecker::*;
2968
2969pub struct CoefficientTransformOperator {
2976 inner: DenseDesignMatrix,
2977 transform: Arc<Array2<f64>>,
2978 n: usize,
2979 p_out: usize,
2980 materialized: OnceLock<Option<Arc<Array2<f64>>>>,
2989}
2990
2991impl CoefficientTransformOperator {
2992 const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
2996
2997 pub fn new(inner: DenseDesignMatrix, transform: Array2<f64>) -> Result<Self, String> {
2998 let p_inner = inner.ncols();
2999 if transform.nrows() != p_inner {
3000 return Err(format!(
3001 "CoefficientTransformOperator: inner has {} cols but transform has {} rows",
3002 p_inner,
3003 transform.nrows(),
3004 ));
3005 }
3006 let n = inner.nrows();
3007 let p_out = transform.ncols();
3008 Ok(Self {
3009 inner,
3010 transform: Arc::new(transform),
3011 n,
3012 p_out,
3013 materialized: OnceLock::new(),
3014 })
3015 }
3016
3017 fn materialized_combined(&self) -> Option<&Array2<f64>> {
3034 if let Some(slot) = self.materialized.get() {
3035 return slot.as_ref().map(|a| a.as_ref());
3036 }
3037 let bytes = self
3038 .n
3039 .checked_mul(self.p_out)
3040 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
3041 let computed = match bytes {
3042 Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
3043 let auto_policy = ResourcePolicy::for_problem(
3049 self.n,
3050 self.p_out,
3051 gam_runtime::resource::ProblemHints::default(),
3052 );
3053 let cache_policy = ResourcePolicy {
3054 max_single_materialization_bytes: Self::MATERIALIZE_MAX_BYTES,
3055 derivative_storage_mode: auto_policy.derivative_storage_mode,
3056 ..ResourcePolicy::default_library()
3057 };
3058 self.inner
3059 .try_to_dense_arc_with_policy(
3060 "CoefficientTransformOperator materialization",
3061 &cache_policy,
3062 )
3063 .ok()
3064 .map(|x| Arc::new(fast_ab(x.as_ref(), &self.transform)))
3065 }
3066 _ => None,
3067 };
3068 if self.materialized.set(computed).is_err() {
3069 return self
3070 .materialized
3071 .get()
3072 .and_then(|opt| opt.as_ref().map(|a| a.as_ref()));
3073 }
3074 self.materialized
3075 .get()
3076 .and_then(|opt| opt.as_ref().map(|a| a.as_ref()))
3077 }
3078}
3079
3080impl LinearOperator for CoefficientTransformOperator {
3081 fn nrows(&self) -> usize {
3082 self.n
3083 }
3084 fn ncols(&self) -> usize {
3085 self.p_out
3086 }
3087 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3088 if let Some(combined) = self.materialized_combined() {
3089 return fast_av(combined, vector);
3090 }
3091 let tv = fast_av(&self.transform, vector);
3092 self.inner.apply(&tv)
3093 }
3094 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
3095 if let Some(combined) = self.materialized_combined() {
3096 return fast_atv(combined, vector);
3097 }
3098 let xtv = self.inner.apply_transpose(vector);
3099 fast_atv(&self.transform, &xtv)
3100 }
3101 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
3102 if let Some(combined) = self.materialized_combined() {
3103 let mut xtwx = Array2::<f64>::zeros((self.p_out, self.p_out));
3104 stream_weighted_crossprod_into(
3105 combined,
3106 weights,
3107 &mut xtwx,
3108 CrossprodStructure::Full,
3109 CrossprodAccum::Replace,
3110 effective_global_parallelism(),
3111 );
3112 return Ok(xtwx);
3113 }
3114 let inner_xtwx = self.inner.diag_xtw_x(weights)?;
3115 let tmp = fast_ab(&self.transform.t().to_owned(), &inner_xtwx);
3117 Ok(fast_ab(&tmp, &self.transform))
3118 }
3119}
3120
3121impl DenseDesignOperator for CoefficientTransformOperator {
3122 fn as_dense_ref(&self) -> Option<&Array2<f64>> {
3130 self.materialized_combined()
3131 }
3132
3133 fn to_dense(&self) -> Array2<f64> {
3134 if let Some(combined) = self.materialized_combined() {
3135 return combined.clone();
3136 }
3137 let x = self.inner.to_dense();
3138 fast_ab(&x, &self.transform)
3139 }
3140 fn row_chunk_into(
3141 &self,
3142 rows: Range<usize>,
3143 mut out: ArrayViewMut2<'_, f64>,
3144 ) -> Result<(), MatrixMaterializationError> {
3145 if out.nrows() != rows.end - rows.start || out.ncols() != self.p_out {
3146 return Err(MatrixMaterializationError::MissingRowChunk {
3147 context: "CoefficientTransformOperator::row_chunk_into shape mismatch",
3148 });
3149 }
3150 if let Some(combined) = self.materialized_combined() {
3151 out.assign(&combined.slice(s![rows, ..]));
3152 return Ok(());
3153 }
3154 let chunk = self.inner.try_row_chunk(rows)?;
3155 out.assign(&fast_ab(&chunk, &self.transform));
3156 Ok(())
3157 }
3158}
3159
3160pub struct ResidualisedDesignOperator {
3174 inner: DenseDesignMatrix,
3175 transform: Arc<Array2<f64>>,
3176 anchors: Vec<(DesignMatrix, Arc<Array2<f64>>)>,
3177 n: usize,
3178 p_out: usize,
3179 materialized: OnceLock<Option<Arc<Array2<f64>>>>,
3180}
3181
3182impl ResidualisedDesignOperator {
3183 const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
3187
3188 pub fn new(
3189 inner: DenseDesignMatrix,
3190 transform: Array2<f64>,
3191 anchors: Vec<(DesignMatrix, Arc<Array2<f64>>)>,
3192 ) -> Result<Self, String> {
3193 let p_inner = inner.ncols();
3194 if transform.nrows() != p_inner {
3195 return Err(format!(
3196 "ResidualisedDesignOperator: inner has {} cols but transform has {} rows",
3197 p_inner,
3198 transform.nrows(),
3199 ));
3200 }
3201 let n = inner.nrows();
3202 let p_out = transform.ncols();
3203 for (idx, (anchor, r_block)) in anchors.iter().enumerate() {
3204 if anchor.nrows() != n {
3205 return Err(format!(
3206 "ResidualisedDesignOperator: anchor[{idx}] has {} rows but inner has {n}",
3207 anchor.nrows(),
3208 ));
3209 }
3210 if r_block.nrows() != anchor.ncols() || r_block.ncols() != p_out {
3211 return Err(format!(
3212 "ResidualisedDesignOperator: anchor[{idx}] r_block is {}x{} but expected {}x{}",
3213 r_block.nrows(),
3214 r_block.ncols(),
3215 anchor.ncols(),
3216 p_out,
3217 ));
3218 }
3219 }
3220 Ok(Self {
3221 inner,
3222 transform: Arc::new(transform),
3223 anchors,
3224 n,
3225 p_out,
3226 materialized: OnceLock::new(),
3227 })
3228 }
3229
3230 fn materialized_combined(&self) -> Option<&Array2<f64>> {
3236 if let Some(slot) = self.materialized.get() {
3237 return slot.as_ref().map(|a| a.as_ref());
3238 }
3239 let bytes = self
3240 .n
3241 .checked_mul(self.p_out)
3242 .and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
3243 let computed = match bytes {
3244 Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
3245 let auto_policy = ResourcePolicy::for_problem(
3246 self.n,
3247 self.p_out,
3248 gam_runtime::resource::ProblemHints::default(),
3249 );
3250 let cache_policy = ResourcePolicy {
3251 max_single_materialization_bytes: Self::MATERIALIZE_MAX_BYTES,
3252 derivative_storage_mode: auto_policy.derivative_storage_mode,
3253 ..ResourcePolicy::default_library()
3254 };
3255 self.inner
3256 .try_to_dense_arc_with_policy(
3257 "ResidualisedDesignOperator materialization",
3258 &cache_policy,
3259 )
3260 .ok()
3261 .and_then(|x| {
3262 let mut combined = fast_ab(x.as_ref(), &self.transform);
3263 for (anchor, r_block) in &self.anchors {
3264 let anchor_dense = match anchor {
3265 DesignMatrix::Dense(d) => d
3266 .try_to_dense_arc_with_policy(
3267 "ResidualisedDesignOperator anchor materialization",
3268 &cache_policy,
3269 )
3270 .ok()?,
3271 DesignMatrix::Sparse(s) => s
3272 .try_to_dense_arc(
3273 "ResidualisedDesignOperator anchor materialization",
3274 )
3275 .ok()?,
3276 };
3277 let contribution = fast_ab(anchor_dense.as_ref(), r_block.as_ref());
3278 combined -= &contribution;
3279 }
3280 Some(Arc::new(combined))
3281 })
3282 }
3283 _ => None,
3284 };
3285 if self.materialized.set(computed).is_err() {
3286 return self
3287 .materialized
3288 .get()
3289 .and_then(|opt| opt.as_ref().map(|a| a.as_ref()));
3290 }
3291 self.materialized
3292 .get()
3293 .and_then(|opt| opt.as_ref().map(|a| a.as_ref()))
3294 }
3295
3296 pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
3300 if let Some(combined) = self.materialized.get().and_then(|opt| opt.clone()) {
3301 return Ok(combined);
3302 }
3303 if let Some(_combined_ref) = self.materialized_combined() {
3304 if let Some(arc) = self.materialized.get().and_then(|opt| opt.clone()) {
3305 return Ok(arc);
3306 }
3307 }
3308 dense_operator_to_dense_by_chunks(self)
3309 .map(Arc::new)
3310 .map_err(|err| format!("{context}: failed to materialize dense row chunks: {err}"))
3311 }
3312}
3313
3314impl LinearOperator for ResidualisedDesignOperator {
3315 fn nrows(&self) -> usize {
3316 self.n
3317 }
3318 fn ncols(&self) -> usize {
3319 self.p_out
3320 }
3321 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3322 if let Some(combined) = self.materialized_combined() {
3323 return fast_av(combined, vector);
3324 }
3325 let tv = fast_av(&self.transform, vector);
3327 let mut out = self.inner.apply(&tv);
3328 for (anchor, r_block) in &self.anchors {
3329 let rv = fast_av(r_block.as_ref(), vector);
3330 let contrib = anchor.apply(&rv);
3331 out -= &contrib;
3332 }
3333 out
3334 }
3335 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
3336 if let Some(combined) = self.materialized_combined() {
3337 return fast_atv(combined, vector);
3338 }
3339 let xtv = self.inner.apply_transpose(vector);
3340 let mut out = fast_atv(&self.transform, &xtv);
3341 for (anchor, r_block) in &self.anchors {
3342 let atv = anchor.apply_transpose(vector);
3343 let contrib = fast_atv(r_block.as_ref(), &atv);
3344 out -= &contrib;
3345 }
3346 out
3347 }
3348 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
3349 if let Some(combined) = self.materialized_combined() {
3350 let mut xtwx = Array2::<f64>::zeros((self.p_out, self.p_out));
3351 stream_weighted_crossprod_into(
3352 combined,
3353 weights,
3354 &mut xtwx,
3355 CrossprodStructure::Full,
3356 CrossprodAccum::Replace,
3357 effective_global_parallelism(),
3358 );
3359 return Ok(xtwx);
3360 }
3361 let n = self.n;
3365 if weights.len() != n {
3366 return Err(format!(
3367 "ResidualisedDesignOperator::diag_xtw_x weights len {} != nrows {n}",
3368 weights.len()
3369 ));
3370 }
3371 let p = self.p_out;
3372 let chunk_rows = (8 * 1024 * 1024 / (p.max(1) * 8 * 2)).max(16).min(n.max(1));
3373 let mut xtwx = Array2::<f64>::zeros((p, p));
3374 let mut start = 0;
3375 while start < n {
3376 let end = (start + chunk_rows).min(n);
3377 let chunk = <Self as DenseDesignOperator>::try_row_chunk(self, start..end)
3378 .map_err(|e| e.to_string())?;
3379 let w_slice = weights.slice(s![start..end]).to_owned();
3380 let mut local = Array2::<f64>::zeros((p, p));
3381 stream_weighted_crossprod_into(
3382 &chunk,
3383 &w_slice,
3384 &mut local,
3385 CrossprodStructure::Full,
3386 CrossprodAccum::Replace,
3387 effective_global_parallelism(),
3388 );
3389 xtwx += &local;
3390 start = end;
3391 }
3392 Ok(xtwx)
3393 }
3394}
3395
3396impl DenseDesignOperator for ResidualisedDesignOperator {
3397 fn as_dense_ref(&self) -> Option<&Array2<f64>> {
3398 self.materialized_combined()
3399 }
3400
3401 fn to_dense(&self) -> Array2<f64> {
3402 if let Some(combined) = self.materialized_combined() {
3403 return combined.clone();
3404 }
3405 dense_operator_to_dense_by_chunks(self).unwrap_or_else(|err| {
3407 std::panic::panic_any(format!(
3408 "ResidualisedDesignOperator::to_dense: failed to materialize {}x{} \
3409 via row chunks: {err}",
3410 self.n, self.p_out,
3411 ))
3412 })
3413 }
3414
3415 fn row_chunk_into(
3416 &self,
3417 rows: Range<usize>,
3418 mut out: ArrayViewMut2<'_, f64>,
3419 ) -> Result<(), MatrixMaterializationError> {
3420 if out.nrows() != rows.end - rows.start || out.ncols() != self.p_out {
3421 return Err(MatrixMaterializationError::MissingRowChunk {
3422 context: "ResidualisedDesignOperator::row_chunk_into shape mismatch",
3423 });
3424 }
3425 if let Some(combined) = self.materialized_combined() {
3426 out.assign(&combined.slice(s![rows, ..]));
3427 return Ok(());
3428 }
3429 let inner_chunk = self.inner.try_row_chunk(rows.clone())?;
3431 let mut combined = fast_ab(&inner_chunk, &self.transform);
3432 for (anchor, r_block) in &self.anchors {
3434 let anchor_chunk = anchor.try_row_chunk(rows.clone())?;
3435 let contribution = fast_ab(&anchor_chunk, r_block.as_ref());
3436 combined -= &contribution;
3437 }
3438 out.assign(&combined);
3439 Ok(())
3440 }
3441}
3442
3443pub struct ConditionedDesign {
3456 inner: DesignMatrix,
3457 columns: Vec<(usize, f64, f64)>,
3459}
3460
3461impl ConditionedDesign {
3462 pub fn new(inner: DesignMatrix, columns: Vec<(usize, f64, f64)>) -> Self {
3463 Self { inner, columns }
3464 }
3465}
3466
3467impl LinearOperator for ConditionedDesign {
3468 fn nrows(&self) -> usize {
3469 self.inner.nrows()
3470 }
3471
3472 fn ncols(&self) -> usize {
3473 self.inner.ncols()
3474 }
3475
3476 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3478 let mut scaled = vector.clone();
3479 let mut shift = 0.0;
3480 for &(j, mean, scale) in &self.columns {
3481 scaled[j] /= scale;
3482 shift += mean * scaled[j];
3483 }
3484 let mut result = self.inner.apply(&scaled);
3485 if shift != 0.0 {
3486 result.mapv_inplace(|v| v - shift);
3487 }
3488 result
3489 }
3490
3491 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
3493 let mut result = self.inner.apply_transpose(vector);
3494 let sum_u: f64 = vector.iter().sum();
3495 for &(j, mean, scale) in &self.columns {
3496 result[j] = (result[j] - mean * sum_u) / scale;
3497 }
3498 result
3499 }
3500
3501 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
3503 let mut base = self.inner.diag_xtw_x(weights)?;
3504 if self.columns.is_empty() {
3505 return Ok(base);
3506 }
3507 let p = base.ncols();
3508 let w_pos: Array1<f64> = weights.mapv(|w| w.max(0.0));
3509 let sum_w: f64 = w_pos.sum();
3510 let cw = self.inner.apply_transpose(&w_pos);
3511
3512 let mut a = vec![1.0_f64; p];
3514 let mut d = vec![0.0_f64; p];
3515 for &(j, mean, scale) in &self.columns {
3516 a[j] = 1.0 / scale;
3517 d[j] = mean / scale;
3518 }
3519
3520 for i in 0..p {
3522 for j in i..p {
3523 let val = a[i] * base[[i, j]] * a[j] - a[i] * cw[i] * d[j] - d[i] * cw[j] * a[j]
3524 + sum_w * d[i] * d[j];
3525 base[[i, j]] = val;
3526 base[[j, i]] = val;
3527 }
3528 }
3529 Ok(base)
3530 }
3531
3532 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
3534 let mut result = self.inner.diag_gram(weights)?;
3535 if self.columns.is_empty() {
3536 return Ok(result);
3537 }
3538 let w_pos: Array1<f64> = weights.mapv(|w| w.max(0.0));
3539 let sum_w: f64 = w_pos.sum();
3540 let cw = self.inner.apply_transpose(&w_pos);
3541 for &(j, mean, scale) in &self.columns {
3542 let a_j = 1.0 / scale;
3543 let d_j = mean / scale;
3544 result[j] = a_j * a_j * result[j] - 2.0 * a_j * cw[j] * d_j + sum_w * d_j * d_j;
3545 }
3546 Ok(result)
3547 }
3548
3549 fn uses_matrix_free_pcg(&self) -> bool {
3550 match &self.inner {
3551 DesignMatrix::Dense(_) => true,
3552 DesignMatrix::Sparse(_) => false,
3553 }
3554 }
3555}
3556
3557impl DenseDesignOperator for ConditionedDesign {
3558 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
3560 let mut result = self.inner.compute_xtwy(weights, y)?;
3561 if self.columns.is_empty() {
3562 return Ok(result);
3563 }
3564 let sum_wy: f64 = weights
3565 .iter()
3566 .zip(y.iter())
3567 .map(|(&w, &yi)| w.max(0.0) * yi)
3568 .sum();
3569 for &(j, mean, scale) in &self.columns {
3570 result[j] = (result[j] - mean * sum_wy) / scale;
3571 }
3572 Ok(result)
3573 }
3574
3575 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
3577 if self.columns.is_empty() {
3578 return self.inner.quadratic_form_diag(middle);
3579 }
3580 let p = self.ncols();
3581 let mut d = Array1::zeros(p);
3582 for &(j, mean, scale) in &self.columns {
3583 d[j] = mean / scale;
3584 }
3585
3586 let mut ama = middle.clone();
3588 for &(j, _, scale) in &self.columns {
3589 for k in 0..p {
3590 ama[[j, k]] /= scale;
3591 ama[[k, j]] /= scale;
3592 }
3593 }
3594
3595 let md = middle.dot(&d);
3597 let mut amd = md;
3598 for &(j, _, scale) in &self.columns {
3599 amd[j] /= scale;
3600 }
3601
3602 let dtmd: f64 = d.dot(&middle.dot(&d));
3603
3604 let mut result = self.inner.quadratic_form_diag(&ama)?;
3605 let x_amd = self.inner.apply(&amd);
3606 for i in 0..result.len() {
3607 result[i] = (result[i] - 2.0 * x_amd[i] + dtmd).max(0.0);
3608 }
3609 Ok(result)
3610 }
3611
3612 fn row_chunk_into(
3613 &self,
3614 rows: Range<usize>,
3615 mut out: ArrayViewMut2<'_, f64>,
3616 ) -> Result<(), MatrixMaterializationError> {
3617 if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
3618 return Err(MatrixMaterializationError::MissingRowChunk {
3619 context: "ConditionedDesign::row_chunk_into shape mismatch",
3620 });
3621 }
3622 let mut chunk = self.inner.try_row_chunk(rows)?;
3623 for &(j, mean, scale) in &self.columns {
3624 chunk.column_mut(j).mapv_inplace(|v| (v - mean) / scale);
3625 }
3626 out.assign(&chunk);
3627 Ok(())
3628 }
3629
3630 fn to_dense(&self) -> Array2<f64> {
3631 let mut dense = self.inner.to_dense();
3632 for &(j, mean, scale) in &self.columns {
3633 dense.column_mut(j).mapv_inplace(|v| (v - mean) / scale);
3634 }
3635 dense
3636 }
3637}
3638
3639#[derive(Clone)]
3649pub enum DesignMatrix {
3650 Dense(DenseDesignMatrix),
3651 Sparse(SparseDesignMatrix),
3652}
3653
3654impl std::fmt::Debug for DesignMatrix {
3655 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3656 match self {
3657 Self::Dense(m) => write!(f, "DesignMatrix::Dense({}x{})", m.nrows(), m.ncols()),
3658 Self::Sparse(s) => write!(f, "DesignMatrix::Sparse({}x{})", s.nrows(), s.ncols()),
3659 }
3660 }
3661}
3662
3663mod symmetric;
3665pub use symmetric::*;
3666pub trait FactorizedSystem: Send + Sync {
3668 fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String>;
3670
3671 fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String>;
3673
3674 fn logdet(&self) -> f64;
3676}
3677
3678pub trait LinearOperator {
3679 fn nrows(&self) -> usize;
3680 fn ncols(&self) -> usize;
3681 fn apply(&self, vector: &Array1<f64>) -> Array1<f64>;
3682 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64>;
3683 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String>;
3684
3685 fn xt_diag_x_signed_op(&self, weights: SignedWeightsView<'_>) -> Result<Array2<f64>, String> {
3691 self.diag_xtw_x(&weights.view().to_owned())
3692 }
3693
3694 fn xt_diag_x_psd_op(&self, weights: PsdWeightsView<'_>) -> Result<SymmetricMatrix, String> {
3699 let xtwx = self.diag_xtw_x(&weights.view().to_owned())?;
3700 Ok(SymmetricMatrix::Dense(xtwx))
3701 }
3702
3703 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
3704 let xtwx = self.diag_xtw_x(weights)?;
3705 Ok(Array1::from_iter((0..self.ncols()).map(|j| xtwx[[j, j]])))
3706 }
3707 fn apply_weighted_normal(
3708 &self,
3709 weights: &Array1<f64>,
3710 vector: &Array1<f64>,
3711 penalty: Option<&Array2<f64>>,
3712 ridge: f64,
3713 ) -> Array1<f64> {
3714 assert_eq!(
3715 weights.len(),
3716 self.nrows(),
3717 "apply_weighted_normal weight length mismatch"
3718 );
3719 assert_eq!(
3720 vector.len(),
3721 self.ncols(),
3722 "apply_weighted_normal vector length mismatch"
3723 );
3724 let xv = self.apply(vector);
3725 let mut weighted_xv = xv;
3726 for i in 0..weighted_xv.len() {
3727 weighted_xv[i] *= weights[i].max(0.0);
3728 }
3729 let mut out = self.apply_transpose(&weighted_xv);
3730 if let Some(pen) = penalty {
3731 out += &fast_av(pen, vector);
3732 }
3733 if ridge > 0.0 {
3734 out.scaled_add(ridge, vector);
3736 }
3737 out
3738 }
3739 fn uses_matrix_free_pcg(&self) -> bool {
3740 false
3741 }
3742 fn solve_system_matrix_free_pcg_try(
3743 &self,
3744 weights: &Array1<f64>,
3745 rhs: &Array1<f64>,
3746 penalty: Option<&Array2<f64>>,
3747 baseridge: f64,
3748 ) -> Result<Array1<f64>, String> {
3749 self.solve_system_matrix_free_pcg_with_info_try(weights, rhs, penalty, baseridge)
3750 .map(|(solution, _)| solution)
3751 }
3752 fn solve_system_matrix_free_pcg_with_info_try(
3753 &self,
3754 weights: &Array1<f64>,
3755 rhs: &Array1<f64>,
3756 penalty: Option<&Array2<f64>>,
3757 baseridge: f64,
3758 ) -> Result<(Array1<f64>, PcgSolveInfo), String> {
3759 if rhs.len() != self.ncols() {
3760 return Err(format!(
3761 "solve_system_matrix_free_pcg rhs dimension mismatch: rhs length {} != ncols {}",
3762 rhs.len(),
3763 self.ncols()
3764 ));
3765 }
3766 if !self.uses_matrix_free_pcg() {
3767 return Err("matrix-free PCG is only enabled for eligible operator types".to_string());
3768 }
3769 if let Some(pen) = penalty
3770 && (pen.nrows() != self.ncols() || pen.ncols() != self.ncols())
3771 {
3772 return Err(format!(
3773 "solve_system_matrix_free_pcg penalty shape mismatch: got {}x{}, expected {}x{}",
3774 pen.nrows(),
3775 pen.ncols(),
3776 self.ncols(),
3777 self.ncols()
3778 ));
3779 }
3780 let p = self.ncols();
3781 for retry in 0..8 {
3782 let ridge = if baseridge > 0.0 {
3783 baseridge * 10f64.powi(retry)
3784 } else {
3785 0.0
3786 };
3787 let normal_op = PenalizedWeightedNormalOperator {
3788 operator: self,
3789 weights,
3790 penalty,
3791 ridge,
3792 };
3793 let preconditioner = normal_op.jacobi_preconditioner()?;
3794 let attempt_started = std::time::Instant::now();
3795 let solved = crate::utils::solve_spd_pcg_with_info(
3796 |v| normal_op.apply(v),
3797 rhs,
3798 &preconditioner,
3799 MATRIX_FREE_PCG_REL_TOL,
3800 MATRIX_FREE_PCG_MAX_ITER.max(4 * p),
3801 );
3802 let elapsed = attempt_started.elapsed().as_secs_f64();
3803 match solved {
3810 Some((solution, info)) if solution.iter().all(|v| v.is_finite()) => {
3811 if retry > 0 {
3812 log::info!(
3813 "[matrix-free PCG] converged after ridge escalation: p={p} retry={retry} ridge={ridge:.3e} iters={} converged={} rel_resid={:.3e} elapsed={elapsed:.3}s",
3814 info.iterations,
3815 info.converged,
3816 info.relative_residual_norm,
3817 );
3818 } else {
3819 log::debug!(
3820 "[matrix-free PCG] solved: p={p} iters={} converged={} rel_resid={:.3e} elapsed={elapsed:.3}s",
3821 info.iterations,
3822 info.converged,
3823 info.relative_residual_norm,
3824 );
3825 }
3826 return Ok((solution, info));
3827 }
3828 Some((_, info)) => {
3829 log::info!(
3830 "[matrix-free PCG] non-finite solution, escalating ridge: p={p} retry={retry} ridge={ridge:.3e} iters={} converged={} rel_resid={:.3e} elapsed={elapsed:.3}s",
3831 info.iterations,
3832 info.converged,
3833 info.relative_residual_norm,
3834 );
3835 }
3836 None => {
3837 log::info!(
3838 "[matrix-free PCG] CG breakdown (non-SPD/NaN), escalating ridge: p={p} retry={retry} ridge={ridge:.3e} elapsed={elapsed:.3}s",
3839 );
3840 }
3841 }
3842 }
3843 Err("matrix-free PCG failed after ridge retries".to_string())
3844 }
3845 fn factorize_system(
3846 &self,
3847 weights: &Array1<f64>,
3848 penalty: Option<&Array2<f64>>,
3849 ) -> Result<Box<dyn FactorizedSystem>, String> {
3850 let mut system = self.diag_xtw_x(weights)?;
3851 if let Some(pen) = penalty {
3852 if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
3853 return Err(format!(
3854 "factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
3855 pen.nrows(),
3856 pen.ncols(),
3857 system.nrows(),
3858 system.ncols()
3859 ));
3860 }
3861 system += pen;
3862 }
3863 let factor = crate::utils::StableSolver::new("linear operator system")
3864 .factorize(&system)
3865 .map_err(|e| format!("factorize_system failed: {e:?}"))?;
3866 Ok(Box::new(factor))
3867 }
3868 fn solve_system(
3869 &self,
3870 weights: &Array1<f64>,
3871 rhs: &Array1<f64>,
3872 penalty: Option<&Array2<f64>>,
3873 ) -> Result<Array1<f64>, String> {
3874 self.solve_systemwith_policy(
3875 weights,
3876 rhs,
3877 penalty,
3878 SPD_SOLVE_RIDGE_FLOOR,
3879 RidgePolicy::explicit_stabilization_pospart(),
3880 )
3881 }
3882 fn solve_systemwith_policy(
3883 &self,
3884 weights: &Array1<f64>,
3885 rhs: &Array1<f64>,
3886 penalty: Option<&Array2<f64>>,
3887 ridge_floor: f64,
3888 ridge_policy: RidgePolicy,
3889 ) -> Result<Array1<f64>, String> {
3890 if rhs.len() != self.ncols() {
3891 return Err(format!(
3892 "solve_systemwith_policy rhs dimension mismatch: rhs length {} != ncols {}",
3893 rhs.len(),
3894 self.ncols()
3895 ));
3896 }
3897 let baseridge = if ridge_policy.include_laplacehessian {
3898 ridge_floor.max(SPD_SOLVE_RIDGE_FLOOR)
3899 } else {
3900 0.0
3901 };
3902 if self.uses_matrix_free_pcg()
3904 && self.ncols() >= MATRIX_FREE_PCG_MIN_P
3905 && let Ok(solution) =
3906 self.solve_system_matrix_free_pcg_try(weights, rhs, penalty, baseridge)
3907 {
3908 return Ok(solution);
3909 }
3910 let mut system = self.diag_xtw_x(weights)?;
3912 if let Some(pen) = penalty {
3913 if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
3914 return Err(format!(
3915 "solve_systemwith_policy penalty shape mismatch: got {}x{}, expected {}x{}",
3916 pen.nrows(),
3917 pen.ncols(),
3918 system.nrows(),
3919 system.ncols()
3920 ));
3921 }
3922 system += pen;
3923 }
3924 crate::utils::StableSolver::new("linear operator system")
3925 .solvevectorwithridge_retries(&system, rhs, baseridge)
3926 .ok_or_else(|| "solve_systemwith_policy failed after ridge retries".to_string())
3927 }
3928}
3929
3930impl LinearOperator for DesignMatrix {
3931 fn uses_matrix_free_pcg(&self) -> bool {
3932 match self {
3933 Self::Dense(matrix) => matrix.uses_matrix_free_pcg(),
3934 Self::Sparse(_) => false,
3935 }
3936 }
3937
3938 fn nrows(&self) -> usize {
3939 match self {
3940 Self::Dense(matrix) => matrix.nrows(),
3941 Self::Sparse(matrix) => matrix.nrows(),
3942 }
3943 }
3944
3945 fn ncols(&self) -> usize {
3946 match self {
3947 Self::Dense(matrix) => matrix.ncols(),
3948 Self::Sparse(matrix) => matrix.ncols(),
3949 }
3950 }
3951
3952 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
3953 match self {
3954 Self::Dense(matrix) => matrix.apply(vector),
3955 Self::Sparse(matrix) => {
3956 let mut output = Array1::<f64>::zeros(matrix.nrows());
3957 let (symbolic, values) = matrix.parts();
3958 let col_ptr = symbolic.col_ptr();
3959 let row_idx = symbolic.row_idx();
3960 for col in 0..matrix.ncols() {
3961 let start = col_ptr[col];
3962 let end = col_ptr[col + 1];
3963 let x = vector[col];
3964 for idx in start..end {
3965 let row = row_idx[idx];
3966 output[row] += values[idx] * x;
3967 }
3968 }
3969 output
3970 }
3971 }
3972 }
3973
3974 fn apply_weighted_normal(
3975 &self,
3976 weights: &Array1<f64>,
3977 vector: &Array1<f64>,
3978 penalty: Option<&Array2<f64>>,
3979 ridge: f64,
3980 ) -> Array1<f64> {
3981 assert_eq!(
3982 weights.len(),
3983 self.nrows(),
3984 "DesignMatrix::apply_weighted_normal weight length mismatch"
3985 );
3986 assert_eq!(
3987 vector.len(),
3988 self.ncols(),
3989 "DesignMatrix::apply_weighted_normal vector length mismatch"
3990 );
3991 match self {
3992 Self::Dense(matrix) => matrix.apply_weighted_normal(weights, vector, penalty, ridge),
3993 Self::Sparse(_) => {
3994 let sparse = self
3995 .as_sparse()
3996 .expect("DesignMatrix::Sparse must expose sparse view");
3997 let mut out = if let Some(csr) = sparse.to_csr_arc() {
3998 let sym = csr.symbolic();
3999 let row_ptr = sym.row_ptr();
4000 let col_idx = sym.col_idx();
4001 let vals = csr.val();
4002 let mut fused = Array1::<f64>::zeros(self.ncols());
4003 for i in 0..self.nrows() {
4004 let wi = weights[i].max(0.0);
4005 if wi == 0.0 {
4006 continue;
4007 }
4008 let start = row_ptr[i];
4009 let end = row_ptr[i + 1];
4010 let mut row_dot = 0.0_f64;
4011 for ptr in start..end {
4012 row_dot += vals[ptr] * vector[col_idx[ptr]];
4013 }
4014 if row_dot == 0.0 {
4015 continue;
4016 }
4017 let scaled = wi * row_dot;
4018 for ptr in start..end {
4019 fused[col_idx[ptr]] += vals[ptr] * scaled;
4020 }
4021 }
4022 fused
4023 } else {
4024 let xv = self.apply(vector);
4025 let mut weighted_xv = xv;
4026 for i in 0..weighted_xv.len() {
4027 weighted_xv[i] *= weights[i].max(0.0);
4028 }
4029 self.apply_transpose(&weighted_xv)
4030 };
4031 if let Some(pen) = penalty {
4032 out += &fast_av(pen, vector);
4033 }
4034 if ridge > 0.0 {
4035 for j in 0..out.len() {
4036 out[j] += ridge * vector[j];
4037 }
4038 }
4039 out
4040 }
4041 }
4042 }
4043
4044 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
4045 match self {
4046 Self::Dense(matrix) => matrix.apply_transpose(vector),
4047 Self::Sparse(matrix) => {
4048 let mut output = Array1::<f64>::zeros(matrix.ncols());
4049 let (symbolic, values) = matrix.parts();
4050 let col_ptr = symbolic.col_ptr();
4051 let row_idx = symbolic.row_idx();
4052 for col in 0..matrix.ncols() {
4053 let mut acc = 0.0;
4054 let start = col_ptr[col];
4055 let end = col_ptr[col + 1];
4056 for idx in start..end {
4057 let row = row_idx[idx];
4058 acc += values[idx] * vector[row];
4059 }
4060 output[col] = acc;
4061 }
4062 output
4063 }
4064 }
4065 }
4066
4067 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
4068 if weights.len() != self.nrows() {
4069 return Err(format!(
4070 "xt_diag_x dimension mismatch: weights length {} != nrows {}",
4071 weights.len(),
4072 self.nrows()
4073 ));
4074 }
4075 let p = self.ncols();
4076 match self {
4077 Self::Dense(x) => x.diag_xtw_x(weights),
4078 Self::Sparse(xs) => {
4079 let n = self.nrows();
4102 let nnz_x = xs.as_ref().val().len();
4103 let avg_nnz_row = if n > 0 { nnz_x / n } else { p };
4104 let dense_regime = 4 * avg_nnz_row >= p;
4105 if dense_regime {
4106 let mut xtwx = Array2::<f64>::zeros((p, p));
4107 let dense_bytes =
4108 checked_dense_nbytes(n, p, "DesignMatrix::diag_xtw_x dense sparse route")?;
4109 if dense_bytes <= MAX_SPARSE_TO_DENSE_BYTES {
4110 let xd =
4111 xs.try_to_dense_arc("DesignMatrix::diag_xtw_x dense sparse route")?;
4112 stream_weighted_crossprod_into(
4113 xd.as_ref(),
4114 weights,
4115 &mut xtwx,
4116 CrossprodStructure::Full,
4117 CrossprodAccum::Replace,
4118 effective_global_parallelism(),
4119 );
4120 } else {
4121 let (symbolic, values) = xs.parts();
4122 streaming_sparse_csc_xt_diag_x(
4123 symbolic.col_ptr(),
4124 symbolic.row_idx(),
4125 values,
4126 n,
4127 p,
4128 weights.view(),
4129 &mut xtwx,
4130 );
4131 }
4132 return Ok(xtwx);
4133 }
4134 let csr = xs
4135 .to_csr_arc()
4136 .ok_or_else(|| "failed to obtain CSR view in xt_diag_x".to_string())?;
4137 let sym = csr.symbolic();
4138 Ok(sparse_csr_weighted_xtwx(
4139 sym.row_ptr(),
4140 sym.col_idx(),
4141 csr.val(),
4142 n,
4143 p,
4144 weights.view(),
4145 ))
4146 }
4147 }
4148 }
4149
4150 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
4151 if weights.len() != self.nrows() {
4152 return Err(format!(
4153 "diag_gram dimension mismatch: weights length {} != nrows {}",
4154 weights.len(),
4155 self.nrows()
4156 ));
4157 }
4158 let p = self.ncols();
4159 match self {
4160 Self::Dense(x) => x.diag_gram(weights),
4161 Self::Sparse(xs) => {
4162 let csr = xs
4163 .to_csr_arc()
4164 .ok_or_else(|| "failed to obtain CSR view in diag_gram".to_string())?;
4165 let sym = csr.symbolic();
4166 Ok(sparse_csr_diag_gram(
4167 sym.row_ptr(),
4168 sym.col_idx(),
4169 csr.val(),
4170 self.nrows(),
4171 p,
4172 weights.view(),
4173 ))
4174 }
4175 }
4176 }
4177
4178 fn factorize_system(
4179 &self,
4180 weights: &Array1<f64>,
4181 penalty: Option<&Array2<f64>>,
4182 ) -> Result<Box<dyn FactorizedSystem>, String> {
4183 if weights.len() != self.nrows() {
4184 return Err(format!(
4185 "factorize_system dimension mismatch: weights length {} != nrows {}",
4186 weights.len(),
4187 self.nrows()
4188 ));
4189 }
4190 match self {
4191 Self::Dense(_) => self.factorize_system_dense(weights, penalty),
4192 Self::Sparse(matrix) => {
4193 let system = assemble_sparseweighted_gram_system(matrix, weights, penalty)?;
4194 let factor = crate::sparse_exact::factorize_sparse_spd(&system)
4195 .map_err(|e| format!("factorize_system failed: {e:?}"))?;
4196 Ok(Box::new(factor))
4197 }
4198 }
4199 }
4200}
4201
4202impl DenseDesignOperator for DesignMatrix {
4203 fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
4204 if weights.len() != self.nrows() || y.len() != self.nrows() {
4205 return Err(format!(
4206 "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
4207 weights.len(),
4208 y.len(),
4209 self.nrows()
4210 ));
4211 }
4212 match self {
4213 Self::Dense(x) => x.compute_xtwy(weights, y),
4214 Self::Sparse(xs) => {
4215 let csr = xs
4216 .as_ref()
4217 .to_row_major()
4218 .map_err(|_| "failed to obtain CSR view in compute_xtwy".to_string())?;
4219 let sym = csr.symbolic();
4220 let row_ptr = sym.row_ptr();
4221 let col_idx = sym.col_idx();
4222 let vals = csr.val();
4223 let mut out = Array1::<f64>::zeros(xs.ncols());
4224 for i in 0..xs.nrows() {
4225 let scaled = weights[i].max(0.0) * y[i];
4226 if scaled == 0.0 {
4227 continue;
4228 }
4229 for idx in row_ptr[i]..row_ptr[i + 1] {
4230 out[col_idx[idx]] += vals[idx] * scaled;
4231 }
4232 }
4233 Ok(out)
4234 }
4235 }
4236 }
4237
4238 fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
4239 if middle.nrows() != self.ncols() || middle.ncols() != self.ncols() {
4240 return Err(format!(
4241 "quadratic_form_diag dimension mismatch: matrix is {}x{}, expected {}x{}",
4242 middle.nrows(),
4243 middle.ncols(),
4244 self.ncols(),
4245 self.ncols()
4246 ));
4247 }
4248
4249 match self {
4250 Self::Dense(xd) => xd.quadratic_form_diag(middle),
4251 Self::Sparse(xs) => {
4252 let csr = xs
4253 .to_csr_arc()
4254 .ok_or_else(|| "quadratic_form_diag: failed to obtain CSR view".to_string())?;
4255 let sym = csr.symbolic();
4256 let row_ptr = sym.row_ptr();
4257 let col_idx = sym.col_idx();
4258 let vals = csr.val();
4259 let mut out = Array1::<f64>::zeros(self.nrows());
4260 for i in 0..xs.nrows() {
4261 let start = row_ptr[i];
4262 let end = row_ptr[i + 1];
4263 let mut acc = 0.0_f64;
4264 for a in start..end {
4265 let j = col_idx[a];
4266 let xij = vals[a];
4267 for b in start..end {
4268 let k = col_idx[b];
4269 let xik = vals[b];
4270 acc += xij * middle[[j, k]] * xik;
4271 }
4272 }
4273 out[i] = acc.max(0.0);
4274 }
4275 Ok(out)
4276 }
4277 }
4278 }
4279
4280 fn row_chunk_into(
4281 &self,
4282 rows: Range<usize>,
4283 mut out: ArrayViewMut2<'_, f64>,
4284 ) -> Result<(), MatrixMaterializationError> {
4285 if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
4286 return Err(MatrixMaterializationError::MissingRowChunk {
4287 context: "DesignMatrix::row_chunk_into shape mismatch",
4288 });
4289 }
4290 match self {
4291 Self::Dense(matrix) => matrix.row_chunk_into(rows, out),
4292 Self::Sparse(matrix) => {
4293 out.fill(0.0);
4294 let csr =
4295 matrix
4296 .to_csr_arc()
4297 .ok_or(MatrixMaterializationError::MissingRowChunk {
4298 context: "DesignMatrix::row_chunk_into: failed to obtain CSR view",
4299 })?;
4300 let sym = csr.symbolic();
4301 let row_ptr = sym.row_ptr();
4302 let col_idx = sym.col_idx();
4303 let vals = csr.val();
4304 for (local_row, row) in rows.enumerate() {
4305 for ptr in row_ptr[row]..row_ptr[row + 1] {
4306 out[[local_row, col_idx[ptr]]] = vals[ptr];
4307 }
4308 }
4309 Ok(())
4310 }
4311 }
4312 }
4313
4314 fn to_dense(&self) -> Array2<f64> {
4315 DesignMatrix::to_dense(self)
4316 }
4317}
4318
4319impl LinearOperator for DenseRightProductView<'_> {
4320 fn nrows(&self) -> usize {
4321 self.base.nrows()
4322 }
4323
4324 fn ncols(&self) -> usize {
4325 self.transformed_ncols()
4326 }
4327
4328 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
4329 let rhs;
4330 let v = match (self.second, self.first) {
4331 (None, None) => vector,
4332 (Some(s), None) => {
4333 rhs = fast_av(s, vector);
4334 &rhs
4335 }
4336 (None, Some(f)) => {
4337 rhs = fast_av(f, vector);
4338 &rhs
4339 }
4340 (Some(s), Some(f)) => {
4341 let tmp = fast_av(s, vector);
4342 rhs = fast_av(f, &tmp);
4343 &rhs
4344 }
4345 };
4346 fast_av(self.base, v)
4347 }
4348
4349 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
4350 let mut out = fast_atv(self.base, vector);
4351 if let Some(factor) = self.first {
4352 out = fast_atv(factor, &out);
4353 }
4354 if let Some(factor) = self.second {
4355 out = fast_atv(factor, &out);
4356 }
4357 out
4358 }
4359
4360 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
4361 if weights.len() != self.nrows() {
4362 return Err(format!(
4363 "xt_diag_x dimension mismatch: weights length {} != nrows {}",
4364 weights.len(),
4365 self.nrows()
4366 ));
4367 }
4368 let mut gram = fast_xt_diag_x(self.base, weights);
4369 if let Some(factor) = self.first {
4370 gram = fast_ab(&fast_atb(factor, &gram), factor);
4371 }
4372 if let Some(factor) = self.second {
4373 gram = fast_ab(&fast_atb(factor, &gram), factor);
4374 }
4375 Ok(gram)
4376 }
4377
4378 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
4379 Ok(self.diag_xtw_x(weights)?.diag().to_owned())
4380 }
4381}
4382
4383impl DenseRightProductView<'_> {
4384 pub fn compute_xtwy(
4385 &self,
4386 weights: &Array1<f64>,
4387 y: &Array1<f64>,
4388 ) -> Result<Array1<f64>, String> {
4389 if weights.len() != self.nrows() || y.len() != self.nrows() {
4390 return Err(format!(
4391 "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
4392 weights.len(),
4393 y.len(),
4394 self.nrows()
4395 ));
4396 }
4397 let weighted_xty = dense_transpose_weighted_response(self.base, weights, y, None);
4398 let mut out = weighted_xty;
4399 if let Some(factor) = self.first {
4400 out = fast_atv(factor, &out);
4401 }
4402 if let Some(factor) = self.second {
4403 out = fast_atv(factor, &out);
4404 }
4405 Ok(out)
4406 }
4407
4408 pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
4409 let dense = self.materialize();
4410 DesignMatrix::Dense(DenseDesignMatrix::from(dense)).quadratic_form_diag(middle)
4411 }
4412}
4413
4414impl LinearOperator for EmbeddedColumnBlock<'_> {
4415 fn nrows(&self) -> usize {
4416 self.local.nrows()
4417 }
4418
4419 fn ncols(&self) -> usize {
4420 self.total_cols
4421 }
4422
4423 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
4424 fast_av(
4425 self.local,
4426 &vector.slice(ndarray::s![self.global_range.clone()]),
4427 )
4428 }
4429
4430 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
4431 let mut out = Array1::<f64>::zeros(self.total_cols);
4432 out.slice_mut(ndarray::s![self.global_range.clone()])
4433 .assign(&fast_atv(self.local, vector));
4434 out
4435 }
4436
4437 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
4438 if weights.len() != self.nrows() {
4439 return Err(format!(
4440 "xt_diag_x dimension mismatch: weights length {} != nrows {}",
4441 weights.len(),
4442 self.nrows()
4443 ));
4444 }
4445 let mut out = Array2::<f64>::zeros((self.total_cols, self.total_cols));
4446 let local = fast_xt_diag_x(self.local, weights);
4447 out.slice_mut(ndarray::s![
4448 self.global_range.clone(),
4449 self.global_range.clone()
4450 ])
4451 .assign(&local);
4452 Ok(out)
4453 }
4454
4455 fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
4456 let mut out = Array1::<f64>::zeros(self.total_cols);
4457 let local =
4458 DesignMatrix::Dense(DenseDesignMatrix::from(self.local.clone())).diag_gram(weights)?;
4459 out.slice_mut(ndarray::s![self.global_range.clone()])
4460 .assign(&local);
4461 Ok(out)
4462 }
4463}
4464
4465impl EmbeddedColumnBlock<'_> {
4466 pub fn compute_xtwy(
4467 &self,
4468 weights: &Array1<f64>,
4469 y: &Array1<f64>,
4470 ) -> Result<Array1<f64>, String> {
4471 if weights.len() != self.nrows() || y.len() != self.nrows() {
4472 return Err(format!(
4473 "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
4474 weights.len(),
4475 y.len(),
4476 self.nrows()
4477 ));
4478 }
4479 let local = dense_transpose_weighted_response(self.local, weights, y, None);
4480 let mut out = Array1::<f64>::zeros(self.total_cols);
4481 out.slice_mut(ndarray::s![self.global_range.clone()])
4482 .assign(&local);
4483 Ok(out)
4484 }
4485
4486 pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
4487 let middle_local = middle
4488 .slice(ndarray::s![
4489 self.global_range.clone(),
4490 self.global_range.clone()
4491 ])
4492 .to_owned();
4493 DesignMatrix::Dense(DenseDesignMatrix::from(self.local.clone()))
4494 .quadratic_form_diag(&middle_local)
4495 }
4496}
4497
4498impl DesignMatrix {
4499 fn factorize_system_dense(
4500 &self,
4501 weights: &Array1<f64>,
4502 penalty: Option<&Array2<f64>>,
4503 ) -> Result<Box<dyn FactorizedSystem>, String> {
4504 let mut system = self.diag_xtw_x(weights)?;
4505 if let Some(pen) = penalty {
4506 if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
4507 return Err(format!(
4508 "factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
4509 pen.nrows(),
4510 pen.ncols(),
4511 system.nrows(),
4512 system.ncols()
4513 ));
4514 }
4515 system += pen;
4516 }
4517 let factor = crate::utils::StableSolver::new("linear operator system")
4518 .factorize(&system)
4519 .map_err(|e| format!("factorize_system failed: {e:?}"))?;
4520 Ok(Box::new(factor))
4521 }
4522}
4523
4524fn assemble_sparseweighted_gram_system(
4525 matrix: &SparseDesignMatrix,
4526 weights: &Array1<f64>,
4527 penalty: Option<&Array2<f64>>,
4528) -> Result<SparseColMat<usize, f64>, String> {
4529 let csr = matrix
4530 .to_csr_arc()
4531 .ok_or_else(|| "failed to obtain CSR view in factorize_system".to_string())?;
4532 let sym = csr.symbolic();
4533 let row_ptr = sym.row_ptr();
4534 let col_idx = sym.col_idx();
4535 let vals = csr.val();
4536 let p = matrix.ncols();
4537 let mut upper = BTreeMap::<(usize, usize), f64>::new();
4538
4539 for i in 0..csr.nrows() {
4540 let wi = weights[i].max(0.0);
4541 if wi == 0.0 {
4542 continue;
4543 }
4544 let start = row_ptr[i];
4545 let end = row_ptr[i + 1];
4546 for a_ptr in start..end {
4547 let a = col_idx[a_ptr];
4548 let xa = vals[a_ptr];
4549 for b_ptr in a_ptr..end {
4550 let b = col_idx[b_ptr];
4551 let xb = vals[b_ptr];
4552 let key = if a <= b { (a, b) } else { (b, a) };
4553 *upper.entry(key).or_insert(0.0) += wi * xa * xb;
4554 }
4555 }
4556 }
4557
4558 if let Some(pen) = penalty {
4559 if pen.nrows() != p || pen.ncols() != p {
4560 return Err(format!(
4561 "factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
4562 pen.nrows(),
4563 pen.ncols(),
4564 p,
4565 p
4566 ));
4567 }
4568 for i in 0..p {
4569 for j in i..p {
4570 let value = pen[[i, j]];
4571 if value != 0.0 {
4572 *upper.entry((i, j)).or_insert(0.0) += value;
4573 }
4574 }
4575 }
4576 }
4577
4578 let mut triplets = Vec::with_capacity(upper.len());
4579 for ((row, col), value) in upper {
4580 if value != 0.0 {
4581 triplets.push(Triplet::new(row, col, value));
4582 }
4583 }
4584 SparseColMat::try_new_from_triplets(p, p, &triplets)
4585 .map_err(|_| "failed to build sparse penalized system".to_string())
4586}
4587
4588impl DesignMatrix {
4589 pub fn hstack(blocks: Vec<DesignMatrix>) -> Result<Self, String> {
4595 if blocks.is_empty() {
4596 return Err("DesignMatrix::hstack requires at least one block".to_string());
4597 }
4598 if blocks.len() == 1 {
4599 return Ok(blocks.into_iter().next().expect("non-empty block list"));
4600 }
4601 let operator =
4602 BlockDesignOperator::new(blocks.into_iter().map(DesignBlock::from).collect())?;
4603 Ok(Self::Dense(DenseDesignMatrix::from(Arc::new(operator))))
4604 }
4605
4606 pub fn nrows(&self) -> usize {
4607 <Self as LinearOperator>::nrows(self)
4608 }
4609
4610 pub fn ncols(&self) -> usize {
4611 <Self as LinearOperator>::ncols(self)
4612 }
4613
4614 pub fn try_row_chunk(
4620 &self,
4621 rows: Range<usize>,
4622 ) -> Result<Array2<f64>, MatrixMaterializationError> {
4623 match self {
4624 Self::Dense(matrix) => matrix.try_row_chunk(rows),
4625 Self::Sparse(matrix) => {
4626 let csr =
4627 matrix
4628 .to_csr_arc()
4629 .ok_or(MatrixMaterializationError::MissingRowChunk {
4630 context: "DesignMatrix::try_row_chunk: failed to obtain CSR view",
4631 })?;
4632 let sym = csr.symbolic();
4633 let row_ptr = sym.row_ptr();
4634 let col_idx = sym.col_idx();
4635 let vals = csr.val();
4636 let chunk_rows = rows.end - rows.start;
4637 let ncols = self.ncols();
4638 let mut out = Array2::<f64>::zeros((chunk_rows, ncols));
4639 for (local_row, row) in rows.enumerate() {
4640 for ptr in row_ptr[row]..row_ptr[row + 1] {
4641 out[[local_row, col_idx[ptr]]] = vals[ptr];
4642 }
4643 }
4644 Ok(out)
4645 }
4646 }
4647 }
4648
4649 pub fn row_chunk_into(
4655 &self,
4656 rows: Range<usize>,
4657 out: ArrayViewMut2<'_, f64>,
4658 ) -> Result<(), MatrixMaterializationError> {
4659 <Self as DenseDesignOperator>::row_chunk_into(self, rows, out)
4660 }
4661
4662 pub fn try_to_dense_by_chunks(&self, context: &str) -> Result<Array2<f64>, String> {
4663 let n = self.nrows();
4664 let p = self.ncols();
4665 let chunk_rows = dense_materialization_chunk_rows(n, p);
4666 let mut out = Array2::<f64>::zeros((n, p));
4667 for start in (0..n).step_by(chunk_rows) {
4668 let end = (start + chunk_rows).min(n);
4669 let slice = out.slice_mut(s![start..end, ..]);
4670 self.row_chunk_into(start..end, slice)
4671 .map_err(|err| format!("{context}: failed to materialize row chunk: {err}"))?;
4672 }
4673 Ok(out)
4674 }
4675
4676 pub fn try_to_dense_by_chunks_budgeted(
4682 &self,
4683 context: &str,
4684 max_bytes: usize,
4685 ) -> Result<Array2<f64>, String> {
4686 let n = self.nrows();
4687 let p = self.ncols();
4688 let dense_bytes = checked_dense_nbytes(n, p, context)?;
4689 if dense_bytes > max_bytes {
4690 let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
4691 let cap_gib = max_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
4692 return Err(MatrixError::DensificationRefused {
4693 reason: format!(
4694 "{context}: refusing to densify {n}x{p} (~{gib:.2} GiB, cap ~{cap_gib:.2} GiB)"
4695 ),
4696 }
4697 .into());
4698 }
4699 self.try_to_dense_by_chunks(context)
4700 }
4701
4702 pub fn dot_row(&self, row: usize, beta: &Array1<f64>) -> f64 {
4705 self.dot_row_view(row, beta.view())
4706 }
4707
4708 pub fn dot_row_view(&self, row: usize, beta: ArrayView1<'_, f64>) -> f64 {
4709 assert_eq!(
4710 beta.len(),
4711 self.ncols(),
4712 "DesignMatrix::dot_row_view length mismatch: beta={}, ncols={}",
4713 beta.len(),
4714 self.ncols()
4715 );
4716 match self {
4717 Self::Dense(matrix) => {
4718 if let Some(dense) = matrix.as_dense_ref() {
4719 dense.row(row).dot(&beta)
4720 } else {
4721 matrix
4722 .try_row_chunk(row..row + 1)
4723 .expect("DesignMatrix::dot_row_view: try_row_chunk must succeed")
4724 .row(0)
4725 .dot(&beta)
4726 }
4727 }
4728 Self::Sparse(matrix) => {
4729 let csr = matrix
4737 .to_csr_arc()
4738 .expect("DesignMatrix::dot_row: failed to obtain CSR view");
4739 let sym = csr.symbolic();
4740 let row_ptr = sym.row_ptr();
4741 let col_idx = sym.col_idx();
4742 let vals = csr.val();
4743 let mut out = 0.0;
4744 for ptr in row_ptr[row]..row_ptr[row + 1] {
4745 out += vals[ptr] * beta[col_idx[ptr]];
4746 }
4747 out
4748 }
4749 }
4750 }
4751
4752 pub fn axpy_row_into(
4754 &self,
4755 row: usize,
4756 alpha: f64,
4757 out: &mut ArrayViewMut1<'_, f64>,
4758 ) -> Result<(), String> {
4759 self.axpy_row_into_impl(row, alpha, out, false, "axpy_row_into")
4760 }
4761
4762 pub fn squared_axpy_row_into(
4765 &self,
4766 row: usize,
4767 alpha: f64,
4768 out: &mut ArrayViewMut1<'_, f64>,
4769 ) -> Result<(), String> {
4770 self.axpy_row_into_impl(row, alpha, out, true, "squared_axpy_row_into")
4771 }
4772
4773 #[inline]
4780 fn axpy_row_into_impl(
4781 &self,
4782 row: usize,
4783 alpha: f64,
4784 out: &mut ArrayViewMut1<'_, f64>,
4785 square: bool,
4786 method: &str,
4787 ) -> Result<(), String> {
4788 if out.len() != self.ncols() {
4789 return Err(format!(
4790 "DesignMatrix::{method} length mismatch: out={}, ncols={}",
4791 out.len(),
4792 self.ncols()
4793 ));
4794 }
4795 if alpha == 0.0 {
4796 return Ok(());
4797 }
4798 let scale = |value: f64| {
4800 if square {
4801 alpha * value * value
4802 } else {
4803 alpha * value
4804 }
4805 };
4806 match self {
4807 Self::Dense(matrix) => {
4808 if let Some(dense) = matrix.as_dense_ref() {
4809 for (dst, &value) in out.iter_mut().zip(dense.row(row).iter()) {
4810 *dst += scale(value);
4811 }
4812 } else {
4813 let chunk = matrix
4814 .try_row_chunk(row..row + 1)
4815 .map_err(|e| format!("DesignMatrix::{method}: {e}"))?;
4816 for (dst, &value) in out.iter_mut().zip(chunk.row(0).iter()) {
4817 *dst += scale(value);
4818 }
4819 }
4820 }
4821 Self::Sparse(matrix) => {
4822 let csr = matrix
4827 .to_csr_arc()
4828 .ok_or_else(|| format!("DesignMatrix::{method}: failed to obtain CSR view"))?;
4829 let sym = csr.symbolic();
4830 let row_ptr = sym.row_ptr();
4831 let col_idx = sym.col_idx();
4832 let vals = csr.val();
4833 for ptr in row_ptr[row]..row_ptr[row + 1] {
4834 out[col_idx[ptr]] += scale(vals[ptr]);
4835 }
4836 }
4837 }
4838 Ok(())
4839 }
4840
4841 pub fn crossdiag_axpy_row_into(
4847 &self,
4848 row: usize,
4849 other: &DesignMatrix,
4850 alpha: f64,
4851 out: &mut ArrayViewMut1<'_, f64>,
4852 ) -> Result<(), String> {
4853 assert_eq!(self.ncols(), other.ncols());
4854 assert_eq!(out.len(), self.ncols());
4855 if alpha == 0.0 {
4856 return Ok(());
4857 }
4858 match (self, other) {
4859 (Self::Dense(lhs), Self::Dense(rhs)) => {
4860 let lhs_chunk;
4861 let rhs_chunk;
4862 let x = if let Some(lhs_dense) = lhs.as_dense_ref() {
4863 lhs_dense.row(row)
4864 } else {
4865 lhs_chunk = lhs
4866 .try_row_chunk(row..row + 1)
4867 .map_err(|e| format!("crossdiag_axpy_row_into lhs: {e}"))?;
4868 lhs_chunk.row(0)
4869 };
4870 let y = if let Some(rhs_dense) = rhs.as_dense_ref() {
4871 rhs_dense.row(row)
4872 } else {
4873 rhs_chunk = rhs
4874 .try_row_chunk(row..row + 1)
4875 .map_err(|e| format!("crossdiag_axpy_row_into rhs: {e}"))?;
4876 rhs_chunk.row(0)
4877 };
4878 for (dst, (&xi, &yi)) in out.iter_mut().zip(x.iter().zip(y.iter())) {
4879 *dst += alpha * xi * yi;
4880 }
4881 }
4882 (Self::Sparse(lhs), Self::Sparse(rhs)) => {
4883 let lhs_csr = lhs.to_csr_arc().ok_or_else(|| {
4889 "crossdiag_axpy_row_into: failed to obtain lhs CSR view".to_string()
4890 })?;
4891 let rhs_csr = rhs.to_csr_arc().ok_or_else(|| {
4892 "crossdiag_axpy_row_into: failed to obtain rhs CSR view".to_string()
4893 })?;
4894 let lhs_sym = lhs_csr.symbolic();
4895 let rhs_sym = rhs_csr.symbolic();
4896 let lhs_rp = lhs_sym.row_ptr();
4897 let rhs_rp = rhs_sym.row_ptr();
4898 let lhs_ci = lhs_sym.col_idx();
4899 let rhs_ci = rhs_sym.col_idx();
4900 let lhs_v = lhs_csr.val();
4901 let rhs_v = rhs_csr.val();
4902 let mut li = lhs_rp[row];
4904 let mut ri = rhs_rp[row];
4905 let l_end = lhs_rp[row + 1];
4906 let r_end = rhs_rp[row + 1];
4907 while li < l_end && ri < r_end {
4908 let lc = lhs_ci[li];
4909 let rc = rhs_ci[ri];
4910 if lc == rc {
4911 out[lc] += alpha * lhs_v[li] * rhs_v[ri];
4912 li += 1;
4913 ri += 1;
4914 } else if lc < rc {
4915 li += 1;
4916 } else {
4917 ri += 1;
4918 }
4919 }
4920 }
4921 _ => {
4922 let (sparse_mat, dense_mat) = match (self, other) {
4924 (Self::Sparse(s), Self::Dense(d)) => (s, d),
4925 (Self::Dense(d), Self::Sparse(s)) => (s, d),
4926 _ => {
4929 return Err(
4930 "crossdiag_axpy_row_into: mixed-arm dispatch reached non-mixed pair"
4931 .to_string(),
4932 );
4933 }
4934 };
4935 let csr = sparse_mat.to_csr_arc().ok_or_else(|| {
4939 "crossdiag_axpy_row_into: failed to obtain CSR view".to_string()
4940 })?;
4941 let sym = csr.symbolic();
4942 let row_ptr = sym.row_ptr();
4943 let col_idx = sym.col_idx();
4944 let vals = csr.val();
4945 let dense_chunk;
4946 let dense_row = if let Some(dense_ref) = dense_mat.as_dense_ref() {
4947 dense_ref.row(row)
4948 } else {
4949 dense_chunk = dense_mat
4950 .try_row_chunk(row..row + 1)
4951 .map_err(|e| format!("crossdiag_axpy_row_into dense chunk: {e}"))?;
4952 dense_chunk.row(0)
4953 };
4954 for ptr in row_ptr[row]..row_ptr[row + 1] {
4955 let c = col_idx[ptr];
4956 out[c] += alpha * vals[ptr] * dense_row[c];
4957 }
4958 }
4959 }
4960 Ok(())
4961 }
4962
4963 pub fn syr_row_into(
4965 &self,
4966 row: usize,
4967 alpha: f64,
4968 target: &mut Array2<f64>,
4969 ) -> Result<(), String> {
4970 self.syr_row_into_view(row, alpha, target.view_mut())
4971 }
4972
4973 pub fn syr_row_into_view(
4976 &self,
4977 row: usize,
4978 alpha: f64,
4979 mut target: ArrayViewMut2<'_, f64>,
4980 ) -> Result<(), String> {
4981 if target.nrows() != self.ncols() || target.ncols() != self.ncols() {
4982 return Err(format!(
4983 "DesignMatrix::syr_row_into shape mismatch: target={}x{}, ncols={}",
4984 target.nrows(),
4985 target.ncols(),
4986 self.ncols()
4987 ));
4988 }
4989 if alpha == 0.0 {
4990 return Ok(());
4991 }
4992 match self {
4993 Self::Dense(matrix) => {
4994 if let Some(dense) = matrix.as_dense_ref() {
4995 let x = dense.row(row);
4996 for i in 0..x.len() {
4997 let xi = x[i];
4998 if xi == 0.0 {
4999 continue;
5000 }
5001 for j in 0..x.len() {
5002 target[[i, j]] += alpha * xi * x[j];
5003 }
5004 }
5005 } else {
5006 let chunk = matrix
5007 .try_row_chunk(row..row + 1)
5008 .map_err(|e| format!("DesignMatrix::syr_row_into: {e}"))?;
5009 let x = chunk.row(0);
5010 for i in 0..x.len() {
5011 let xi = x[i];
5012 if xi == 0.0 {
5013 continue;
5014 }
5015 for j in 0..x.len() {
5016 target[[i, j]] += alpha * xi * x[j];
5017 }
5018 }
5019 }
5020 }
5021 Self::Sparse(matrix) => {
5022 let csr = matrix.to_csr_arc().ok_or_else(|| {
5027 "DesignMatrix::syr_row_into: failed to obtain CSR view".to_string()
5028 })?;
5029 let sym = csr.symbolic();
5030 let row_ptr = sym.row_ptr();
5031 let col_idx = sym.col_idx();
5032 let vals = csr.val();
5033 for ptr_i in row_ptr[row]..row_ptr[row + 1] {
5034 let i = col_idx[ptr_i];
5035 let xi = vals[ptr_i];
5036 for ptr_j in row_ptr[row]..row_ptr[row + 1] {
5037 let j = col_idx[ptr_j];
5038 target[[i, j]] += alpha * xi * vals[ptr_j];
5039 }
5040 }
5041 }
5042 }
5043 Ok(())
5044 }
5045
5046 pub fn row_outer_into(
5051 &self,
5052 row: usize,
5053 other: &DesignMatrix,
5054 alpha: f64,
5055 target: &mut Array2<f64>,
5056 ) -> Result<(), String> {
5057 self.row_outer_into_view(row, other, alpha, target.view_mut())
5058 }
5059
5060 pub fn row_outer_into_view(
5063 &self,
5064 row: usize,
5065 other: &DesignMatrix,
5066 alpha: f64,
5067 mut target: ArrayViewMut2<'_, f64>,
5068 ) -> Result<(), String> {
5069 if target.nrows() != self.ncols() || target.ncols() != other.ncols() {
5070 return Err(format!(
5071 "DesignMatrix::row_outer_into shape mismatch: target={}x{}, lhs={}, rhs={}",
5072 target.nrows(),
5073 target.ncols(),
5074 self.ncols(),
5075 other.ncols()
5076 ));
5077 }
5078 if alpha == 0.0 {
5079 return Ok(());
5080 }
5081 match (self, other) {
5082 (Self::Dense(lhs), Self::Dense(rhs)) => {
5083 let lhs_chunk;
5084 let rhs_chunk;
5085 let x = if let Some(lhs_dense) = lhs.as_dense_ref() {
5086 lhs_dense.row(row)
5087 } else {
5088 lhs_chunk = lhs
5089 .try_row_chunk(row..row + 1)
5090 .map_err(|e| format!("row_outer_into_view lhs: {e}"))?;
5091 lhs_chunk.row(0)
5092 };
5093 let y = if let Some(rhs_dense) = rhs.as_dense_ref() {
5094 rhs_dense.row(row)
5095 } else {
5096 rhs_chunk = rhs
5097 .try_row_chunk(row..row + 1)
5098 .map_err(|e| format!("row_outer_into_view rhs: {e}"))?;
5099 rhs_chunk.row(0)
5100 };
5101 for i in 0..x.len() {
5102 let xi = x[i];
5103 if xi == 0.0 {
5104 continue;
5105 }
5106 for j in 0..y.len() {
5107 target[[i, j]] += alpha * xi * y[j];
5108 }
5109 }
5110 }
5111 (Self::Sparse(lhs), Self::Sparse(rhs)) => {
5112 let lhs_csr = lhs
5117 .to_csr_arc()
5118 .ok_or_else(|| "row_outer_into: failed to obtain lhs CSR view".to_string())?;
5119 let rhs_csr = rhs
5121 .to_csr_arc()
5122 .ok_or_else(|| "row_outer_into: failed to obtain rhs CSR view".to_string())?;
5123 let lhs_sym = lhs_csr.symbolic();
5124 let rhs_sym = rhs_csr.symbolic();
5125 let lhs_rp = lhs_sym.row_ptr();
5126 let rhs_rp = rhs_sym.row_ptr();
5127 let lhs_ci = lhs_sym.col_idx();
5128 let rhs_ci = rhs_sym.col_idx();
5129 let lhs_v = lhs_csr.val();
5130 let rhs_v = rhs_csr.val();
5131 for pi in lhs_rp[row]..lhs_rp[row + 1] {
5132 let i = lhs_ci[pi];
5133 let xi = lhs_v[pi];
5134 for pj in rhs_rp[row]..rhs_rp[row + 1] {
5135 let j = rhs_ci[pj];
5136 target[[i, j]] += alpha * xi * rhs_v[pj];
5137 }
5138 }
5139 }
5140 _ => {
5141 let x = self
5143 .try_row_chunk(row..row + 1)
5144 .map_err(|e| format!("row_outer_into_view lhs: {e}"))?;
5145 let x_row = x.row(0);
5146 let y = other
5147 .try_row_chunk(row..row + 1)
5148 .map_err(|e| format!("row_outer_into_view rhs: {e}"))?;
5149 let y_row = y.row(0);
5150 for i in 0..x_row.len() {
5151 let xi = x_row[i];
5152 if xi == 0.0 {
5153 continue;
5154 }
5155 for j in 0..y_row.len() {
5156 target[[i, j]] += alpha * xi * y_row[j];
5157 }
5158 }
5159 }
5160 }
5161 Ok(())
5162 }
5163
5164 #[inline]
5176 pub fn get(&self, i: usize, j: usize) -> f64 {
5177 match self {
5178 Self::Dense(matrix) => match matrix.as_dense_ref() {
5179 Some(dense) => dense[[i, j]],
5180 None => {
5185 let mut e_j = Array1::<f64>::zeros(matrix.ncols());
5186 e_j[j] = 1.0;
5187 matrix.apply(&e_j)[i]
5188 }
5189 },
5190 Self::Sparse(sp) => {
5191 let dense = sp
5199 .try_to_dense_arc("DesignMatrix::get")
5200 .unwrap_or_else(|msg| std::panic::panic_any(msg));
5201 dense[[i, j]]
5202 }
5203 }
5204 }
5205
5206 pub fn extract_column(&self, j: usize) -> Array1<f64> {
5212 match self {
5213 Self::Dense(m) => {
5214 if let Some(dense) = m.as_dense_ref() {
5215 dense.column(j).to_owned()
5216 } else {
5217 let mut e_j = Array1::zeros(m.ncols());
5218 e_j[j] = 1.0;
5219 m.apply(&e_j)
5220 }
5221 }
5222 Self::Sparse(sp) => {
5223 let n = sp.nrows();
5224 let mut col = Array1::zeros(n);
5225 let (symbolic, values) = sp.parts();
5226 let col_ptr = symbolic.col_ptr();
5227 let row_idx = symbolic.row_idx();
5228 let start = col_ptr[j];
5229 let end = col_ptr[j + 1];
5230 for idx in start..end {
5231 col[row_idx[idx]] = values[idx];
5232 }
5233 col
5234 }
5235 }
5236 }
5237
5238 pub fn extract_columns(&self, cols: &[usize]) -> Array2<f64> {
5245 match self {
5246 Self::Dense(m) => match m {
5247 DenseDesignMatrix::Materialized(mat) => mat.select(Axis(1), cols),
5248 DenseDesignMatrix::Lazy(op) => op.apply_columns(cols),
5249 },
5250 Self::Sparse(sp) => {
5251 let n = sp.nrows();
5252 let mut out = Array2::<f64>::zeros((n, cols.len()));
5253 let (symbolic, values) = sp.parts();
5254 let col_ptr = symbolic.col_ptr();
5255 let row_idx = symbolic.row_idx();
5256 for (k, &j) in cols.iter().enumerate() {
5257 let start = col_ptr[j];
5258 let end = col_ptr[j + 1];
5259 let mut out_col = out.column_mut(k);
5260 for idx in start..end {
5261 out_col[row_idx[idx]] = values[idx];
5262 }
5263 }
5264 out
5265 }
5266 }
5267 }
5268
5269 pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
5271 match self {
5272 Self::Dense(matrix) => matrix.as_dense_ref(),
5273 Self::Sparse(_) => None,
5274 }
5275 }
5276
5277 pub const fn is_materialized_dense(&self) -> bool {
5278 matches!(self, Self::Dense(DenseDesignMatrix::Materialized(_)))
5279 }
5280
5281 pub const fn is_operator_backed(&self) -> bool {
5282 match self {
5283 Self::Dense(matrix) => matrix.is_operator_backed(),
5284 Self::Sparse(_) => false,
5285 }
5286 }
5287
5288 pub const fn is_sparse(&self) -> bool {
5294 matches!(self, Self::Sparse(_))
5295 }
5296
5297 pub fn as_dense_cow(&self) -> Cow<'_, Array2<f64>> {
5303 match self {
5304 Self::Dense(DenseDesignMatrix::Materialized(matrix)) => Cow::Borrowed(matrix.as_ref()),
5305 Self::Dense(DenseDesignMatrix::Lazy(op)) => match op.as_dense_ref() {
5306 Some(dense) => Cow::Borrowed(dense),
5307 None => std::panic::panic_any(format!(
5314 "DesignMatrix::as_dense_cow called on operator-backed design ({}x{}); use row chunks or matrix-vector products",
5315 op.nrows(),
5316 op.ncols()
5317 )),
5318 },
5319 Self::Sparse(matrix) => Cow::Owned(
5320 matrix
5321 .try_to_dense_arc("DesignMatrix::as_dense_cow")
5322 .unwrap_or_else(|msg| std::panic::panic_any(msg))
5328 .as_ref()
5329 .clone(),
5330 ),
5331 }
5332 }
5333
5334 pub fn to_dense_cow(&self) -> Cow<'_, Array2<f64>> {
5343 match self {
5344 Self::Dense(DenseDesignMatrix::Materialized(matrix)) => Cow::Borrowed(matrix.as_ref()),
5345 Self::Dense(DenseDesignMatrix::Lazy(op)) => {
5346 if let Some(dense) = op.as_dense_ref() {
5347 Cow::Borrowed(dense)
5348 } else {
5349 Cow::Owned(
5352 dense_operator_to_dense_by_chunks(op.as_ref()).unwrap_or_else(|err| {
5353 std::panic::panic_any(format!(
5360 "DesignMatrix::to_dense_cow: failed to materialize {}x{} \
5361 operator-backed design via row chunks: {err}",
5362 op.nrows(),
5363 op.ncols(),
5364 ))
5365 }),
5366 )
5367 }
5368 }
5369 Self::Sparse(matrix) => Cow::Owned(
5370 matrix
5371 .try_to_dense_arc("DesignMatrix::to_dense_cow")
5372 .unwrap_or_else(|msg| std::panic::panic_any(msg))
5378 .as_ref()
5379 .clone(),
5380 ),
5381 }
5382 }
5383
5384 pub fn to_dense(&self) -> Array2<f64> {
5408 match self {
5409 Self::Dense(matrix) => matrix.to_dense(),
5410 Self::Sparse(matrix) => matrix
5411 .try_to_dense_arc("DesignMatrix::to_dense")
5412 .unwrap_or_else(|msg| std::panic::panic_any(msg))
5419 .as_ref()
5420 .clone(),
5421 }
5422 }
5423
5424 pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
5427 match self {
5428 Self::Dense(matrix) => matrix.to_dense_arc(),
5429 Self::Sparse(matrix) => matrix
5430 .try_to_dense_arc("DesignMatrix::to_dense_arc")
5431 .unwrap_or_else(|msg| std::panic::panic_any(msg)),
5438 }
5439 }
5440
5441 pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
5442 match self {
5443 Self::Dense(matrix) => matrix.try_to_dense_arc(context),
5444 Self::Sparse(matrix) => matrix.try_to_dense_arc(context),
5445 }
5446 }
5447
5448 pub fn try_to_dense_arc_with_policy(
5451 &self,
5452 context: &str,
5453 policy: &ResourcePolicy,
5454 ) -> Result<Arc<Array2<f64>>, String> {
5455 match self {
5456 Self::Dense(matrix) => matrix.try_to_dense_arc_with_policy(context, policy),
5457 Self::Sparse(matrix) => matrix.try_to_dense_arc(context),
5458 }
5459 }
5460
5461 pub fn to_csr_cache(&self) -> Option<SparseRowMat<usize, f64>> {
5462 match self {
5463 Self::Dense(_) => None,
5464 Self::Sparse(matrix) => matrix.to_csr_arc().map(|arc| (*arc).clone()),
5465 }
5466 }
5467
5468 pub fn as_sparse(&self) -> Option<&SparseDesignMatrix> {
5469 match self {
5470 Self::Sparse(matrix) => Some(matrix),
5471 Self::Dense(_) => None,
5472 }
5473 }
5474
5475 pub fn as_dense(&self) -> Option<&Array2<f64>> {
5476 match self {
5477 Self::Dense(matrix) => matrix.as_dense_ref(),
5478 Self::Sparse(_) => None,
5479 }
5480 }
5481
5482 fn apply_transpose_view(&self, vector: ArrayView1<'_, f64>) -> Array1<f64> {
5483 match self {
5484 Self::Dense(DenseDesignMatrix::Materialized(matrix)) => fast_atv(matrix, &vector),
5485 Self::Dense(DenseDesignMatrix::Lazy(op)) => op.apply_transpose(&vector.to_owned()),
5486 Self::Sparse(matrix) => {
5487 let mut output = Array1::<f64>::zeros(matrix.ncols());
5488 let (symbolic, values) = matrix.parts();
5489 let col_ptr = symbolic.col_ptr();
5490 let row_idx = symbolic.row_idx();
5491 for col in 0..matrix.ncols() {
5492 let mut acc = 0.0;
5493 let start = col_ptr[col];
5494 let end = col_ptr[col + 1];
5495 for idx in start..end {
5496 acc += values[idx] * vector[row_idx[idx]];
5497 }
5498 output[col] = acc;
5499 }
5500 output
5501 }
5502 }
5503 }
5504
5505 fn diag_gram_view(&self, weights: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
5506 if weights.len() != self.nrows() {
5507 return Err(format!(
5508 "diag_gram dimension mismatch: weights length {} != nrows {}",
5509 weights.len(),
5510 self.nrows()
5511 ));
5512 }
5513 match self {
5514 Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
5515 let psd = PsdWeightsView::try_new(weights)?;
5520 Ok(dense_diag_gram_view(matrix, psd))
5521 }
5522 Self::Dense(DenseDesignMatrix::Lazy(op)) => op.diag_gram(&weights.to_owned()),
5523 Self::Sparse(xs) => {
5524 let p = xs.ncols();
5525 let csr = xs
5526 .to_csr_arc()
5527 .ok_or_else(|| "failed to obtain CSR view in diag_gram".to_string())?;
5528 let sym = csr.symbolic();
5529 Ok(sparse_csr_diag_gram(
5530 sym.row_ptr(),
5531 sym.col_idx(),
5532 csr.val(),
5533 xs.nrows(),
5534 p,
5535 weights,
5536 ))
5537 }
5538 }
5539 }
5540
5541 fn compute_xtwy_view(
5542 &self,
5543 weights: ArrayView1<'_, f64>,
5544 y: ArrayView1<'_, f64>,
5545 ) -> Result<Array1<f64>, String> {
5546 if weights.len() != self.nrows() || y.len() != self.nrows() {
5547 return Err(format!(
5548 "compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
5549 weights.len(),
5550 y.len(),
5551 self.nrows()
5552 ));
5553 }
5554 match self {
5555 Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
5556 Ok(dense_transpose_weighted_response_view(matrix, weights, y))
5557 }
5558 Self::Dense(DenseDesignMatrix::Lazy(op)) => {
5559 op.compute_xtwy(&weights.to_owned(), &y.to_owned())
5560 }
5561 Self::Sparse(xs) => {
5562 let csr = xs
5563 .as_ref()
5564 .to_row_major()
5565 .map_err(|_| "failed to obtain CSR view in compute_xtwy".to_string())?;
5566 let sym = csr.symbolic();
5567 let row_ptr = sym.row_ptr();
5568 let col_idx = sym.col_idx();
5569 let vals = csr.val();
5570 let mut out = Array1::<f64>::zeros(xs.ncols());
5571 for i in 0..xs.nrows() {
5572 let scaled = weights[i].max(0.0) * y[i];
5573 if scaled == 0.0 {
5574 continue;
5575 }
5576 for idx in row_ptr[i]..row_ptr[i + 1] {
5577 out[col_idx[idx]] += vals[idx] * scaled;
5578 }
5579 }
5580 Ok(out)
5581 }
5582 }
5583 }
5584
5585 pub fn dot(&self, vector: &Array1<f64>) -> Array1<f64> {
5586 <Self as LinearOperator>::apply(self, vector)
5587 }
5588
5589 pub fn matrixvectormultiply(&self, vector: &Array1<f64>) -> Array1<f64> {
5590 <Self as LinearOperator>::apply(self, vector)
5591 }
5592
5593 pub fn transpose_vector_multiply(&self, vector: &Array1<f64>) -> Array1<f64> {
5594 <Self as LinearOperator>::apply_transpose(self, vector)
5595 }
5596
5597 pub fn compute_xtwy(
5598 &self,
5599 weights: &Array1<f64>,
5600 y: &Array1<f64>,
5601 ) -> Result<Array1<f64>, String> {
5602 <Self as DenseDesignOperator>::compute_xtwy(self, weights, y)
5603 }
5604
5605 pub fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
5606 <Self as LinearOperator>::diag_gram(self, weights)
5607 }
5608
5609 pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
5610 <Self as DenseDesignOperator>::quadratic_form_diag(self, middle)
5611 }
5612
5613 pub fn apply_weighted_normal(
5614 &self,
5615 weights: &Array1<f64>,
5616 vector: &Array1<f64>,
5617 penalty: Option<&Array2<f64>>,
5618 ridge: f64,
5619 ) -> Array1<f64> {
5620 <Self as LinearOperator>::apply_weighted_normal(self, weights, vector, penalty, ridge)
5621 }
5622
5623 pub fn solve_system(
5624 &self,
5625 weights: &Array1<f64>,
5626 rhs: &Array1<f64>,
5627 penalty: Option<&Array2<f64>>,
5628 ) -> Result<Array1<f64>, String> {
5629 <Self as LinearOperator>::solve_system(self, weights, rhs, penalty)
5630 }
5631
5632 pub fn solve_systemwith_policy(
5633 &self,
5634 weights: &Array1<f64>,
5635 rhs: &Array1<f64>,
5636 penalty: Option<&Array2<f64>>,
5637 ridge_floor: f64,
5638 ridge_policy: RidgePolicy,
5639 ) -> Result<Array1<f64>, String> {
5640 <Self as LinearOperator>::solve_systemwith_policy(
5641 self,
5642 weights,
5643 rhs,
5644 penalty,
5645 ridge_floor,
5646 ridge_policy,
5647 )
5648 }
5649
5650 pub fn solve_system_matrix_free_pcg(
5651 &self,
5652 weights: &Array1<f64>,
5653 rhs: &Array1<f64>,
5654 penalty: Option<&Array2<f64>>,
5655 ridge_floor: f64,
5656 ) -> Result<Array1<f64>, String> {
5657 <Self as LinearOperator>::solve_system_matrix_free_pcg_try(
5658 self,
5659 weights,
5660 rhs,
5661 penalty,
5662 ridge_floor.max(SPD_SOLVE_RIDGE_FLOOR),
5663 )
5664 }
5665
5666 pub fn solve_system_matrix_free_pcg_with_info(
5667 &self,
5668 weights: &Array1<f64>,
5669 rhs: &Array1<f64>,
5670 penalty: Option<&Array2<f64>>,
5671 ridge_floor: f64,
5672 ) -> Result<(Array1<f64>, PcgSolveInfo), String> {
5673 <Self as LinearOperator>::solve_system_matrix_free_pcg_with_info_try(
5674 self,
5675 weights,
5676 rhs,
5677 penalty,
5678 ridge_floor.max(SPD_SOLVE_RIDGE_FLOOR),
5679 )
5680 }
5681
5682 pub fn should_use_matrix_free_pcg(&self) -> bool {
5683 <Self as LinearOperator>::uses_matrix_free_pcg(self)
5684 && self.ncols() >= MATRIX_FREE_PCG_MIN_P
5685 }
5686
5687 pub fn factorize_system(
5688 &self,
5689 weights: &Array1<f64>,
5690 penalty: Option<&Array2<f64>>,
5691 ) -> Result<Box<dyn FactorizedSystem>, String> {
5692 <Self as LinearOperator>::factorize_system(self, weights, penalty)
5693 }
5694}
5695
5696impl<'a> From<ArrayView2<'a, f64>> for DesignMatrix {
5697 fn from(value: ArrayView2<'a, f64>) -> Self {
5698 Self::Dense(DenseDesignMatrix::from(value.to_owned()))
5699 }
5700}
5701
5702impl From<Array2<f64>> for DesignMatrix {
5703 fn from(value: Array2<f64>) -> Self {
5704 Self::Dense(DenseDesignMatrix::from(value))
5705 }
5706}
5707
5708impl From<Arc<Array2<f64>>> for DesignMatrix {
5709 fn from(value: Arc<Array2<f64>>) -> Self {
5710 Self::Dense(DenseDesignMatrix::from(value))
5711 }
5712}
5713
5714impl From<&Array2<f64>> for DesignMatrix {
5715 fn from(value: &Array2<f64>) -> Self {
5716 Self::Dense(DenseDesignMatrix::from(value.clone()))
5717 }
5718}
5719
5720impl From<DenseDesignMatrix> for DesignMatrix {
5721 fn from(value: DenseDesignMatrix) -> Self {
5722 Self::Dense(value)
5723 }
5724}
5725
5726impl From<SparseColMat<usize, f64>> for DesignMatrix {
5727 fn from(value: SparseColMat<usize, f64>) -> Self {
5728 Self::Sparse(SparseDesignMatrix::new(value))
5729 }
5730}
5731
5732impl From<&SparseColMat<usize, f64>> for DesignMatrix {
5733 fn from(value: &SparseColMat<usize, f64>) -> Self {
5734 Self::Sparse(SparseDesignMatrix::new(value.clone()))
5735 }
5736}
5737
5738impl From<&DesignMatrix> for DesignMatrix {
5739 fn from(value: &DesignMatrix) -> Self {
5740 value.clone()
5741 }
5742}
5743
5744impl From<DesignMatrix> for DesignBlock {
5745 fn from(value: DesignMatrix) -> Self {
5746 match value {
5747 DesignMatrix::Dense(matrix) => Self::Dense(matrix),
5748 DesignMatrix::Sparse(matrix) => Self::Sparse(matrix),
5749 }
5750 }
5751}
5752
5753impl From<&DesignMatrix> for DesignBlock {
5754 fn from(value: &DesignMatrix) -> Self {
5755 match value {
5756 DesignMatrix::Dense(matrix) => Self::Dense(matrix.clone()),
5757 DesignMatrix::Sparse(matrix) => Self::Sparse(matrix.clone()),
5758 }
5759 }
5760}
5761
5762#[cfg(test)]
5763mod tests {
5764 use super::{
5765 BlockDesignOperator, CoefficientTransformOperator, DenseDesignMatrix, DenseDesignOperator,
5766 DesignBlock, DesignMatrix, EmbeddedColumnBlock, MultiChannelOperator, PsdWeightsView,
5767 ReparamOperator, ResidualisedDesignOperator, RowwiseKroneckerOperator, SignedWeightsView,
5768 SparseDesignMatrix, dense_operator_to_dense_by_chunks, dense_transpose_weighted_response,
5769 fast_atv, fast_av, streaming_sparse_csc_xt_diag_x, weighted_crossprod_dense_view,
5770 };
5771 use crate::matrix::LinearOperator;
5772 use crate::test_support::no_densify_design;
5773 use crate::types::RidgePolicy;
5774 use crate::utils::{PcgSolveInfo, StableSolver};
5775 use faer::sparse::{SparseColMat, SymbolicSparseColMat, Triplet};
5776 use gam_runtime::resource::{MatrixMaterializationError, ResourcePolicy};
5777 use ndarray::{Array1, Array2, ArrayViewMut2, Axis, array, s};
5778 use std::ops::Range;
5779 use std::sync::Arc;
5780 use std::sync::atomic::{AtomicUsize, Ordering};
5781
5782 struct ChunkOnlyOperator {
5783 n: usize,
5784 p: usize,
5785 row_chunk_calls: AtomicUsize,
5786 }
5787
5788 impl ChunkOnlyOperator {
5789 fn value(&self, i: usize, j: usize) -> f64 {
5790 ((i % 251) as f64) * 0.25 - ((j % 127) as f64) * 0.5 + ((i + j) % 7) as f64
5791 }
5792 }
5793
5794 impl LinearOperator for ChunkOnlyOperator {
5795 fn nrows(&self) -> usize {
5796 self.n
5797 }
5798
5799 fn ncols(&self) -> usize {
5800 self.p
5801 }
5802
5803 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
5804 let mut out = Array1::<f64>::zeros(self.n);
5805 for i in 0..self.n {
5806 let mut acc = 0.0;
5807 for j in 0..self.p {
5808 acc += self.value(i, j) * vector[j];
5809 }
5810 out[i] = acc;
5811 }
5812 out
5813 }
5814
5815 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
5816 let mut out = Array1::<f64>::zeros(self.p);
5817 for i in 0..self.n {
5818 for j in 0..self.p {
5819 out[j] += self.value(i, j) * vector[i];
5820 }
5821 }
5822 out
5823 }
5824
5825 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
5826 let dense = dense_operator_to_dense_by_chunks(self).map_err(|err| err.to_string())?;
5827 let psd = PsdWeightsView::try_new(weights.view())?;
5828 Ok(weighted_crossprod_dense_view(&dense, psd.view(), &dense))
5829 }
5830 }
5831
5832 impl DenseDesignOperator for ChunkOnlyOperator {
5833 fn row_chunk_into(
5834 &self,
5835 rows: Range<usize>,
5836 mut out: ArrayViewMut2<'_, f64>,
5837 ) -> Result<(), MatrixMaterializationError> {
5838 self.row_chunk_calls.fetch_add(1, Ordering::SeqCst);
5839 if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
5840 return Err(MatrixMaterializationError::MissingRowChunk {
5841 context: "ChunkOnlyOperator::row_chunk_into shape mismatch",
5842 });
5843 }
5844 for (local, row) in rows.enumerate() {
5845 for col in 0..self.p {
5846 out[[local, col]] = self.value(row, col);
5847 }
5848 }
5849 Ok(())
5850 }
5851
5852 fn to_dense(&self) -> Array2<f64> {
5853 panic!("ChunkOnlyOperator::to_dense fallback must not be used")
5855 }
5856 }
5857
5858 fn exact_weighted_penalized_solve(
5859 design: &Array2<f64>,
5860 weights: &Array1<f64>,
5861 rhs: &Array1<f64>,
5862 penalty: &Array2<f64>,
5863 ridge: f64,
5864 ) -> Array1<f64> {
5865 let mut h = design
5866 .t()
5867 .dot(&(design * &weights.view().insert_axis(Axis(1))));
5868 h += penalty;
5869 if ridge > 0.0 {
5870 for i in 0..h.nrows() {
5871 h[[i, i]] += ridge;
5872 }
5873 }
5874 StableSolver::new("matrix-free pcg exact reference")
5875 .solvevectorwithridge_retries(&h, rhs, 0.0)
5876 .expect("exact reference solve")
5877 }
5878
5879 #[test]
5880 fn fast_av_matches_ndarray_dot() {
5881 let x = array![[1.0, 2.0, -1.0], [0.5, -3.0, 4.0], [2.0, 0.0, 1.5]];
5882 let v = array![0.25, -1.0, 2.0];
5883 let expected = x.dot(&v);
5884 let got = fast_av(&x, &v);
5885 for i in 0..expected.len() {
5886 assert!((expected[i] - got[i]).abs() < 1e-12);
5887 }
5888 }
5889
5890 #[test]
5891 fn fast_atv_matches_ndarray_dot() {
5892 let x = array![[1.0, 2.0, -1.0], [0.5, -3.0, 4.0], [2.0, 0.0, 1.5]];
5893 let v = array![0.25, -1.0, 2.0];
5894 let expected = x.t().dot(&v);
5895 let got = fast_atv(&x, &v);
5896 for i in 0..expected.len() {
5897 assert!((expected[i] - got[i]).abs() < 1e-12);
5898 }
5899 }
5900
5901 #[test]
5902 fn sparse_to_dense_accumulates_duplicate_entries() {
5903 let symbolic = SymbolicSparseColMat::new_unsorted_checked(
5906 3,
5907 2,
5908 vec![0_usize, 2, 3],
5909 None,
5910 vec![1_usize, 1, 0],
5911 );
5912 let sparse = SparseColMat::new(symbolic, vec![2.0_f64, 3.5, -1.0]);
5913 let design = DesignMatrix::from(sparse);
5914 let dense = design.to_dense_arc();
5915
5916 assert!((dense[[1, 0]] - 5.5).abs() < 1e-12);
5917 assert!((dense[[0, 1]] + 1.0).abs() < 1e-12);
5918
5919 let v = array![4.0, -2.0];
5920 let y_sparse = design.matrixvectormultiply(&v);
5921 let y_dense = dense.dot(&v);
5922 for i in 0..y_sparse.len() {
5923 assert!((y_sparse[i] - y_dense[i]).abs() < 1e-12);
5924 }
5925 }
5926
5927 #[test]
5928 fn huge_sparse_densification_is_rejected_before_allocation() {
5929 let sparse = SparseColMat::try_new_from_triplets(500_000, 10_000, &[])
5930 .expect("empty sparse matrix should build");
5931 let design = SparseDesignMatrix::new(sparse);
5932 let err = design
5933 .try_to_dense_arc("matrix test")
5934 .expect_err("huge sparse densification should be rejected");
5935 assert!(err.contains("refusing to densify sparse design"));
5936 }
5937
5938 #[test]
5939 fn streaming_sparse_csc_xt_diag_x_matches_dense_signed_weights() {
5940 let sparse = SparseColMat::try_new_from_triplets(
5941 4,
5942 3,
5943 &[
5944 Triplet::new(0, 0, 1.0),
5945 Triplet::new(1, 0, 2.0),
5946 Triplet::new(2, 0, -1.0),
5947 Triplet::new(0, 1, 0.5),
5948 Triplet::new(1, 1, -3.0),
5949 Triplet::new(3, 1, 4.0),
5950 Triplet::new(0, 2, 2.0),
5951 Triplet::new(2, 2, 1.5),
5952 Triplet::new(3, 2, -0.25),
5953 ],
5954 )
5955 .expect("sparse matrix");
5956 let design = SparseDesignMatrix::new(sparse.clone());
5957 let dense = design.to_dense_arc();
5958 let weights = array![1.0, -2.0, 0.5, -1.5];
5959 let (symbolic, values) = sparse.parts();
5960 let mut got = Array2::<f64>::zeros((3, 3));
5961 streaming_sparse_csc_xt_diag_x(
5962 symbolic.col_ptr(),
5963 symbolic.row_idx(),
5964 values,
5965 4,
5966 3,
5967 weights.view(),
5968 &mut got,
5969 );
5970
5971 let mut expected = Array2::<f64>::zeros((3, 3));
5972 for row in 0..4 {
5973 for a in 0..3 {
5974 for b in 0..3 {
5975 expected[[a, b]] += weights[row] * dense[[row, a]] * dense[[row, b]];
5976 }
5977 }
5978 }
5979 let max_diff = (&got - &expected)
5980 .iter()
5981 .map(|v| v.abs())
5982 .fold(0.0_f64, f64::max);
5983 assert!(
5984 max_diff < 1e-12,
5985 "streamed sparse weighted Gram mismatch: max_diff={max_diff}"
5986 );
5987 }
5988
5989 #[test]
5990 fn multi_channel_operator_view_paths_match_stacked_dense_reference() {
5991 let dense_channel = array![[1.0, 2.0], [0.5, -1.0], [3.0, 0.25]];
5992 let sparse_dense = array![[0.0, 1.5], [2.0, 0.0], [-1.0, 0.75]];
5993 let sparse = SparseColMat::try_new_from_triplets(
5994 3,
5995 2,
5996 &[
5997 Triplet::new(1, 0, 2.0),
5998 Triplet::new(2, 0, -1.0),
5999 Triplet::new(0, 1, 1.5),
6000 Triplet::new(2, 1, 0.75),
6001 ],
6002 )
6003 .expect("sparse channel");
6004 let op = MultiChannelOperator::new(vec![
6005 DesignMatrix::Dense(DenseDesignMatrix::from(dense_channel.clone())),
6006 DesignMatrix::from(sparse),
6007 ])
6008 .expect("multi-channel operator");
6009 let mut stacked = Array2::<f64>::zeros((6, 2));
6010 stacked.slice_mut(s![0..3, ..]).assign(&dense_channel);
6011 stacked.slice_mut(s![3..6, ..]).assign(&sparse_dense);
6012
6013 let beta = array![0.25, -0.4];
6014 let expected_apply = stacked.dot(&beta);
6015 let got_apply = op.apply(&beta);
6016 for i in 0..expected_apply.len() {
6017 assert!((expected_apply[i] - got_apply[i]).abs() < 1e-12);
6018 }
6019
6020 let probe = array![0.5, -1.0, 0.25, 1.5, -0.75, 0.2];
6021 let expected_transpose = stacked.t().dot(&probe);
6022 let got_transpose = op.apply_transpose(&probe);
6023 for i in 0..expected_transpose.len() {
6024 assert!((expected_transpose[i] - got_transpose[i]).abs() < 1e-12);
6025 }
6026
6027 let weights = array![1.0, -0.5, 0.75, 2.0, 0.25, 1.5];
6028 let w_pos = weights.mapv(|w: f64| w.max(0.0));
6029 let weighted = stacked.clone() * w_pos.view().insert_axis(Axis(1));
6030 let expected_xtwx = stacked.t().dot(&weighted);
6031 let got_xtwx = op.diag_xtw_x(&weights).expect("multi-channel xtwx");
6032 for i in 0..expected_xtwx.nrows() {
6033 for j in 0..expected_xtwx.ncols() {
6034 assert!((expected_xtwx[[i, j]] - got_xtwx[[i, j]]).abs() < 1e-12);
6035 }
6036 }
6037
6038 let expected_diag = Array1::from_iter((0..2).map(|j| expected_xtwx[[j, j]]));
6039 let got_diag = op.diag_gram(&weights).expect("multi-channel diag gram");
6040 for i in 0..expected_diag.len() {
6041 assert!((expected_diag[i] - got_diag[i]).abs() < 1e-12);
6042 }
6043
6044 let y = array![1.0, 0.5, -0.25, 2.0, -1.0, 0.75];
6045 let expected_xtwy = stacked.t().dot(&(w_pos * &y));
6046 let got_xtwy = op.compute_xtwy(&weights, &y).expect("multi-channel xtwy");
6047 for i in 0..expected_xtwy.len() {
6048 assert!((expected_xtwy[i] - got_xtwy[i]).abs() < 1e-12);
6049 }
6050 }
6051
6052 #[test]
6059 fn block_design_fused_dense_cross_matches_stacked_reference_xtwx() {
6060 let b0 = array![
6061 [1.0, 2.0],
6062 [0.5, -1.0],
6063 [3.0, 0.25],
6064 [-2.0, 1.5],
6065 [0.75, -0.5],
6066 ];
6067 let b1 = array![
6068 [-1.0, 0.5, 2.0],
6069 [1.5, -0.25, 0.0],
6070 [0.0, 1.0, -1.5],
6071 [2.0, 0.5, 1.0],
6072 [-0.5, -1.0, 0.25],
6073 ];
6074 let b2 = array![[0.5], [-1.0], [2.0], [0.25], [-0.75]];
6075
6076 let mut stacked = Array2::<f64>::zeros((5, 6));
6077 stacked.slice_mut(s![.., 0..2]).assign(&b0);
6078 stacked.slice_mut(s![.., 2..5]).assign(&b1);
6079 stacked.slice_mut(s![.., 5..6]).assign(&b2);
6080
6081 let blocks = vec![
6082 DesignBlock::Dense(DenseDesignMatrix::from(b0)),
6083 DesignBlock::Dense(DenseDesignMatrix::from(b1)),
6084 DesignBlock::Dense(DenseDesignMatrix::from(b2)),
6085 ];
6086 let op = BlockDesignOperator::new(blocks).expect("block design");
6087
6088 let weights = array![1.5, -0.5, 2.0, -1.0, 0.75];
6090 let weighted = stacked.clone() * weights.view().insert_axis(Axis(1));
6091 let expected = stacked.t().dot(&weighted);
6092
6093 let got = op.diag_xtw_x(&weights).expect("block fused xtwx");
6094 assert_eq!(got.dim(), (6, 6));
6095 let max_diff = (&got - &expected)
6096 .iter()
6097 .map(|v| v.abs())
6098 .fold(0.0_f64, f64::max);
6099 assert!(
6100 max_diff < 1e-10,
6101 "fused block Dense×Dense Gram mismatch: max_diff={max_diff}"
6102 );
6103 }
6104
6105 #[test]
6106 #[should_panic(expected = "ReparamOperator: X cols (2) must match Qs rows (3)")]
6107 fn reparam_operator_rejects_incompatible_transform_shape() {
6108 let x = array![[1.0, 2.0], [0.5, -1.0]];
6109 let qs = Arc::new(Array2::<f64>::zeros((3, 1)));
6110 ReparamOperator::new(DesignMatrix::Dense(DenseDesignMatrix::from(x)), qs);
6111 }
6112
6113 #[test]
6125 fn coefficient_transform_operator_exposes_cached_dense_to_block_dispatch() {
6126 let inner = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
6127 let transform = array![[0.5, -1.0, 2.0], [1.0, 0.0, -0.5]];
6128 let expected = inner.dot(&transform);
6129
6130 let op =
6131 CoefficientTransformOperator::new(DenseDesignMatrix::from(inner), transform.clone())
6132 .expect("coefficient transform operator");
6133 let dense_design = DenseDesignMatrix::from(Arc::new(op));
6134
6135 let probe = Array1::from_elem(3, 1.0);
6140 let warmed = dense_design.apply_transpose(&probe);
6141 assert_eq!(warmed.len(), expected.ncols());
6142
6143 let dense_ref = dense_design
6144 .as_dense_ref()
6145 .expect("DenseDesignMatrix::as_dense_ref must reach the cached X·T");
6146 assert_eq!(dense_ref.dim(), expected.dim());
6147 for ((r, c), v) in expected.indexed_iter() {
6148 assert!((dense_ref[[r, c]] - v).abs() < 1e-12);
6149 }
6150 }
6151
6152 #[test]
6153 fn design_matrix_hstack_preserves_lazy_blocks() {
6154 let left_dense = array![[1.0, 2.0], [3.0, 4.0]];
6155 let right_dense = array![[5.0], [6.0]];
6156 let left = no_densify_design(left_dense.clone());
6157 let right = no_densify_design(right_dense.clone());
6158 let stacked = DesignMatrix::hstack(vec![left, right]).expect("stacked design");
6159
6160 assert!(stacked.as_dense_ref().is_none());
6161 assert!(!stacked.is_materialized_dense());
6162 assert!(stacked.is_operator_backed());
6163 assert_eq!(stacked.nrows(), 2);
6164 assert_eq!(stacked.ncols(), 3);
6165
6166 let beta = array![0.25, -0.5, 2.0];
6167 let expected = array![9.25, 10.75];
6168 let got = stacked.dot(&beta);
6169 for i in 0..expected.len() {
6170 assert!((got[i] - expected[i]).abs() < 1e-12);
6171 }
6172
6173 let chunk = stacked
6174 .try_row_chunk(0..2)
6175 .expect("stacked.try_row_chunk must succeed");
6176 assert_eq!(chunk, array![[1.0, 2.0, 5.0], [3.0, 4.0, 6.0]]);
6177 }
6178
6179 #[test]
6180 #[should_panic(expected = "DesignMatrix::as_dense_cow called on operator-backed design")]
6181 fn design_matrix_as_dense_cow_rejects_operator_backed_designs() {
6182 let design = no_densify_design(array![[1.0, 2.0], [3.0, 4.0]]);
6183 design.as_dense_cow();
6184 }
6185
6186 #[test]
6187 fn sparse_factorized_solve_matches_dense_operator_solve() {
6188 let triplets = vec![
6189 Triplet::new(0usize, 0usize, 1.0),
6190 Triplet::new(1, 0, 2.0),
6191 Triplet::new(1, 1, -1.0),
6192 Triplet::new(2, 1, 3.0),
6193 Triplet::new(2, 2, 0.5),
6194 ];
6195 let sparse = SparseColMat::try_new_from_triplets(3, 3, &triplets)
6196 .expect("sparse design should build");
6197 let sparse_design = DesignMatrix::from(sparse);
6198 let dense_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
6199 sparse_design.to_dense(),
6200 ));
6201 let weights = array![1.5, 0.75, 2.0];
6202 let rhs = array![1.0, -0.5, 2.0];
6203 let penalty = Array2::from_diag(&array![0.25, 0.5, 0.75]);
6204
6205 let sparse_sol = sparse_design
6206 .solve_system(&weights, &rhs, Some(&penalty))
6207 .expect("sparse solve should factorize natively");
6208 let dense_sol = dense_design
6209 .solve_system(&weights, &rhs, Some(&penalty))
6210 .expect("dense solve should factorize");
6211
6212 for i in 0..rhs.len() {
6213 assert!(
6214 (sparse_sol[i] - dense_sol[i]).abs() < 1e-10,
6215 "solution mismatch at {i}: sparse={} dense={}",
6216 sparse_sol[i],
6217 dense_sol[i]
6218 );
6219 }
6220 }
6221
6222 #[test]
6223 fn solve_system_stabilizes_indefinite_penalty_and_returns_finite_solution() {
6224 let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![
6225 [1.0, 0.0],
6226 [0.0, 0.0]
6227 ]));
6228 let weights = array![1.0, 1.0];
6229 let rhs = array![2.0, 0.0];
6230 let penalty = array![[0.0, 0.0], [0.0, -1e-12]];
6231
6232 let beta = design
6233 .solve_system(&weights, &rhs, Some(&penalty))
6234 .expect("solve_system should stabilize indefinite systems");
6235
6236 assert!(beta.iter().all(|v| v.is_finite()));
6237 assert!((beta[0] - 2.0).abs() < 1e-10);
6238 assert!(beta[1].abs() < 1e-8);
6239 }
6240
6241 #[test]
6242 fn explicit_matrix_free_pcg_matches_exact_large_dense_weighted_penalized_solve() {
6243 let n = 48usize;
6244 let p = 520usize;
6245 let mut x = Array2::<f64>::zeros((n, p));
6246 for i in 0..n {
6247 for j in 0..p {
6248 x[[i, j]] = (((i + 3) * (j + 5)) % 17) as f64 / 17.0
6249 + 0.02 * (i as f64)
6250 + 0.001 * (j as f64);
6251 }
6252 }
6253 let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone()));
6254 let weights = Array1::from_iter((0..n).map(|i| 0.5 + (i as f64) / (2.0 * n as f64)));
6255 let rhs = Array1::from_iter((0..p).map(|j| ((j % 13) as f64 - 6.0) / 13.0));
6256 let penalty = Array2::from_diag(&Array1::from_iter(
6257 (0..p).map(|j| 0.1 + 0.005 * ((j % 7) as f64)),
6258 ));
6259 let ridge = 1e-8;
6260
6261 let pcg = design
6262 .solve_system_matrix_free_pcg(&weights, &rhs, Some(&penalty), ridge)
6263 .expect("matrix-free pcg solve");
6264 let exact = exact_weighted_penalized_solve(&x, &weights, &rhs, &penalty, ridge);
6265 for i in 0..p {
6266 assert!(
6267 (pcg[i] - exact[i]).abs() < 1e-5,
6268 "solution mismatch at {i}: pcg={} exact={}",
6269 pcg[i],
6270 exact[i]
6271 );
6272 }
6273 let mut h = x
6274 .t()
6275 .dot(&(x.clone() * weights.view().insert_axis(Axis(1))));
6276 h += &penalty;
6277 for i in 0..p {
6278 h[[i, i]] += ridge;
6279 }
6280 let residual = h.dot(&pcg) - &rhs;
6281 let residual_norm = residual.dot(&residual).sqrt();
6282 assert!(residual_norm < 1e-4, "residual_norm={residual_norm}");
6283 }
6284
6285 #[test]
6286 fn policy_solve_matches_explicit_matrix_free_pcg_on_large_dense_system() {
6287 let n = 40usize;
6288 let p = 520usize;
6289 let mut x = Array2::<f64>::zeros((n, p));
6290 for i in 0..n {
6291 for j in 0..p {
6292 x[[i, j]] = (((2 * i + j + 11) % 23) as f64 / 23.0) + 0.0005 * (j as f64);
6293 }
6294 }
6295 let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x));
6296 let weights = Array1::from_iter((0..n).map(|i| 1.0 + 0.01 * i as f64));
6297 let rhs = Array1::from_iter((0..p).map(|j| ((j % 5) as f64) - 2.0));
6298 let penalty = Array2::from_diag(&Array1::from_iter(
6299 (0..p).map(|j| 0.2 + 0.01 * ((j % 3) as f64)),
6300 ));
6301 let ridge_floor = 1e-8;
6302
6303 let explicit = design
6304 .solve_system_matrix_free_pcg(&weights, &rhs, Some(&penalty), ridge_floor)
6305 .expect("explicit pcg");
6306 let policy = design
6307 .solve_systemwith_policy(
6308 &weights,
6309 &rhs,
6310 Some(&penalty),
6311 ridge_floor,
6312 RidgePolicy::explicit_stabilization_pospart(),
6313 )
6314 .expect("policy solve");
6315 for i in 0..p {
6316 let tol = 1e-5 * (1.0 + explicit[i].abs());
6325 assert!(
6326 (explicit[i] - policy[i]).abs() < tol,
6327 "policy mismatch at {i}: explicit={} policy={} (tol={tol})",
6328 explicit[i],
6329 policy[i]
6330 );
6331 }
6332 }
6333
6334 #[test]
6335 fn explicit_matrix_free_pcg_reports_convergence_diagnostics() {
6336 let n = 36usize;
6337 let p = 2160usize;
6338 let mut x = Array2::<f64>::zeros((n, p));
6339 for i in 0..n {
6340 for j in 0..p {
6341 x[[i, j]] = (((3 * i + 5 * j + 7) % 29) as f64 / 29.0)
6342 + 0.015 * (i as f64)
6343 + 1e-4 * j as f64;
6344 }
6345 }
6346 let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone()));
6347 assert!(design.should_use_matrix_free_pcg());
6348 let weights = Array1::from_iter((0..n).map(|i| 0.75 + 0.01 * i as f64));
6349 let rhs = Array1::from_iter((0..p).map(|j| ((j % 9) as f64 - 4.0) / 9.0));
6350 let penalty = Array2::from_diag(&Array1::from_iter(
6351 (0..p).map(|j| 0.05 + 0.002 * ((j % 11) as f64)),
6352 ));
6353 let ridge = 1e-8;
6354
6355 let (pcg, info): (Array1<f64>, PcgSolveInfo) = design
6356 .solve_system_matrix_free_pcg_with_info(&weights, &rhs, Some(&penalty), ridge)
6357 .expect("pcg with info");
6358 assert!(info.converged);
6359 assert!(info.iterations > 0);
6360 assert!(info.relative_residual_norm.is_finite());
6361 assert!(info.relative_residual_norm < 1e-6);
6362
6363 let exact = exact_weighted_penalized_solve(&x, &weights, &rhs, &penalty, ridge);
6364 for i in 0..p {
6365 assert!(
6366 (pcg[i] - exact[i]).abs() < 1e-5,
6367 "solution mismatch at {i}: pcg={} exact={}",
6368 pcg[i],
6369 exact[i]
6370 );
6371 }
6372 }
6373
6374 #[test]
6375 fn compute_xtwy_dense_allocationfree_matches_matvec() {
6376 let n = 2_000usize;
6377 let p = 64usize;
6378 let mut x = Array2::<f64>::zeros((n, p));
6379 let mut y = Array1::<f64>::zeros(n);
6380 let mut w = Array1::<f64>::zeros(n);
6381 for i in 0..n {
6382 y[i] = ((i % 17) as f64 - 8.0) * 0.1;
6383 w[i] = 0.25 + ((i % 11) as f64) * 0.05;
6384 for j in 0..p {
6385 x[[i, j]] = (((i * 13 + j * 7) % 97) as f64) / 97.0;
6386 }
6387 }
6388
6389 let reference = {
6390 let wy = Array1::from_shape_fn(n, |i| y[i] * w[i].max(0.0));
6391 fast_atv(&x, &wy)
6392 };
6393 let fused = dense_transpose_weighted_response(&x, &w, &y, None);
6394 for j in 0..p {
6395 assert!(
6396 (reference[j] - fused[j]).abs() < 1e-10,
6397 "mismatch at column {j}: ref={} fused={}",
6398 reference[j],
6399 fused[j]
6400 );
6401 }
6402 }
6403
6404 #[test]
6405 fn large_lazy_dense_materialization_streams_chunks_without_to_dense_fallback() {
6406 let n = 11_000usize;
6407 let p = 128usize;
6408 let op = Arc::new(ChunkOnlyOperator {
6409 n,
6410 p,
6411 row_chunk_calls: AtomicUsize::new(0),
6412 });
6413 let design = DenseDesignMatrix::from(Arc::clone(&op));
6414
6415 let dense = design.to_dense_arc();
6416
6417 assert_eq!(dense.dim(), (n, p));
6418 assert!(
6419 op.row_chunk_calls.load(Ordering::SeqCst) > 1,
6420 "expected dense materialization to stream more than one row chunk"
6421 );
6422 for &(i, j) in &[(0, 0), (8_191, 127), (8_192, 0), (10_999, 64)] {
6423 assert_eq!(dense[[i, j]], op.value(i, j));
6424 }
6425 }
6426
6427 #[test]
6436 fn to_dense_arc_bypasses_policy_cap_strict_policy_still_refuses() {
6437 let op = Arc::new(ChunkOnlyOperator {
6438 n: 128,
6439 p: 4,
6440 row_chunk_calls: AtomicUsize::new(0),
6441 });
6442 let design = DenseDesignMatrix::from(Arc::clone(&op));
6443
6444 let dense = design.to_dense_arc();
6446 assert_eq!(dense.dim(), (128, 4));
6447
6448 let strict = ResourcePolicy::analytic_operator_required();
6451 let err = design
6452 .try_to_dense_arc_with_policy("regression strict refuses", &strict)
6453 .expect_err("strict policy must refuse lazy materialization");
6454 assert!(
6455 err.contains("refusing to densify operator-backed design")
6456 && err.contains("AnalyticOperatorRequired"),
6457 "unexpected strict-policy error: {err}"
6458 );
6459
6460 let mut tight = ResourcePolicy::default_library();
6464 tight.max_single_materialization_bytes = 1;
6465 let size_err = design
6466 .try_to_dense_arc_with_policy("regression tight refuses", &tight)
6467 .expect_err("undersized cap must refuse lazy materialization");
6468 assert!(
6469 size_err.contains("refusing to densify operator-backed design"),
6470 "unexpected size-cap error: {size_err}"
6471 );
6472 }
6473
6474 #[test]
6475 fn try_to_dense_by_chunks_writes_directly_into_output_slices() {
6476 let n = 11_000usize;
6477 let p = 128usize;
6478 let op = Arc::new(ChunkOnlyOperator {
6479 n,
6480 p,
6481 row_chunk_calls: AtomicUsize::new(0),
6482 });
6483 let design = DesignMatrix::Dense(DenseDesignMatrix::from(Arc::clone(&op)));
6484
6485 let dense = design
6486 .try_to_dense_by_chunks("large chunked regression")
6487 .expect("chunked materialization");
6488
6489 assert_eq!(dense.dim(), (n, p));
6490 assert!(
6491 op.row_chunk_calls.load(Ordering::SeqCst) > 1,
6492 "expected direct chunked conversion to use bounded row chunks"
6493 );
6494 for &(i, j) in &[(1, 7), (4_096, 12), (8_193, 63), (10_998, 127)] {
6495 assert_eq!(dense[[i, j]], op.value(i, j));
6496 }
6497 }
6498
6499 #[test]
6500 fn tensor_product_design_operator_matches_dense_2d() {
6501 use super::{DenseDesignOperator, TensorProductDesignOperator};
6502
6503 let n = 10;
6505 let q1 = 4;
6506 let q2 = 3;
6507 let mut b1 = Array2::<f64>::zeros((n, q1));
6508 let mut b2 = Array2::<f64>::zeros((n, q2));
6509 for i in 0..n {
6511 let t1 = i as f64 / (n - 1) as f64 * (q1 - 1) as f64;
6512 let j1 = (t1.floor() as usize).min(q1 - 2);
6513 let frac1 = t1 - j1 as f64;
6514 b1[[i, j1]] = 1.0 - frac1;
6515 b1[[i, j1 + 1]] = frac1;
6516
6517 let t2 = i as f64 / (n - 1) as f64 * (q2 - 1) as f64;
6518 let j2 = (t2.floor() as usize).min(q2 - 2);
6519 let frac2 = t2 - j2 as f64;
6520 b2[[i, j2]] = 1.0 - frac2;
6521 b2[[i, j2 + 1]] = frac2;
6522 }
6523
6524 let op = TensorProductDesignOperator::new(vec![Arc::new(b1.clone()), Arc::new(b2.clone())])
6525 .unwrap();
6526
6527 let p = q1 * q2;
6529 let mut dense = Array2::<f64>::zeros((n, p));
6530 for i in 0..n {
6531 for j1 in 0..q1 {
6532 for j2 in 0..q2 {
6533 dense[[i, j1 * q2 + j2]] = b1[[i, j1]] * b2[[i, j2]];
6534 }
6535 }
6536 }
6537
6538 let op_dense = op.to_dense();
6540 let max_diff = (&op_dense - &dense)
6541 .iter()
6542 .map(|v: &f64| v.abs())
6543 .fold(0.0f64, f64::max);
6544 assert!(max_diff < 1e-14, "to_dense mismatch: max_diff={max_diff}");
6545
6546 let beta = Array1::from_vec((0..p).map(|j| (j as f64 + 1.0) * 0.1).collect());
6548 let ref_result = dense.dot(&beta);
6549 let op_result = op.apply(&beta);
6550 let max_diff = (&op_result - &ref_result)
6551 .iter()
6552 .map(|v: &f64| v.abs())
6553 .fold(0.0f64, f64::max);
6554 assert!(max_diff < 1e-12, "apply mismatch: max_diff={max_diff}");
6555
6556 let v = Array1::from_vec((0..n).map(|i| (i as f64 + 1.0) * 0.3).collect());
6558 let ref_xt_v = dense.t().dot(&v);
6559 let op_xt_v = op.apply_transpose(&v);
6560 let max_diff = (&op_xt_v - &ref_xt_v)
6561 .iter()
6562 .map(|v: &f64| v.abs())
6563 .fold(0.0f64, f64::max);
6564 assert!(
6565 max_diff < 1e-12,
6566 "apply_transpose mismatch: max_diff={max_diff}"
6567 );
6568
6569 let w = Array1::from_vec((0..n).map(|i| 1.0 + i as f64 * 0.1).collect());
6571 let ref_xtwx = {
6572 let mut out = Array2::<f64>::zeros((p, p));
6573 for i in 0..n {
6574 for a in 0..p {
6575 for b in 0..p {
6576 out[[a, b]] += w[i] * dense[[i, a]] * dense[[i, b]];
6577 }
6578 }
6579 }
6580 out
6581 };
6582 let op_xtwx = op.diag_xtw_x(&w).unwrap();
6583 let max_diff = (&op_xtwx - &ref_xtwx)
6584 .iter()
6585 .map(|v: &f64| v.abs())
6586 .fold(0.0f64, f64::max);
6587 assert!(max_diff < 1e-10, "diag_xtw_x mismatch: max_diff={max_diff}");
6588 }
6589
6590 #[test]
6591 fn tensor_product_design_operator_3d() {
6592 use super::{DenseDesignOperator, TensorProductDesignOperator};
6593
6594 let n = 8;
6595 let dims = [3, 2, 2];
6596 let mut marginals: Vec<Array2<f64>> = Vec::new();
6597 for &q in &dims {
6598 let mut b = Array2::<f64>::zeros((n, q));
6599 for i in 0..n {
6600 let t = i as f64 / (n - 1) as f64 * (q - 1) as f64;
6601 let j = (t.floor() as usize).min(q - 2);
6602 let frac = t - j as f64;
6603 b[[i, j]] = 1.0 - frac;
6604 b[[i, j + 1]] = frac;
6605 }
6606 marginals.push(b);
6607 }
6608
6609 let op = TensorProductDesignOperator::new(
6610 marginals.iter().map(|m| Arc::new(m.clone())).collect(),
6611 )
6612 .unwrap();
6613
6614 let p: usize = dims.iter().copied().product();
6616 let mut dense = Array2::<f64>::zeros((n, p));
6617 for i in 0..n {
6618 for j0 in 0..dims[0] {
6619 for j1 in 0..dims[1] {
6620 for j2 in 0..dims[2] {
6621 let col = j0 * dims[1] * dims[2] + j1 * dims[2] + j2;
6622 dense[[i, col]] =
6623 marginals[0][[i, j0]] * marginals[1][[i, j1]] * marginals[2][[i, j2]];
6624 }
6625 }
6626 }
6627 }
6628
6629 let op_dense = op.to_dense();
6630 let max_diff = (&op_dense - &dense)
6631 .iter()
6632 .map(|v: &f64| v.abs())
6633 .fold(0.0f64, f64::max);
6634 assert!(
6635 max_diff < 1e-14,
6636 "3D to_dense mismatch: max_diff={max_diff}"
6637 );
6638
6639 let beta = Array1::from_vec((0..p).map(|j| (j as f64).sin()).collect());
6641 let xb = op.apply(&beta);
6642 let xtxb = op.apply_transpose(&xb);
6643 let ref_xtxb = dense.t().dot(&dense.dot(&beta));
6644 let max_diff = (&xtxb - &ref_xtxb)
6645 .iter()
6646 .map(|v: &f64| v.abs())
6647 .fold(0.0f64, f64::max);
6648 assert!(max_diff < 1e-10, "3D X'Xβ mismatch: max_diff={max_diff}");
6649 }
6650
6651 #[test]
6652 fn sparse_weighted_crossprod_parallel_path_matches_dense_reference() {
6653 use faer::sparse::Triplet;
6654
6655 let n = 4096;
6656 let p = 192;
6657 let mut triplets = Vec::with_capacity(n * 4);
6658 let mut dense = Array2::<f64>::zeros((n, p));
6659 for i in 0..n {
6660 let base = (i * 37) % p;
6661 for k in 0..4 {
6662 let col = (base + k * 11) % p;
6663 let val = ((i + 3 * k + 1) as f64).sin() * 0.25 + 0.5;
6664 triplets.push(Triplet::new(i, col, val));
6665 dense[[i, col]] = val;
6666 }
6667 }
6668 let sparse = faer::sparse::SparseColMat::try_new_from_triplets(n, p, &triplets).unwrap();
6669 let design = DesignMatrix::Sparse(SparseDesignMatrix::new(sparse));
6670 let weights = Array1::from_iter((0..n).map(|i| match i % 7 {
6671 0 => 0.0,
6672 r => 0.5 + r as f64 * 0.125,
6673 }));
6674
6675 let got = <DesignMatrix as LinearOperator>::xt_diag_x_signed_op(
6676 &design,
6677 SignedWeightsView::from_array(&weights),
6678 )
6679 .unwrap();
6680 let mut reference = Array2::<f64>::zeros((p, p));
6681 for i in 0..n {
6682 let wi = weights[i].max(0.0);
6683 if wi == 0.0 {
6684 continue;
6685 }
6686 for a in 0..p {
6687 let xa = dense[[i, a]];
6688 if xa == 0.0 {
6689 continue;
6690 }
6691 for b in 0..p {
6692 reference[[a, b]] += wi * xa * dense[[i, b]];
6693 }
6694 }
6695 }
6696 let max_diff = (&got - &reference)
6697 .iter()
6698 .map(|v: &f64| v.abs())
6699 .fold(0.0_f64, f64::max);
6700 assert!(
6701 max_diff < 1e-10,
6702 "sparse xtwx mismatch: max_diff={max_diff}"
6703 );
6704
6705 let got_diag = design.diag_gram(&weights).unwrap();
6706 let ref_diag = reference.diag().to_owned();
6707 let max_diag_diff = (&got_diag - &ref_diag)
6708 .iter()
6709 .map(|v: &f64| v.abs())
6710 .fold(0.0_f64, f64::max);
6711 assert!(
6712 max_diag_diff < 1e-10,
6713 "sparse diag gram mismatch: max_diff={max_diag_diff}"
6714 );
6715 }
6716
6717 #[test]
6718 fn rowwise_kronecker_sparse_structured_xtwx_matches_dense_reference() {
6719 use faer::sparse::Triplet;
6720
6721 let n = 2048;
6722 let p_cov = 64;
6723 let p_time = 6;
6724 let mut triplets = Vec::with_capacity(n * 3);
6725 let mut cov_dense = Array2::<f64>::zeros((n, p_cov));
6726 for i in 0..n {
6727 let base = (i * 17) % p_cov;
6728 for k in 0..3 {
6729 let col = (base + k * 7) % p_cov;
6730 let val = 0.2 + (((i + k) % 13) as f64) / 17.0;
6731 triplets.push(Triplet::new(i, col, val));
6732 cov_dense[[i, col]] = val;
6733 }
6734 }
6735 let cov_sparse =
6736 faer::sparse::SparseColMat::try_new_from_triplets(n, p_cov, &triplets).unwrap();
6737 let cov = DesignMatrix::Sparse(SparseDesignMatrix::new(cov_sparse));
6738 let mut time = Array2::<f64>::zeros((n, p_time));
6739 for i in 0..n {
6740 for t in 0..p_time {
6741 time[[i, t]] = (((i + 1) * (t + 3)) as f64).cos() * 0.1 + 0.4;
6742 }
6743 }
6744 let op = RowwiseKroneckerOperator::new(cov, Arc::new(time.clone())).unwrap();
6745 let weights = Array1::from_iter((0..n).map(|i| 0.25 + ((i % 11) as f64) * 0.05));
6746 let got = op.diag_xtw_x(&weights).unwrap();
6747
6748 let p_total = p_cov * p_time;
6749 let mut reference = Array2::<f64>::zeros((p_total, p_total));
6750 for i in 0..n {
6751 for c1 in 0..p_cov {
6752 let x1 = cov_dense[[i, c1]];
6753 if x1 == 0.0 {
6754 continue;
6755 }
6756 for t1 in 0..p_time {
6757 let a = c1 * p_time + t1;
6758 let xa = x1 * time[[i, t1]];
6759 for c2 in 0..p_cov {
6760 let x2 = cov_dense[[i, c2]];
6761 if x2 == 0.0 {
6762 continue;
6763 }
6764 for t2 in 0..p_time {
6765 let b = c2 * p_time + t2;
6766 reference[[a, b]] += weights[i] * xa * x2 * time[[i, t2]];
6767 }
6768 }
6769 }
6770 }
6771 }
6772 let max_diff = (&got - &reference)
6773 .iter()
6774 .map(|v: &f64| v.abs())
6775 .fold(0.0_f64, f64::max);
6776 assert!(
6777 max_diff < 1e-9,
6778 "rowwise kronecker sparse xtwx mismatch: max_diff={max_diff}"
6779 );
6780 }
6781
6782 #[test]
6783 fn embedded_column_block_zero_row_local_materializes_empty_global_width() {
6784 let local = Array2::<f64>::zeros((0, 0));
6785 let out = EmbeddedColumnBlock::new(&local, 2..5, 7).materialize();
6786 assert_eq!(out.dim(), (0, 7));
6787 }
6788
6789 #[test]
6795 fn residualised_design_operator_identity_passthrough() {
6796 let inner = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
6797 let transform = Array2::<f64>::eye(2);
6798 let anchor_raw = array![[7.0, -1.0], [0.5, 2.0], [-3.0, 1.5]];
6799 let r_block = Arc::new(Array2::<f64>::zeros((
6800 anchor_raw.ncols(),
6801 transform.ncols(),
6802 )));
6803 let anchor_design = DesignMatrix::from(anchor_raw);
6804
6805 let op = ResidualisedDesignOperator::new(
6806 DenseDesignMatrix::from(inner.clone()),
6807 transform,
6808 vec![(anchor_design, r_block)],
6809 )
6810 .expect("residualised operator constructs");
6811
6812 let mut chunk = Array2::<f64>::zeros((3, 2));
6815 op.row_chunk_into(0..3, chunk.view_mut())
6816 .expect("row chunk");
6817 for ((r, c), v) in inner.indexed_iter() {
6818 assert!(
6819 (chunk[[r, c]] - v).abs() < 1e-12,
6820 "identity row_chunk mismatch at ({r},{c}): got {} expected {v}",
6821 chunk[[r, c]]
6822 );
6823 }
6824
6825 let dense_design = DenseDesignMatrix::from(Arc::new(op));
6829 assert_eq!(dense_design.nrows(), 3);
6830 assert_eq!(dense_design.ncols(), 2);
6831 let probe = ndarray::Array1::from_vec(vec![1.0, -2.0]);
6832 let got = dense_design.apply(&probe);
6833 let expected = inner.dot(&probe);
6834 for i in 0..3 {
6835 assert!((got[i] - expected[i]).abs() < 1e-12);
6836 }
6837 }
6838
6839 #[test]
6847 fn residualised_design_operator_two_block_reconstruction() {
6848 let anchor = array![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0]];
6852 let b_raw = array![[1.0, 2.0], [1.0, 1.5], [1.0, 0.5], [1.0, -1.0]];
6853
6854 let v_b = array![[0.0], [1.0]];
6857
6858 let bv = b_raw.dot(&v_b); let ata = anchor.t().dot(&anchor); let atbv = anchor.t().dot(&bv); let ata_inv = {
6866 let det = ata[[0, 0]] * ata[[1, 1]] - ata[[0, 1]] * ata[[1, 0]];
6867 array![
6868 [ata[[1, 1]] / det, -ata[[0, 1]] / det],
6869 [-ata[[1, 0]] / det, ata[[0, 0]] / det],
6870 ]
6871 };
6872 let r_b: Array2<f64> = ata_inv.dot(&atbv); let op = ResidualisedDesignOperator::new(
6875 DenseDesignMatrix::from(b_raw.clone()),
6876 v_b.clone(),
6877 vec![(DesignMatrix::from(anchor.clone()), Arc::new(r_b.clone()))],
6878 )
6879 .expect("residualised operator constructs");
6880
6881 let gamma_a = ndarray::Array1::from_vec(vec![0.5, -1.25]);
6883 let theta_b = ndarray::Array1::from_vec(vec![2.5]);
6884
6885 let cv = b_raw.dot(&v_b); let ar = anchor.dot(&r_b); let emitted_b_chunk = &cv - &ar;
6889 let expected = anchor.dot(&gamma_a) + emitted_b_chunk.dot(&theta_b);
6890
6891 let mut got_chunk = Array2::<f64>::zeros((4, 1));
6895 op.row_chunk_into(0..4, got_chunk.view_mut())
6896 .expect("row chunk");
6897 let got = anchor.dot(&gamma_a) + got_chunk.dot(&theta_b);
6898 for i in 0..4 {
6899 assert!(
6900 (got[i] - expected[i]).abs() < 1e-10,
6901 "two-block reconstruction mismatch at row {i}: got {} expected {}",
6902 got[i],
6903 expected[i]
6904 );
6905 }
6906
6907 let applied = op.apply(&theta_b);
6912 for i in 0..4 {
6913 assert!((applied[i] - emitted_b_chunk[[i, 0]] * theta_b[0]).abs() < 1e-10);
6914 }
6915 }
6916}