use crate::faer_ndarray::{
FaerArrayView, array2_to_matmut, fast_ab, fast_atb, fast_atv, fast_atv_into, fast_av,
fast_av_into, fast_xt_diag_x,
};
use crate::resource::{
MaterializationPolicy, MatrixMaterializationError, ResourcePolicy, rows_for_target_bytes,
};
use crate::types::RidgePolicy;
use faer::Accum;
use faer::Par;
use faer::linalg::matmul::matmul;
use faer::sparse::{SparseColMat, SparseRowMat, Triplet};
use ndarray::{
Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, ShapeBuilder, s,
};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::ops::Deref;
use std::ops::Range;
use std::sync::{Arc, OnceLock};
const MATRIX_FREE_PCG_MIN_P: usize = 2048;
const MATRIX_FREE_PCG_REL_TOL: f64 = 1e-8;
const MATRIX_FREE_PCG_MAX_ITER: usize = 2000;
const MAX_SINGLE_DENSE_MATERIALIZATION_BYTES: usize = 256 * 1024 * 1024;
const MAX_PERSISTENT_SPARSE_DENSE_CACHE_BYTES: usize = 256 * 1024 * 1024;
const MAX_SPARSE_TO_DENSE_BYTES: usize = MAX_SINGLE_DENSE_MATERIALIZATION_BYTES;
const CHUNKED_DENSE_MATERIALIZATION_BYTES: usize = 8 * 1024 * 1024;
const OPERATOR_ROW_CHUNK_SIZE: usize = 256;
const KERNEL_OPERATOR_ROW_CHUNK_SIZE: usize = 2048;
const DENSE_ROW_PARALLEL_MIN_NP: u64 = 200_000;
const WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS: u64 = 500_000;
const SPARSE_ROW_PARALLEL_MIN_FLOPS: u64 = 100_000;
const TENSOR_GEMM_MAX_INTERMEDIATE_BYTES: usize = 128 * 1024 * 1024;
pub use crate::linalg::utils::PcgSolveInfo;
#[inline]
fn dense_materialization_chunk_rows(nrows: usize, ncols: usize) -> usize {
rows_for_target_bytes(CHUNKED_DENSE_MATERIALIZATION_BYTES, ncols)
.max(1)
.min(nrows.max(1))
}
fn dense_operator_to_dense_by_chunks<O: DenseDesignOperator + ?Sized>(
op: &O,
) -> Result<Array2<f64>, MatrixMaterializationError> {
let n = op.nrows();
let p = op.ncols();
let chunk_rows = dense_materialization_chunk_rows(n, p);
let mut out = Array2::<f64>::zeros((n, p));
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let slice = out.slice_mut(s![start..end, ..]);
op.row_chunk_into(start..end, slice)?;
}
Ok(out)
}
fn checked_dense_nbytes(nrows: usize, ncols: usize, context: &str) -> Result<usize, String> {
nrows
.checked_mul(ncols)
.and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
.ok_or_else(|| format!("{context}: dense size overflow for {nrows}x{ncols}"))
}
pub fn panic_or_error_if_biobank_mode_and_to_dense_called_with_policy(
context: &str,
n: usize,
p: usize,
policy: &ResourcePolicy,
) -> Result<(), String> {
if matches!(
policy.derivative_storage_mode,
crate::resource::DerivativeStorageMode::AnalyticOperatorRequired
) {
return Err(format!(
"{context}: refusing to densify operator-backed design {n}x{p} under \
AnalyticOperatorRequired policy; provide an operator-form path"
));
}
let dense_bytes = checked_dense_nbytes(n, p, context)?;
let limit = policy.max_single_materialization_bytes;
if dense_bytes > limit {
let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
return Err(format!(
"{context}: refusing to densify operator-backed design {n}x{p} (~{gib:.2} GiB); use matrix-free or chunked code"
));
}
Ok(())
}
fn weighted_crossprod_dense(
left: &Array2<f64>,
weights: &Array1<f64>,
right: &Array2<f64>,
) -> Result<Array2<f64>, String> {
if left.nrows() != weights.len() || right.nrows() != weights.len() {
return Err(format!(
"weighted_crossprod_dense row mismatch: left={}, weights={}, right={}",
left.nrows(),
weights.len(),
right.nrows()
));
}
Ok(weighted_crossprod_dense_view(left, weights.view(), right))
}
fn weighted_crossprod_dense_view(
left: &Array2<f64>,
weights: ArrayView1<'_, f64>,
right: &Array2<f64>,
) -> Array2<f64> {
let n = weights.len();
let p_left = left.ncols();
let p_right = right.ncols();
let work = (n as u64)
.saturating_mul(p_left as u64)
.saturating_mul(p_right as u64);
if rayon::current_num_threads() <= 1 || work < WEIGHTED_CROSSPROD_PARALLEL_MIN_FLOPS {
return weighted_crossprod_dense_rows(left, weights, right, 0..n);
}
let n_threads = rayon::current_num_threads();
let chunk_rows = n.div_ceil(n_threads * 4).max(1);
let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
let partials: Vec<Array2<f64>> = starts
.into_par_iter()
.map(|start| {
weighted_crossprod_dense_rows(left, weights, right, start..(start + chunk_rows).min(n))
})
.collect();
let mut out = Array2::<f64>::zeros((p_left, p_right));
for partial in &partials {
out += partial;
}
out
}
fn weighted_crossprod_dense_rows(
left: &Array2<f64>,
weights: ArrayView1<'_, f64>,
right: &Array2<f64>,
rows: Range<usize>,
) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((left.ncols(), right.ncols()));
for i in rows {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
for a in 0..left.ncols() {
let scaled = wi * left[[i, a]];
if scaled == 0.0 {
continue;
}
for b in 0..right.ncols() {
out[[a, b]] += scaled * right[[i, b]];
}
}
}
out
}
pub struct DenseRightProductView<'a> {
base: &'a Array2<f64>,
first: Option<&'a Array2<f64>>,
second: Option<&'a Array2<f64>>,
}
impl<'a> DenseRightProductView<'a> {
pub fn new(base: &'a Array2<f64>) -> Self {
Self {
base,
first: None,
second: None,
}
}
pub fn with_factor(mut self, factor: &'a Array2<f64>) -> Self {
if self.first.is_none() {
self.first = Some(factor);
} else if self.second.is_none() {
self.second = Some(factor);
} else {
panic!("DenseRightProductView supports at most two right factors");
}
self
}
pub fn with_optional_factor(self, factor: Option<&'a Array2<f64>>) -> Self {
match factor {
Some(factor) => self.with_factor(factor),
None => self,
}
}
pub fn materialize(&self) -> Array2<f64> {
let mut out = self.base.clone();
if let Some(factor) = self.first {
out = fast_ab(&out, factor);
}
if let Some(factor) = self.second {
out = fast_ab(&out, factor);
}
out
}
fn transformed_ncols(&self) -> usize {
if let Some(factor) = self.second {
factor.ncols()
} else if let Some(factor) = self.first {
factor.ncols()
} else {
self.base.ncols()
}
}
}
pub struct DenseRowScaledView<'a> {
matrix: &'a Array2<f64>,
scale: &'a Array1<f64>,
}
impl<'a> DenseRowScaledView<'a> {
pub fn new(matrix: &'a Array2<f64>, scale: &'a Array1<f64>) -> Self {
Self { matrix, scale }
}
pub fn materialize(&self) -> Array2<f64> {
let mut out = self.matrix.clone();
for (mut row, &weight) in out.outer_iter_mut().zip(self.scale.iter()) {
row *= weight;
}
out
}
}
pub struct EmbeddedColumnBlock<'a> {
local: &'a Array2<f64>,
global_range: Range<usize>,
total_cols: usize,
}
impl<'a> EmbeddedColumnBlock<'a> {
pub fn new(local: &'a Array2<f64>, global_range: Range<usize>, total_cols: usize) -> Self {
Self {
local,
global_range,
total_cols,
}
}
pub fn materialize(&self) -> Array2<f64> {
if self.local.nrows() == 0 {
return Array2::<f64>::zeros((0, self.total_cols));
}
debug_assert_eq!(
self.local.ncols(),
self.global_range.len(),
"embedded column block width mismatch"
);
let mut out = Array2::<f64>::zeros((self.local.nrows(), self.total_cols));
out.slice_mut(ndarray::s![.., self.global_range.clone()])
.assign(self.local);
out
}
}
pub struct EmbeddedSquareBlock<'a> {
local: &'a Array2<f64>,
global_range: Range<usize>,
total_dim: usize,
}
impl<'a> EmbeddedSquareBlock<'a> {
pub fn new(local: &'a Array2<f64>, global_range: Range<usize>, total_dim: usize) -> Self {
Self {
local,
global_range,
total_dim,
}
}
pub fn materialize(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.total_dim, self.total_dim));
out.slice_mut(ndarray::s![
self.global_range.clone(),
self.global_range.clone()
])
.assign(self.local);
out
}
}
struct PenalizedWeightedNormalOperator<'a, O: LinearOperator + ?Sized> {
operator: &'a O,
weights: &'a Array1<f64>,
penalty: Option<&'a Array2<f64>>,
ridge: f64,
}
impl<'a, O: LinearOperator + ?Sized> PenalizedWeightedNormalOperator<'a, O> {
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
self.operator
.apply_weighted_normal(self.weights, vector, self.penalty, self.ridge)
}
fn jacobi_preconditioner(&self) -> Result<Array1<f64>, String> {
let mut diag = self.operator.diag_gram(self.weights)?;
if let Some(pen) = self.penalty {
for i in 0..diag.len() {
diag[i] += pen[[i, i]];
}
}
if self.ridge > 0.0 {
for i in 0..diag.len() {
diag[i] += self.ridge;
}
}
Ok(diag)
}
}
#[inline]
fn dense_matvec(matrix: &Array2<f64>, vector: &Array1<f64>) -> Array1<f64> {
fast_av(matrix, vector)
}
#[inline]
fn dense_transpose_matvec(matrix: &Array2<f64>, vector: &Array1<f64>) -> Array1<f64> {
fast_atv(matrix, vector)
}
#[inline]
fn dense_transpose_matvec_view(matrix: &Array2<f64>, vector: ArrayView1<'_, f64>) -> Array1<f64> {
let n = matrix.nrows();
let p = matrix.ncols();
if (n as u64) * (p as u64) < DENSE_ROW_PARALLEL_MIN_NP {
let mut out = Array1::<f64>::zeros(p);
for i in 0..n {
let vi = vector[i];
if vi == 0.0 {
continue;
}
for j in 0..p {
out[j] += matrix[[i, j]] * vi;
}
}
return out;
}
(0..n)
.into_par_iter()
.fold(
|| Array1::<f64>::zeros(p),
|mut acc, i| {
let vi = vector[i];
if vi != 0.0 {
for j in 0..p {
acc[j] += matrix[[i, j]] * vi;
}
}
acc
},
)
.reduce(
|| Array1::<f64>::zeros(p),
|mut a, b| {
a += &b;
a
},
)
}
#[inline]
fn dense_xtwx_view(matrix: &Array2<f64>, weights: ArrayView1<'_, f64>) -> Array2<f64> {
weighted_crossprod_dense_view(matrix, weights, matrix)
}
#[inline]
fn dense_diag_gram_view(matrix: &Array2<f64>, weights: ArrayView1<'_, f64>) -> Array1<f64> {
let p = matrix.ncols();
let mut diag = Array1::<f64>::zeros(p);
for i in 0..matrix.nrows() {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
for j in 0..p {
let xij = matrix[[i, j]];
diag[j] += wi * xij * xij;
}
}
diag
}
fn sparse_csr_weighted_xtwx(
row_ptr: &[usize],
col_idx: &[usize],
vals: &[f64],
n: usize,
p: usize,
weights: ArrayView1<'_, f64>,
) -> Array2<f64> {
let nnz = vals.len() as u64;
let avg = nnz.checked_div(n.max(1) as u64).unwrap_or(0);
let work = (n as u64).saturating_mul(avg.saturating_mul(avg));
if rayon::current_num_threads() <= 1 || work < SPARSE_ROW_PARALLEL_MIN_FLOPS {
return sparse_csr_weighted_xtwx_rows(row_ptr, col_idx, vals, p, weights, 0..n);
}
let n_threads = rayon::current_num_threads();
let target_chunks = (n_threads * 8).max(1);
let chunk_rows = n.div_ceil(target_chunks).max(1);
let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
let partials: Vec<Array2<f64>> = starts
.into_par_iter()
.map(|start| {
sparse_csr_weighted_xtwx_rows(
row_ptr,
col_idx,
vals,
p,
weights,
start..(start + chunk_rows).min(n),
)
})
.collect();
let mut xtwx = Array2::<f64>::zeros((p, p));
for partial in &partials {
xtwx += partial;
}
xtwx
}
fn sparse_csr_weighted_xtwx_rows(
row_ptr: &[usize],
col_idx: &[usize],
vals: &[f64],
p: usize,
weights: ArrayView1<'_, f64>,
rows: Range<usize>,
) -> Array2<f64> {
let mut xtwx = Array2::<f64>::zeros((p, p));
for i in rows {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
let start = row_ptr[i];
let end = row_ptr[i + 1];
for a_ptr in start..end {
let a = col_idx[a_ptr];
let wxa = wi * vals[a_ptr];
for b_ptr in a_ptr..end {
let b = col_idx[b_ptr];
let v = wxa * vals[b_ptr];
xtwx[[a, b]] += v;
if a != b {
xtwx[[b, a]] += v;
}
}
}
}
xtwx
}
fn sparse_csr_diag_gram(
row_ptr: &[usize],
col_idx: &[usize],
vals: &[f64],
n: usize,
p: usize,
weights: ArrayView1<'_, f64>,
) -> Array1<f64> {
let work = vals.len() as u64;
if rayon::current_num_threads() <= 1 || work < SPARSE_ROW_PARALLEL_MIN_FLOPS {
return sparse_csr_diag_gram_rows(row_ptr, col_idx, vals, p, weights, 0..n);
}
let n_threads = rayon::current_num_threads();
let chunk_rows = n.div_ceil(n_threads * 8).max(1);
let starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
let partials: Vec<Array1<f64>> = starts
.into_par_iter()
.map(|start| {
sparse_csr_diag_gram_rows(
row_ptr,
col_idx,
vals,
p,
weights,
start..(start + chunk_rows).min(n),
)
})
.collect();
let mut diag = Array1::<f64>::zeros(p);
for partial in &partials {
diag += partial;
}
diag
}
fn sparse_csr_diag_gram_rows(
row_ptr: &[usize],
col_idx: &[usize],
vals: &[f64],
p: usize,
weights: ArrayView1<'_, f64>,
rows: Range<usize>,
) -> Array1<f64> {
let mut diag = Array1::<f64>::zeros(p);
for i in rows {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
for idx in row_ptr[i]..row_ptr[i + 1] {
let j = col_idx[idx];
let xij = vals[idx];
diag[j] += wi * xij * xij;
}
}
diag
}
#[inline]
fn dense_transpose_weighted_response(
matrix: &Array2<f64>,
weights: &Array1<f64>,
y: &Array1<f64>,
row_scale: Option<&Array1<f64>>,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(matrix.ncols());
for i in 0..matrix.nrows() {
let mut scaled = y[i] * weights[i].max(0.0);
if let Some(scale) = row_scale {
scaled *= scale[i];
}
if scaled == 0.0 {
continue;
}
for j in 0..matrix.ncols() {
out[j] += matrix[[i, j]] * scaled;
}
}
out
}
#[inline]
fn dense_transpose_weighted_response_view(
matrix: &Array2<f64>,
weights: ArrayView1<'_, f64>,
y: ArrayView1<'_, f64>,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(matrix.ncols());
for i in 0..matrix.nrows() {
let scaled = y[i] * weights[i].max(0.0);
if scaled == 0.0 {
continue;
}
for j in 0..matrix.ncols() {
out[j] += matrix[[i, j]] * scaled;
}
}
out
}
#[derive(Clone)]
pub struct SparseDesignMatrix {
matrix: SparseColMat<usize, f64>,
dense_cache: Arc<OnceLock<Arc<Array2<f64>>>>,
csr_cache: Arc<OnceLock<Arc<SparseRowMat<usize, f64>>>>,
}
impl SparseDesignMatrix {
pub fn new(matrix: SparseColMat<usize, f64>) -> Self {
Self {
matrix,
dense_cache: Arc::new(OnceLock::new()),
csr_cache: Arc::new(OnceLock::new()),
}
}
fn dense_nbytes(&self) -> Result<usize, String> {
self.matrix
.nrows()
.checked_mul(self.matrix.ncols())
.and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()))
.ok_or_else(|| {
format!(
"dense size overflow for sparse design {}x{}",
self.matrix.nrows(),
self.matrix.ncols()
)
})
}
fn materialize_dense_arc(&self) -> Arc<Array2<f64>> {
let mut out = Array2::<f64>::zeros((self.matrix.nrows(), self.matrix.ncols()));
let (symbolic, values) = self.matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..self.matrix.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
out[[row_idx[idx], col]] += values[idx];
}
}
Arc::new(out)
}
pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
let dense_bytes = self.dense_nbytes()?;
if dense_bytes > MAX_SPARSE_TO_DENSE_BYTES {
let gib = dense_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
return Err(format!(
"{context}: refusing to densify sparse design {}x{} (~{gib:.2} GiB); use sparse or matrix-free code",
self.matrix.nrows(),
self.matrix.ncols(),
));
}
if dense_bytes <= MAX_PERSISTENT_SPARSE_DENSE_CACHE_BYTES {
Ok(self
.dense_cache
.get_or_init(|| self.materialize_dense_arc())
.clone())
} else {
Ok(self.materialize_dense_arc())
}
}
pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
self.try_to_dense_arc("SparseDesignMatrix::to_dense_arc")
.unwrap_or_else(|msg| panic!("{msg}"))
}
pub fn to_csr_arc(&self) -> Option<Arc<SparseRowMat<usize, f64>>> {
if let Some(cached) = self.csr_cache.get() {
return Some(cached.clone());
}
let csr = self.matrix.as_ref().to_row_major().ok()?;
let arc = Arc::new(csr);
self.csr_cache.set(arc.clone()).ok();
Some(arc)
}
}
impl Deref for SparseDesignMatrix {
type Target = SparseColMat<usize, f64>;
fn deref(&self) -> &Self::Target {
&self.matrix
}
}
impl AsRef<SparseColMat<usize, f64>> for SparseDesignMatrix {
fn as_ref(&self) -> &SparseColMat<usize, f64> {
&self.matrix
}
}
pub trait DenseDesignOperator: LinearOperator + Send + Sync {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
let n = self.nrows();
if weights.len() != n || y.len() != n {
return Err(format!(
"DenseDesignOperator::compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
n
));
}
let mut wy = Array1::<f64>::zeros(n);
ndarray::Zip::from(&mut wy)
.and(weights)
.and(y)
.par_for_each(|o, &w, &yi| *o = w.max(0.0) * yi);
Ok(self.apply_transpose(&wy))
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
if middle.nrows() != self.ncols() || middle.ncols() != self.ncols() {
return Err(format!(
"DenseDesignOperator::quadratic_form_diag dimension mismatch: {}x{} vs expected {}x{}",
middle.nrows(),
middle.ncols(),
self.ncols(),
self.ncols()
));
}
let n = self.nrows();
let mut out = Array1::<f64>::zeros(n);
let chunk_size = (8 * 1024 * 1024 / (self.ncols().max(1) * 8 * 2))
.max(16)
.min(n.max(1));
let mut start = 0;
while start < n {
let end = (start + chunk_size).min(n);
let x_chunk = self.try_row_chunk(start..end).map_err(|e| e.to_string())?;
let xm_chunk = fast_ab(&x_chunk, middle);
let mut chunk_out = out.slice_mut(ndarray::s![start..end]);
ndarray::Zip::from(&mut chunk_out)
.and(x_chunk.rows())
.and(xm_chunk.rows())
.par_for_each(|o, xr, xmr| *o = xr.dot(&xmr).max(0.0));
start = end;
}
Ok(out)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError>;
fn try_row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, MatrixMaterializationError> {
let mut out = Array2::<f64>::zeros((rows.end - rows.start, self.ncols()));
self.row_chunk_into(rows, out.view_mut())?;
Ok(out)
}
fn as_dense_ref(&self) -> Option<&Array2<f64>> {
None
}
fn to_dense(&self) -> Array2<f64>;
fn estimated_dense_bytes(&self) -> usize {
self.nrows()
.saturating_mul(self.ncols())
.saturating_mul(std::mem::size_of::<f64>())
}
fn try_to_dense_with_policy(
&self,
policy: &MaterializationPolicy,
context: &'static str,
) -> Result<Arc<Array2<f64>>, MatrixMaterializationError> {
let bytes = self.estimated_dense_bytes();
if !policy.allow_operator_materialization {
return Err(MatrixMaterializationError::Forbidden {
context,
mode: crate::resource::DerivativeStorageMode::AnalyticOperatorRequired,
});
}
if bytes > policy.max_single_dense_bytes {
return Err(MatrixMaterializationError::TooLarge {
context,
nrows: self.nrows(),
ncols: self.ncols(),
bytes,
limit_bytes: policy.max_single_dense_bytes,
});
}
dense_operator_to_dense_by_chunks(self).map(Arc::new)
}
fn to_dense_arc(&self) -> Arc<Array2<f64>> {
Arc::new(
dense_operator_to_dense_by_chunks(self)
.expect("DenseDesignOperator::to_dense_arc: row-chunk materialization failed"),
)
}
}
#[derive(Clone)]
pub enum DenseDesignMatrix {
Materialized(Arc<Array2<f64>>),
Lazy(Arc<dyn DenseDesignOperator>),
}
impl std::fmt::Debug for DenseDesignMatrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Materialized(matrix) => {
write!(
f,
"DenseDesignMatrix::Materialized({}x{})",
matrix.nrows(),
matrix.ncols()
)
}
Self::Lazy(op) => write!(f, "DenseDesignMatrix::Lazy({}x{})", op.nrows(), op.ncols()),
}
}
}
impl From<Arc<Array2<f64>>> for DenseDesignMatrix {
fn from(value: Arc<Array2<f64>>) -> Self {
Self::Materialized(value)
}
}
impl From<Array2<f64>> for DenseDesignMatrix {
fn from(value: Array2<f64>) -> Self {
Self::Materialized(Arc::new(value))
}
}
impl<T> From<Arc<T>> for DenseDesignMatrix
where
T: DenseDesignOperator + 'static,
{
fn from(value: Arc<T>) -> Self {
Self::Lazy(value)
}
}
impl DenseDesignMatrix {
pub fn nrows(&self) -> usize {
match self {
Self::Materialized(matrix) => matrix.nrows(),
Self::Lazy(op) => op.nrows(),
}
}
pub fn ncols(&self) -> usize {
match self {
Self::Materialized(matrix) => matrix.ncols(),
Self::Lazy(op) => op.ncols(),
}
}
pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
match self {
Self::Materialized(matrix) => Some(matrix.as_ref()),
Self::Lazy(op) => op.as_dense_ref(),
}
}
pub fn is_materialized_dense(&self) -> bool {
matches!(self, Self::Materialized(_))
}
pub fn is_operator_backed(&self) -> bool {
matches!(self, Self::Lazy(_))
}
pub fn to_dense(&self) -> Array2<f64> {
match self {
Self::Materialized(matrix) => matrix.as_ref().clone(),
Self::Lazy(_) => self
.try_to_dense_arc("DenseDesignMatrix::to_dense")
.unwrap_or_else(|msg| panic!("{msg}"))
.as_ref()
.clone(),
}
}
pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
match self {
Self::Materialized(matrix) => Arc::clone(matrix),
Self::Lazy(_) => self
.try_to_dense_arc("DenseDesignMatrix::to_dense_arc")
.unwrap_or_else(|msg| panic!("{msg}")),
}
}
pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
let policy = ResourcePolicy::for_problem(
self.nrows(),
self.ncols(),
crate::resource::ProblemHints::default(),
);
self.try_to_dense_arc_with_policy(context, &policy)
}
pub fn try_to_dense_arc_with_policy(
&self,
context: &str,
policy: &ResourcePolicy,
) -> Result<Arc<Array2<f64>>, String> {
match self {
Self::Materialized(matrix) => Ok(Arc::clone(matrix)),
Self::Lazy(op) => {
panic_or_error_if_biobank_mode_and_to_dense_called_with_policy(
context,
op.nrows(),
op.ncols(),
policy,
)?;
dense_operator_to_dense_by_chunks(op.as_ref())
.map(Arc::new)
.map_err(|err| {
format!("{context}: failed to materialize dense row chunks: {err}")
})
}
}
}
pub fn try_row_chunk(
&self,
rows: Range<usize>,
) -> Result<Array2<f64>, MatrixMaterializationError> {
match self {
Self::Materialized(matrix) => Ok(matrix.slice(s![rows, ..]).to_owned()),
Self::Lazy(op) => op.try_row_chunk(rows),
}
}
pub fn row_chunk_into(
&self,
rows: Range<usize>,
out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
match self {
Self::Materialized(matrix) => {
let mut out = out;
out.assign(&matrix.slice(s![rows, ..]));
Ok(())
}
Self::Lazy(op) => op.row_chunk_into(rows, out),
}
}
}
impl LinearOperator for DenseDesignMatrix {
fn nrows(&self) -> usize {
DenseDesignMatrix::nrows(self)
}
fn ncols(&self) -> usize {
DenseDesignMatrix::ncols(self)
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Materialized(matrix) => dense_matvec(matrix, vector),
Self::Lazy(op) => op.apply(vector),
}
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Materialized(matrix) => dense_transpose_matvec(matrix, vector),
Self::Lazy(op) => op.apply_transpose(vector),
}
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
match self {
Self::Materialized(matrix) => {
let mut xtwx = Array2::<f64>::zeros((matrix.ncols(), matrix.ncols()));
streaming_blas_xt_diag_x(matrix, weights, &mut xtwx);
Ok(xtwx)
}
Self::Lazy(op) => op.diag_xtw_x(weights),
}
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
match self {
Self::Materialized(matrix) => {
let n = matrix.nrows();
let p = matrix.ncols();
if (n as u64) * (p as u64) < DENSE_ROW_PARALLEL_MIN_NP {
let mut diag = Array1::<f64>::zeros(p);
for i in 0..n {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
for j in 0..p {
let xij = matrix[[i, j]];
diag[j] += wi * xij * xij;
}
}
return Ok(diag);
}
let diag = (0..n)
.into_par_iter()
.fold(
|| Array1::<f64>::zeros(p),
|mut acc, i| {
let wi = weights[i].max(0.0);
if wi != 0.0 {
for j in 0..p {
let xij = matrix[[i, j]];
acc[j] += wi * xij * xij;
}
}
acc
},
)
.reduce(
|| Array1::<f64>::zeros(p),
|mut a, b| {
a += &b;
a
},
);
Ok(diag)
}
Self::Lazy(op) => op.diag_gram(weights),
}
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
match self {
Self::Materialized(matrix) => {
let n = matrix.nrows();
let p = matrix.ncols();
let mut out = if (n as u64) * (p as u64) < DENSE_ROW_PARALLEL_MIN_NP {
let mut out = Array1::<f64>::zeros(p);
for i in 0..n {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
let mut row_dot = 0.0_f64;
for j in 0..p {
row_dot += matrix[[i, j]] * vector[j];
}
if row_dot == 0.0 {
continue;
}
let scaled = wi * row_dot;
for j in 0..p {
out[j] += scaled * matrix[[i, j]];
}
}
out
} else {
(0..n)
.into_par_iter()
.fold(
|| Array1::<f64>::zeros(p),
|mut acc, i| {
let wi = weights[i].max(0.0);
if wi != 0.0 {
let mut row_dot = 0.0_f64;
for j in 0..p {
row_dot += matrix[[i, j]] * vector[j];
}
if row_dot != 0.0 {
let scaled = wi * row_dot;
for j in 0..p {
acc[j] += scaled * matrix[[i, j]];
}
}
}
acc
},
)
.reduce(
|| Array1::<f64>::zeros(p),
|mut a, b| {
a += &b;
a
},
)
};
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
for j in 0..p {
out[j] += ridge * vector[j];
}
}
out
}
Self::Lazy(op) => op.apply_weighted_normal(weights, vector, penalty, ridge),
}
}
fn uses_matrix_free_pcg(&self) -> bool {
match self {
Self::Materialized(_) => true,
Self::Lazy(op) => op.uses_matrix_free_pcg(),
}
}
}
impl DenseDesignOperator for DenseDesignMatrix {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
match self {
Self::Materialized(matrix) => {
Ok(dense_transpose_weighted_response(matrix, weights, y, None))
}
Self::Lazy(op) => op.compute_xtwy(weights, y),
}
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
match self {
Self::Materialized(matrix) => {
if middle.nrows() != matrix.ncols() || middle.ncols() != matrix.ncols() {
return Err(format!(
"quadratic_form_diag dimension mismatch: matrix is {}x{}, expected {}x{}",
middle.nrows(),
middle.ncols(),
matrix.ncols(),
matrix.ncols()
));
}
let xc = fast_ab(matrix, middle);
let mut out = Array1::<f64>::zeros(matrix.nrows());
for i in 0..matrix.nrows() {
out[i] = matrix.row(i).dot(&xc.row(i)).max(0.0);
}
Ok(out)
}
Self::Lazy(op) => op.quadratic_form_diag(middle),
}
}
fn as_dense_ref(&self) -> Option<&Array2<f64>> {
DenseDesignMatrix::as_dense_ref(self)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "DenseDesignMatrix::row_chunk_into shape mismatch",
});
}
match self {
Self::Materialized(matrix) => {
out.assign(&matrix.slice(s![rows, ..]));
Ok(())
}
Self::Lazy(op) => op.row_chunk_into(rows, out),
}
}
fn to_dense(&self) -> Array2<f64> {
DenseDesignMatrix::to_dense(self)
}
fn to_dense_arc(&self) -> Arc<Array2<f64>> {
DenseDesignMatrix::to_dense_arc(self)
}
}
pub struct ReparamOperator {
x_original: DesignMatrix,
qs: Arc<Array2<f64>>,
n: usize,
p: usize,
}
impl ReparamOperator {
pub fn new(x_original: DesignMatrix, qs: Arc<Array2<f64>>) -> Self {
let n = x_original.nrows();
let p = qs.ncols();
assert_eq!(
x_original.ncols(),
qs.nrows(),
"ReparamOperator: X cols ({}) must match Qs rows ({})",
x_original.ncols(),
qs.nrows()
);
Self {
x_original,
qs,
n,
p,
}
}
pub fn x_original(&self) -> &DesignMatrix {
&self.x_original
}
pub fn qs(&self) -> &Array2<f64> {
&self.qs
}
}
impl LinearOperator for ReparamOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.p
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let qv = self.qs.dot(vector);
self.x_original.apply(&qv)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let xtv = self.x_original.apply_transpose(vector);
fast_atv(&self.qs, &xtv)
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let xtwx = self.x_original.diag_xtw_x(weights)?;
let tmp = fast_atb(&self.qs, &xtwx);
Ok(fast_ab(&tmp, &self.qs))
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
let qv = self.qs.dot(vector);
let xqv = self.x_original.apply(&qv);
let mut wxqv = xqv;
for i in 0..wxqv.len() {
wxqv[i] *= weights[i].max(0.0);
}
let xtw = self.x_original.apply_transpose(&wxqv);
let mut out = fast_atv(&self.qs, &xtw);
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
out += &vector.mapv(|x| ridge * x);
}
out
}
}
impl DenseDesignOperator for ReparamOperator {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
let xtwy = self.x_original.compute_xtwy(weights, y)?;
Ok(fast_atv(&self.qs, &xtwy))
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
let qm = fast_ab(&self.qs, middle);
let m_orig = fast_ab(&qm, &self.qs.t().to_owned());
self.x_original.quadratic_form_diag(&m_orig)
}
fn to_dense(&self) -> Array2<f64> {
match &self.x_original {
DesignMatrix::Dense(x) => fast_ab(x.to_dense_arc().as_ref(), &self.qs),
_ => {
let x_dense = self.x_original.to_dense();
fast_ab(&x_dense, &self.qs)
}
}
}
fn to_dense_arc(&self) -> Arc<Array2<f64>> {
Arc::new(self.to_dense())
}
fn as_dense_ref(&self) -> Option<&Array2<f64>> {
None
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "ReparamOperator::row_chunk_into shape mismatch",
});
}
match &self.x_original {
DesignMatrix::Dense(x) => {
let chunk = x.try_row_chunk(rows)?;
out.assign(&fast_ab(&chunk, &self.qs));
}
DesignMatrix::Sparse(sdm) => {
let csr = sdm
.to_csr_arc()
.expect("ReparamOperator::row_chunk_into: CSR conversion");
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let chunk_rows = rows.end - rows.start;
let p_inner = sdm.ncols();
let mut chunk = Array2::<f64>::zeros((chunk_rows, p_inner));
for (local, global) in (rows.start..rows.end).enumerate() {
for ptr in row_ptr[global]..row_ptr[global + 1] {
chunk[[local, col_idx[ptr]]] = vals[ptr];
}
}
out.assign(&fast_ab(&chunk, &self.qs));
}
}
Ok(())
}
}
#[derive(Clone)]
pub struct RandomEffectOperator {
pub group_ids: Vec<Option<usize>>,
pub n: usize,
pub num_groups: usize,
}
impl RandomEffectOperator {
pub fn new(group_ids: Vec<Option<usize>>, num_groups: usize) -> Self {
let n = group_ids.len();
Self {
group_ids,
n,
num_groups,
}
}
pub fn weighted_cross_with_dense(
&self,
dense: &Array2<f64>,
weights: &Array1<f64>,
) -> Array2<f64> {
let p_dense = dense.ncols();
let mut cross = Array2::<f64>::zeros((p_dense, self.num_groups));
for i in 0..self.n {
if let Some(g) = self.group_ids[i] {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
for j in 0..p_dense {
cross[[j, g]] += wi * dense[[i, j]];
}
}
}
cross
}
pub fn weighted_cross_with_re(
&self,
other: &RandomEffectOperator,
weights: &Array1<f64>,
) -> Array2<f64> {
let mut cross = Array2::<f64>::zeros((self.num_groups, other.num_groups));
for i in 0..self.n {
if let (Some(a), Some(b)) = (self.group_ids[i], other.group_ids[i]) {
let wi = weights[i].max(0.0);
if wi != 0.0 {
cross[[a, b]] += wi;
}
}
}
cross
}
}
impl LinearOperator for RandomEffectOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.num_groups
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
use rayon::prelude::*;
let out: Vec<f64> = self
.group_ids
.par_iter()
.map(|g| g.map(|g| vector[g]).unwrap_or(0.0))
.collect();
Array1::from(out)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.num_groups);
for i in 0..self.n {
if let Some(g) = self.group_ids[i] {
out[g] += vector[i];
}
}
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let q = self.num_groups;
let mut xtwx = Array2::<f64>::zeros((q, q));
for i in 0..self.n {
if let Some(g) = self.group_ids[i] {
xtwx[[g, g]] += weights[i].max(0.0);
}
}
Ok(xtwx)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut diag = Array1::<f64>::zeros(self.num_groups);
for i in 0..self.n {
if let Some(g) = self.group_ids[i] {
diag[g] += weights[i].max(0.0);
}
}
Ok(diag)
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
let mut group_wacc = Array1::<f64>::zeros(self.num_groups);
for i in 0..self.n {
if let Some(g) = self.group_ids[i] {
group_wacc[g] += weights[i].max(0.0);
}
}
let mut out = Array1::<f64>::zeros(self.num_groups);
for g in 0..self.num_groups {
out[g] = group_wacc[g] * vector[g];
}
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
for g in 0..self.num_groups {
out[g] += ridge * vector[g];
}
}
out
}
fn uses_matrix_free_pcg(&self) -> bool {
true
}
}
impl DenseDesignOperator for RandomEffectOperator {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(self.num_groups);
for i in 0..self.n {
if let Some(g) = self.group_ids[i] {
let wi = weights[i].max(0.0);
out[g] += wi * y[i];
}
}
Ok(out)
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
use rayon::prelude::*;
let out: Vec<f64> = self
.group_ids
.par_iter()
.map(|g| g.map(|g| middle[[g, g]].max(0.0)).unwrap_or(0.0))
.collect();
Ok(Array1::from(out))
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.num_groups {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "RandomEffectOperator::row_chunk_into shape mismatch",
});
}
out.fill(0.0);
for (local, global) in rows.enumerate() {
if let Some(g) = self.group_ids[global] {
out[[local, g]] = 1.0;
}
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.n, self.num_groups));
ndarray::Zip::indexed(out.rows_mut()).par_for_each(|i, mut row| {
if let Some(g) = self.group_ids[i] {
row[g] = 1.0;
}
});
out
}
}
#[derive(Clone)]
pub enum DesignBlock {
Dense(DenseDesignMatrix),
Sparse(SparseDesignMatrix),
RandomEffect(Arc<RandomEffectOperator>),
Intercept(usize),
}
impl DesignBlock {
pub(crate) fn nrows(&self) -> usize {
match self {
Self::Dense(d) => d.nrows(),
Self::Sparse(s) => s.nrows(),
Self::RandomEffect(op) => op.nrows(),
Self::Intercept(n) => *n,
}
}
pub(crate) fn ncols(&self) -> usize {
match self {
Self::Dense(d) => d.ncols(),
Self::Sparse(s) => s.ncols(),
Self::RandomEffect(op) => op.ncols(),
Self::Intercept(_) => 1,
}
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(d) => d.apply(vector),
Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).apply(vector),
Self::RandomEffect(op) => op.apply(vector),
Self::Intercept(n) => Array1::from_elem(*n, vector[0]),
}
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(d) => d.apply_transpose(vector),
Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).apply_transpose(vector),
Self::RandomEffect(op) => op.apply_transpose(vector),
Self::Intercept(_) => {
let sum: f64 = vector.iter().sum();
Array1::from_vec(vec![sum])
}
}
}
fn try_row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, MatrixMaterializationError> {
match self {
Self::Dense(d) => d.try_row_chunk(rows),
Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).try_row_chunk(rows),
Self::RandomEffect(op) => op.try_row_chunk(rows),
Self::Intercept(_) => Ok(Array2::ones((rows.end - rows.start, 1))),
}
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
match self {
Self::Dense(d) => d.diag_xtw_x(weights),
Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).diag_xtw_x(weights),
Self::RandomEffect(op) => op.diag_xtw_x(weights),
Self::Intercept(_) => {
let sum: f64 = weights.iter().map(|w| w.max(0.0)).sum();
Ok(Array2::from_elem((1, 1), sum))
}
}
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
match self {
Self::Dense(d) => d.diag_gram(weights),
Self::Sparse(s) => DesignMatrix::Sparse(s.clone()).diag_gram(weights),
Self::RandomEffect(op) => op.diag_gram(weights),
Self::Intercept(_) => {
let sum: f64 = weights.iter().map(|w| w.max(0.0)).sum();
Ok(Array1::from_vec(vec![sum]))
}
}
}
fn to_dense(&self) -> Array2<f64> {
match self {
Self::Dense(d) => d.to_dense(),
Self::Sparse(s) => s.to_dense_arc().as_ref().clone(),
Self::RandomEffect(op) => op.to_dense(),
Self::Intercept(n) => Array2::ones((*n, 1)),
}
}
}
#[derive(Clone)]
pub struct BlockDesignOperator {
pub blocks: Vec<DesignBlock>,
pub col_offsets: Vec<usize>,
pub total_cols: usize,
pub n: usize,
}
impl BlockDesignOperator {
pub fn new(blocks: Vec<DesignBlock>) -> Result<Self, String> {
if blocks.is_empty() {
return Err("BlockDesignOperator: need at least one block".to_string());
}
let n = blocks[0].nrows();
for (i, b) in blocks.iter().enumerate() {
if b.nrows() != n {
return Err(format!(
"BlockDesignOperator: block {i} has {} rows, expected {n}",
b.nrows()
));
}
}
let mut col_offsets = Vec::with_capacity(blocks.len() + 1);
col_offsets.push(0);
for b in &blocks {
col_offsets.push(col_offsets.last().unwrap() + b.ncols());
}
let total_cols = *col_offsets.last().unwrap();
Ok(Self {
blocks,
col_offsets,
total_cols,
n,
})
}
fn weighted_cross_chunked(
&self,
left: &DesignBlock,
right: &DesignBlock,
weights: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let pi = left.ncols();
let pj = right.ncols();
let mut cross = Array2::<f64>::zeros((pi, pj));
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let left_chunk = left.try_row_chunk(start..end).map_err(|e| e.to_string())?;
let right_chunk = right.try_row_chunk(start..end).map_err(|e| e.to_string())?;
for local in 0..(end - start) {
let wi = weights[start + local].max(0.0);
if wi == 0.0 {
continue;
}
for a in 0..pi {
let scaled = wi * left_chunk[[local, a]];
if scaled == 0.0 {
continue;
}
for b in 0..pj {
cross[[a, b]] += scaled * right_chunk[[local, b]];
}
}
}
}
Ok(cross)
}
fn quadratic_form_diag_cross_chunked(
&self,
block_a: &DesignBlock,
block_b: &DesignBlock,
m_ab: &Array2<f64>,
) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let a_chunk = block_a
.try_row_chunk(start..end)
.map_err(|e| e.to_string())?;
let b_chunk = block_b
.try_row_chunk(start..end)
.map_err(|e| e.to_string())?;
let a_m = fast_ab(&a_chunk, m_ab);
for local in 0..(end - start) {
out[start + local] = a_m.row(local).dot(&b_chunk.row(local));
}
}
Ok(out)
}
fn cross_block(
&self,
i: usize,
j: usize,
weights: &Array1<f64>,
) -> Result<Array2<f64>, String> {
match (&self.blocks[i], &self.blocks[j]) {
(DesignBlock::Dense(d_i), DesignBlock::Dense(d_j)) => {
if let (Some(xi), Some(xj)) = (d_i.as_dense_ref(), d_j.as_dense_ref()) {
weighted_crossprod_dense(xi, weights, xj)
} else {
self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
}
}
(DesignBlock::Dense(_), DesignBlock::Sparse(_))
| (DesignBlock::Sparse(_), DesignBlock::Dense(_))
| (DesignBlock::Sparse(_), DesignBlock::Sparse(_))
| (DesignBlock::Sparse(_), DesignBlock::RandomEffect(_))
| (DesignBlock::RandomEffect(_), DesignBlock::Sparse(_)) => {
self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
}
(DesignBlock::Dense(d), DesignBlock::RandomEffect(re)) => {
if let Some(dense) = d.as_dense_ref() {
Ok(re.weighted_cross_with_dense(dense, weights))
} else {
self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
}
}
(DesignBlock::RandomEffect(re), DesignBlock::Dense(d)) => {
if let Some(dense) = d.as_dense_ref() {
let cross_t = re.weighted_cross_with_dense(dense, weights);
Ok(cross_t.t().to_owned())
} else {
self.weighted_cross_chunked(&self.blocks[i], &self.blocks[j], weights)
}
}
(DesignBlock::RandomEffect(re_a), DesignBlock::RandomEffect(re_b)) => {
Ok(re_a.weighted_cross_with_re(re_b, weights))
}
(DesignBlock::Intercept(_), other) => {
let pj = other.ncols();
let mut cross = Array2::<f64>::zeros((1, pj));
let weighted = Array1::from_shape_fn(self.n, |idx| weights[idx].max(0.0));
let row = other.apply_transpose(&weighted);
cross.row_mut(0).assign(&row);
Ok(cross)
}
(other, DesignBlock::Intercept(_)) => {
let pi = other.ncols();
let mut cross = Array2::<f64>::zeros((pi, 1));
let weighted = Array1::from_shape_fn(self.n, |idx| weights[idx].max(0.0));
let col = other.apply_transpose(&weighted);
cross.column_mut(0).assign(&col);
Ok(cross)
}
}
}
fn quadratic_form_diag_block(
&self,
block: &DesignBlock,
m_kk: &Array2<f64>,
) -> Result<Array1<f64>, String> {
match block {
DesignBlock::Dense(d) => {
if let Some(dense) = d.as_dense_ref() {
let xm = fast_ab(dense, m_kk);
let mut out = Array1::<f64>::zeros(self.n);
ndarray::Zip::from(&mut out)
.and(dense.rows())
.and(xm.rows())
.par_for_each(|o, dr, xmr| *o = dr.dot(&xmr));
Ok(out)
} else {
d.quadratic_form_diag(m_kk)
}
}
DesignBlock::Sparse(s) => {
let sparse = DesignMatrix::Sparse(s.clone());
sparse.quadratic_form_diag(m_kk)
}
DesignBlock::RandomEffect(re) => {
use rayon::prelude::*;
let out: Vec<f64> = re
.group_ids
.par_iter()
.map(|g| g.map(|g| m_kk[[g, g]]).unwrap_or(0.0))
.collect();
Ok(Array1::from(out))
}
DesignBlock::Intercept(_) => {
Ok(Array1::from_elem(self.n, m_kk[[0, 0]]))
}
}
}
fn quadratic_form_diag_cross(
&self,
block_a: &DesignBlock,
block_b: &DesignBlock,
m_ab: &Array2<f64>,
) -> Result<Array1<f64>, String> {
match (block_a, block_b) {
(DesignBlock::Dense(da), DesignBlock::Dense(db)) => {
if let (Some(da), Some(db)) = (da.as_dense_ref(), db.as_dense_ref()) {
let da_m = fast_ab(da, m_ab);
let mut out = Array1::<f64>::zeros(self.n);
ndarray::Zip::from(&mut out)
.and(da_m.rows())
.and(db.rows())
.par_for_each(|o, ar, br| *o = ar.dot(&br));
Ok(out)
} else {
self.quadratic_form_diag_cross_chunked(block_a, block_b, m_ab)
}
}
(DesignBlock::Dense(_), DesignBlock::Sparse(_))
| (DesignBlock::Sparse(_), DesignBlock::Dense(_))
| (DesignBlock::Sparse(_), DesignBlock::Sparse(_))
| (DesignBlock::Sparse(_), DesignBlock::RandomEffect(_))
| (DesignBlock::RandomEffect(_), DesignBlock::Sparse(_)) => {
self.quadratic_form_diag_cross_chunked(block_a, block_b, m_ab)
}
(DesignBlock::Dense(d), DesignBlock::RandomEffect(re)) => {
let mut out = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = d.try_row_chunk(start..end).map_err(|e| e.to_string())?;
for local in 0..chunk.nrows() {
let i = start + local;
if let Some(g) = re.group_ids[i] {
let mut val = 0.0;
for j in 0..chunk.ncols() {
val += chunk[[local, j]] * m_ab[[j, g]];
}
out[i] = val;
}
}
}
Ok(out)
}
(DesignBlock::RandomEffect(re), DesignBlock::Dense(d)) => {
let mut out = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = d.try_row_chunk(start..end).map_err(|e| e.to_string())?;
for local in 0..chunk.nrows() {
let i = start + local;
if let Some(g) = re.group_ids[i] {
let mut val = 0.0;
for j in 0..chunk.ncols() {
val += m_ab[[g, j]] * chunk[[local, j]];
}
out[i] = val;
}
}
}
Ok(out)
}
(DesignBlock::RandomEffect(re_a), DesignBlock::RandomEffect(re_b)) => {
use rayon::prelude::*;
let out: Vec<f64> = re_a
.group_ids
.par_iter()
.zip(re_b.group_ids.par_iter())
.map(|(ga, gb)| match (ga, gb) {
(Some(ga), Some(gb)) => m_ab[[*ga, *gb]],
_ => 0.0,
})
.collect();
Ok(Array1::from(out))
}
(DesignBlock::Intercept(_), other) => {
let m_row = m_ab.row(0);
let mut out = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = other.try_row_chunk(start..end).map_err(|e| e.to_string())?;
for local in 0..(end - start) {
out[start + local] = chunk.row(local).dot(&m_row);
}
}
Ok(out)
}
(other, DesignBlock::Intercept(_)) => {
let m_col = m_ab.column(0);
let mut out = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = other.try_row_chunk(start..end).map_err(|e| e.to_string())?;
for local in 0..(end - start) {
out[start + local] = chunk.row(local).dot(&m_col);
}
}
Ok(out)
}
}
}
}
impl LinearOperator for BlockDesignOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.total_cols
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.n);
for (idx, block) in self.blocks.iter().enumerate() {
let start = self.col_offsets[idx];
let end = self.col_offsets[idx + 1];
let slice = vector.slice(s![start..end]).to_owned();
let contribution = block.apply(&slice);
out += &contribution;
}
out
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.total_cols);
for (idx, block) in self.blocks.iter().enumerate() {
let start = self.col_offsets[idx];
let end = self.col_offsets[idx + 1];
let transposed = block.apply_transpose(vector);
out.slice_mut(s![start..end]).assign(&transposed);
}
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let p = self.total_cols;
let mut result = Array2::<f64>::zeros((p, p));
for (idx, block) in self.blocks.iter().enumerate() {
let start = self.col_offsets[idx];
let end = self.col_offsets[idx + 1];
let block_xtwx = block.diag_xtw_x(weights)?;
result
.slice_mut(s![start..end, start..end])
.assign(&block_xtwx);
}
for i in 0..self.blocks.len() {
for j in (i + 1)..self.blocks.len() {
let cross = self.cross_block(i, j, weights)?;
let si = self.col_offsets[i];
let ei = self.col_offsets[i + 1];
let sj = self.col_offsets[j];
let ej = self.col_offsets[j + 1];
result.slice_mut(s![si..ei, sj..ej]).assign(&cross);
result.slice_mut(s![sj..ej, si..ei]).assign(&cross.t());
}
}
Ok(result)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(self.total_cols);
for (idx, block) in self.blocks.iter().enumerate() {
let start = self.col_offsets[idx];
let end = self.col_offsets[idx + 1];
let block_diag = block.diag_gram(weights)?;
out.slice_mut(s![start..end]).assign(&block_diag);
}
Ok(out)
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
let xv = self.apply(vector);
let mut weighted = xv;
for i in 0..weighted.len() {
weighted[i] *= weights[i].max(0.0);
}
let mut out = self.apply_transpose(&weighted);
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
out += &vector.mapv(|x| ridge * x);
}
out
}
fn uses_matrix_free_pcg(&self) -> bool {
self.blocks
.iter()
.any(|b| matches!(b, DesignBlock::RandomEffect(_) | DesignBlock::Intercept(_)))
}
}
impl DenseDesignOperator for BlockDesignOperator {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut wy = Array1::<f64>::zeros(self.n);
ndarray::Zip::from(&mut wy)
.and(weights)
.and(y)
.par_for_each(|o, &w, &yi| *o = w.max(0.0) * yi);
Ok(self.apply_transpose(&wy))
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(self.n);
let nb = self.blocks.len();
for k in 0..nb {
let sk = self.col_offsets[k];
let ek = self.col_offsets[k + 1];
let m_kk = middle.slice(s![sk..ek, sk..ek]).to_owned();
let block_diag = self.quadratic_form_diag_block(&self.blocks[k], &m_kk)?;
out += &block_diag;
}
for a in 0..nb {
for b in (a + 1)..nb {
let sa = self.col_offsets[a];
let ea = self.col_offsets[a + 1];
let sb = self.col_offsets[b];
let eb = self.col_offsets[b + 1];
let m_ab = middle.slice(s![sa..ea, sb..eb]);
let cross_diag = self.quadratic_form_diag_cross(
&self.blocks[a],
&self.blocks[b],
&m_ab.to_owned(),
)?;
for i in 0..self.n {
out[i] += 2.0 * cross_diag[i];
}
}
}
for v in out.iter_mut() {
*v = v.max(0.0);
}
Ok(out)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "BlockDesignOperator::row_chunk_into shape mismatch",
});
}
for (idx, block) in self.blocks.iter().enumerate() {
let cs = self.col_offsets[idx];
let ce = self.col_offsets[idx + 1];
let block_chunk = block.try_row_chunk(rows.clone())?;
out.slice_mut(s![.., cs..ce]).assign(&block_chunk);
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.n, self.total_cols));
for (idx, block) in self.blocks.iter().enumerate() {
let start = self.col_offsets[idx];
let end = self.col_offsets[idx + 1];
let dense_block = block.to_dense();
out.slice_mut(s![.., start..end]).assign(&dense_block);
}
out
}
}
pub struct ReparamDesignOperator {
pub inner: Arc<dyn DenseDesignOperator>,
pub q: Arc<Array2<f64>>,
n: usize,
p_new: usize,
}
impl ReparamDesignOperator {
pub fn new(inner: Arc<dyn DenseDesignOperator>, q: Arc<Array2<f64>>) -> Result<Self, String> {
let p_inner = inner.ncols();
if q.nrows() != p_inner {
return Err(format!(
"ReparamDesignOperator: inner has {} cols but Q has {} rows",
p_inner,
q.nrows()
));
}
Ok(Self {
n: inner.nrows(),
p_new: q.ncols(),
inner,
q,
})
}
}
impl LinearOperator for ReparamDesignOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.p_new
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let transformed = self.q.dot(vector);
self.inner.apply(&transformed)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let inner_result = self.inner.apply_transpose(vector);
self.q.t().dot(&inner_result)
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let inner_gram = self.inner.diag_xtw_x(weights)?;
let qt_gram = fast_atb(&self.q, &inner_gram);
Ok(fast_ab(&qt_gram, &self.q))
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let xtwx = self.diag_xtw_x(weights)?;
Ok(Array1::from_iter((0..self.p_new).map(|j| xtwx[[j, j]])))
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
let q_beta = self.q.dot(vector);
let inner_result = self
.inner
.apply_weighted_normal(weights, &q_beta, None, 0.0);
let mut out = self.q.t().dot(&inner_result);
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
out += &vector.mapv(|x| ridge * x);
}
out
}
fn uses_matrix_free_pcg(&self) -> bool {
self.inner.uses_matrix_free_pcg()
}
}
unsafe impl Send for ReparamDesignOperator {}
unsafe impl Sync for ReparamDesignOperator {}
impl DenseDesignOperator for ReparamDesignOperator {
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.p_new {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "ReparamDesignOperator::row_chunk_into shape mismatch",
});
}
let inner_chunk = self.inner.try_row_chunk(rows)?;
out.assign(&fast_ab(&inner_chunk, &self.q));
Ok(())
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
let qm = fast_ab(&self.q, middle); let qmqt = fast_ab(&qm, &self.q.t()); self.inner.quadratic_form_diag(&qmqt)
}
fn to_dense(&self) -> Array2<f64> {
let inner_dense = self.inner.to_dense();
fast_ab(&inner_dense, &self.q)
}
}
#[derive(Clone)]
pub struct MultiChannelOperator {
pub channels: Vec<DesignMatrix>,
pub n_per_channel: usize,
pub p: usize,
}
impl MultiChannelOperator {
pub fn new(channels: Vec<DesignMatrix>) -> Result<Self, String> {
if channels.is_empty() {
return Err("MultiChannelOperator: need at least one channel".to_string());
}
let n = channels[0].nrows();
let p = channels[0].ncols();
for (i, ch) in channels.iter().enumerate() {
if ch.nrows() != n {
return Err(format!(
"MultiChannelOperator: channel {i} has {} rows, expected {n}",
ch.nrows()
));
}
if ch.ncols() != p {
return Err(format!(
"MultiChannelOperator: channel {i} has {} cols, expected {p}",
ch.ncols()
));
}
}
Ok(Self {
channels,
n_per_channel: n,
p,
})
}
}
impl LinearOperator for MultiChannelOperator {
fn nrows(&self) -> usize {
self.n_per_channel * self.channels.len()
}
fn ncols(&self) -> usize {
self.p
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let total = self.nrows();
let mut out = Array1::<f64>::zeros(total);
let n = self.n_per_channel;
for (i, ch) in self.channels.iter().enumerate() {
let ch_result = ch.matrixvectormultiply(vector);
out.slice_mut(s![i * n..(i + 1) * n]).assign(&ch_result);
}
out
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let n = self.n_per_channel;
let mut out = Array1::<f64>::zeros(self.p);
for (i, ch) in self.channels.iter().enumerate() {
out += &ch.apply_transpose_view(vector.slice(s![i * n..(i + 1) * n]));
}
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let n = self.n_per_channel;
if weights.len() != self.nrows() {
return Err(format!(
"MultiChannelOperator::diag_xtw_x: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
let mut xtwx = Array2::<f64>::zeros((self.p, self.p));
for (i, ch) in self.channels.iter().enumerate() {
let ch_xtwx = ch.compute_xtwx_view(weights.slice(s![i * n..(i + 1) * n]))?;
xtwx += &ch_xtwx;
}
Ok(xtwx)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let n = self.n_per_channel;
if weights.len() != self.nrows() {
return Err(format!(
"MultiChannelOperator::diag_gram: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
let mut diag = Array1::<f64>::zeros(self.p);
for (i, ch) in self.channels.iter().enumerate() {
diag += &ch.diag_gram_view(weights.slice(s![i * n..(i + 1) * n]))?;
}
Ok(diag)
}
fn uses_matrix_free_pcg(&self) -> bool {
true
}
}
impl DenseDesignOperator for MultiChannelOperator {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
let n = self.n_per_channel;
let total = self.nrows();
if weights.len() != total || y.len() != total {
return Err(format!(
"MultiChannelOperator::compute_xtwy: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
total
));
}
let mut out = Array1::<f64>::zeros(self.p);
for (i, ch) in self.channels.iter().enumerate() {
out += &ch.compute_xtwy_view(
weights.slice(s![i * n..(i + 1) * n]),
y.slice(s![i * n..(i + 1) * n]),
)?;
}
Ok(out)
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
let n = self.n_per_channel;
let mut out = Array1::<f64>::zeros(self.nrows());
for (i, ch) in self.channels.iter().enumerate() {
let ch_diag = ch.quadratic_form_diag(middle)?;
out.slice_mut(s![i * n..(i + 1) * n]).assign(&ch_diag);
}
Ok(out)
}
fn to_dense(&self) -> Array2<f64> {
let total = self.nrows();
let n = self.n_per_channel;
let mut out = Array2::<f64>::zeros((total, self.p));
for (i, ch) in self.channels.iter().enumerate() {
let dense = ch.to_dense();
out.slice_mut(s![i * n..(i + 1) * n, ..]).assign(&dense);
}
out
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "MultiChannelOperator::row_chunk_into shape mismatch",
});
}
let n = self.n_per_channel;
let mut local = 0usize;
let mut global = rows.start;
while global < rows.end {
let ch_idx = global / n;
let ch_local_start = global % n;
let ch_local_end = ((ch_idx + 1) * n).min(rows.end) - ch_idx * n;
let segment_len = ch_local_end - ch_local_start;
let ch_chunk = self.channels[ch_idx].try_row_chunk(ch_local_start..ch_local_end)?;
out.slice_mut(s![local..local + segment_len, ..])
.assign(&ch_chunk);
local += segment_len;
global += segment_len;
}
Ok(())
}
}
#[derive(Clone)]
pub struct RowwiseKroneckerOperator {
pub cov: DesignMatrix,
pub time_basis: Arc<Array2<f64>>,
pub n: usize,
pub p_cov: usize,
pub p_time: usize,
}
fn decode_multi_index(mut flat: usize, dims: &[usize], out: &mut [usize]) {
for d in (0..dims.len()).rev() {
out[d] = flat % dims[d];
flat /= dims[d];
}
}
pub struct TensorProductDesignOperator {
marginals: Vec<Arc<Array2<f64>>>,
n: usize,
total_cols: usize,
}
impl TensorProductDesignOperator {
pub fn new(marginals: Vec<Arc<Array2<f64>>>) -> Result<Self, String> {
if marginals.is_empty() {
return Err("TensorProductDesignOperator requires at least one marginal".to_string());
}
let n = marginals[0].nrows();
let total_cols = marginals.iter().try_fold(1usize, |acc, marginal| {
if marginal.nrows() != n {
return Err(format!(
"TensorProductDesignOperator row mismatch: expected {n}, got {}",
marginal.nrows()
));
}
acc.checked_mul(marginal.ncols()).ok_or_else(|| {
"TensorProductDesignOperator total column count overflow".to_string()
})
})?;
Ok(Self {
marginals,
n,
total_cols,
})
}
fn row_values(&self, row: usize) -> Vec<f64> {
let mut values = vec![1.0_f64];
for marginal in &self.marginals {
let q = marginal.ncols();
let mut next = vec![0.0_f64; values.len() * q];
for (prefix_idx, &prefix) in values.iter().enumerate() {
for col in 0..q {
next[prefix_idx * q + col] = prefix * marginal[[row, col]];
}
}
values = next;
}
values
}
fn apply_vectorized(&self, vector: &Array1<f64>) -> Array1<f64> {
let d = self.marginals.len();
let n = self.n;
if d == 0 {
return Array1::zeros(n);
}
let b0 = &self.marginals[0];
let q0 = b0.ncols();
if d == 1 {
return fast_av(b0.as_ref(), vector);
}
let tail_dims: Vec<usize> = self.marginals[1..].iter().map(|m| m.ncols()).collect();
let tail_total: usize = tail_dims.iter().product();
let intermediate_bytes = n * tail_total * std::mem::size_of::<f64>();
if intermediate_bytes <= TENSOR_GEMM_MAX_INTERMEDIATE_BYTES {
let beta_view =
ndarray::ArrayView2::from_shape((q0, tail_total), vector.as_slice().unwrap())
.expect("β reshape for GEMM");
let temp = fast_ab(b0.as_ref(), &beta_view);
let mut out = Array1::<f64>::zeros(n);
let mut tail_indices = vec![0usize; tail_dims.len()];
for t_flat in 0..tail_total {
decode_multi_index(t_flat, &tail_dims, &mut tail_indices);
for i in 0..n {
let mut val = temp[[i, t_flat]];
for (dim_idx, &ti) in tail_indices.iter().enumerate() {
val *= self.marginals[dim_idx + 1][[i, ti]];
}
out[i] += val;
}
}
out
} else {
let mut tail_indices = vec![0usize; tail_dims.len()];
let mut out = Array1::<f64>::zeros(n);
let mut beta_slice = Array1::<f64>::zeros(q0);
let mut contrib = Array1::<f64>::zeros(n);
for t_flat in 0..tail_total {
decode_multi_index(t_flat, &tail_dims, &mut tail_indices);
for j1 in 0..q0 {
beta_slice[j1] = vector[j1 * tail_total + t_flat];
}
fast_av_into(b0.as_ref(), &beta_slice, &mut contrib);
for (dim_idx, &ti) in tail_indices.iter().enumerate() {
let m = &self.marginals[dim_idx + 1];
for i in 0..n {
contrib[i] *= m[[i, ti]];
}
}
out += &contrib;
}
out
}
}
fn apply_transpose_vectorized(&self, vector: &Array1<f64>) -> Array1<f64> {
let d = self.marginals.len();
let n = self.n;
if d == 0 {
return Array1::zeros(self.total_cols);
}
let b0 = &self.marginals[0];
let q0 = b0.ncols();
if d == 1 {
return fast_atv(b0.as_ref(), vector);
}
let tail_dims: Vec<usize> = self.marginals[1..].iter().map(|m| m.ncols()).collect();
let tail_total: usize = tail_dims.iter().product();
let intermediate_bytes = n * tail_total * std::mem::size_of::<f64>();
if intermediate_bytes <= TENSOR_GEMM_MAX_INTERMEDIATE_BYTES {
let mut w_mat = Array2::<f64>::zeros((n, tail_total));
let mut tail_indices = vec![0usize; tail_dims.len()];
for t_flat in 0..tail_total {
decode_multi_index(t_flat, &tail_dims, &mut tail_indices);
for i in 0..n {
let mut val = vector[i];
for (dim_idx, &ti) in tail_indices.iter().enumerate() {
val *= self.marginals[dim_idx + 1][[i, ti]];
}
w_mat[[i, t_flat]] = val;
}
}
let result_mat = fast_atb(b0.as_ref(), &w_mat);
let mut out = Array1::<f64>::zeros(self.total_cols);
for j1 in 0..q0 {
for t_flat in 0..tail_total {
out[j1 * tail_total + t_flat] = result_mat[[j1, t_flat]];
}
}
out
} else {
let mut tail_indices = vec![0usize; tail_dims.len()];
let mut out = Array1::<f64>::zeros(self.total_cols);
let mut scaled_v = Array1::<f64>::zeros(n);
let mut col_result = Array1::<f64>::zeros(q0);
for t_flat in 0..tail_total {
decode_multi_index(t_flat, &tail_dims, &mut tail_indices);
scaled_v.assign(vector);
for (dim_idx, &ti) in tail_indices.iter().enumerate() {
let m = &self.marginals[dim_idx + 1];
for i in 0..n {
scaled_v[i] *= m[[i, ti]];
}
}
fast_atv_into(b0.as_ref(), &scaled_v, &mut col_result);
for j1 in 0..q0 {
out[j1 * tail_total + t_flat] = col_result[j1];
}
}
out
}
}
}
impl LinearOperator for TensorProductDesignOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.total_cols
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
self.apply_vectorized(vector)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
self.apply_transpose_vectorized(vector)
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
if weights.len() != self.n {
return Err(format!(
"TensorProductDesignOperator::diag_xtw_x: weights length {} != n {}",
weights.len(),
self.n
));
}
let d = self.marginals.len();
if d == 0 {
return Ok(Array2::zeros((0, 0)));
}
let n = self.n;
let q0 = self.marginals[0].ncols();
let mut xtwx = Array2::<f64>::zeros((self.total_cols, self.total_cols));
let b0 = &self.marginals[0];
let tail_dims: Vec<usize> = self.marginals[1..].iter().map(|m| m.ncols()).collect();
let tail_total: usize = tail_dims.iter().product();
let tail_d = tail_dims.len();
let pairs: Vec<(usize, usize)> = (0..tail_total)
.flat_map(|a_flat| (a_flat..tail_total).map(move |b_flat| (a_flat, b_flat)))
.collect();
let blocks: Vec<(usize, usize, Array2<f64>)> = pairs
.into_par_iter()
.map(|(a_flat, b_flat)| {
let mut a_indices = vec![0usize; tail_d];
let mut b_indices = vec![0usize; tail_d];
decode_multi_index(a_flat, &tail_dims, &mut a_indices);
decode_multi_index(b_flat, &tail_dims, &mut b_indices);
let mut gamma = Array1::<f64>::zeros(n);
for i in 0..n {
let mut prod = weights[i].max(0.0);
if prod != 0.0 {
for dim_idx in 0..tail_d {
let m = &self.marginals[dim_idx + 1];
prod *= m[[i, a_indices[dim_idx]]] * m[[i, b_indices[dim_idx]]];
if prod == 0.0 {
break;
}
}
}
gamma[i] = prod;
}
let mut block = Array2::<f64>::zeros((q0, q0));
streaming_blas_xt_diag_x(b0.as_ref(), &gamma, &mut block);
(a_flat, b_flat, block)
})
.collect();
for (a_flat, b_flat, block) in blocks {
for a1 in 0..q0 {
let ga = a1 * tail_total + a_flat;
for b1 in 0..q0 {
let gb = b1 * tail_total + b_flat;
xtwx[[ga, gb]] += block[[a1, b1]];
if a_flat != b_flat {
let ga_mirror = a1 * tail_total + b_flat;
let gb_mirror = b1 * tail_total + a_flat;
xtwx[[ga_mirror, gb_mirror]] += block[[a1, b1]];
}
}
}
}
Ok(xtwx)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
if weights.len() != self.n {
return Err(format!(
"TensorProductDesignOperator::diag_gram: weights length {} != n {}",
weights.len(),
self.n
));
}
let d = self.marginals.len();
if d == 0 {
return Ok(Array1::zeros(0));
}
let mut diag = vec![0.0_f64; self.total_cols];
let tail_dims: Vec<usize> = self.marginals[1..].iter().map(|m| m.ncols()).collect();
let tail_total: usize = tail_dims.iter().product();
let q0 = self.marginals[0].ncols();
let mut tail_indices = vec![0usize; tail_dims.len()];
for t_flat in 0..tail_total {
decode_multi_index(t_flat, &tail_dims, &mut tail_indices);
for i in 0..self.n {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
let mut tail_prod_sq = wi;
for (dim_idx, &ti) in tail_indices.iter().enumerate() {
let val = self.marginals[dim_idx + 1][[i, ti]];
tail_prod_sq *= val * val;
if tail_prod_sq == 0.0 {
break;
}
}
if tail_prod_sq == 0.0 {
continue;
}
for j1 in 0..q0 {
let b1 = self.marginals[0][[i, j1]];
diag[j1 * tail_total + t_flat] += tail_prod_sq * b1 * b1;
}
}
}
Ok(Array1::from_vec(diag))
}
fn uses_matrix_free_pcg(&self) -> bool {
true
}
}
impl DenseDesignOperator for TensorProductDesignOperator {
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
if middle.nrows() != self.total_cols || middle.ncols() != self.total_cols {
return Err(format!(
"TensorProductDesignOperator::quadratic_form_diag dimension mismatch: {}x{} vs expected {}x{}",
middle.nrows(),
middle.ncols(),
self.total_cols,
self.total_cols
));
}
let mut out = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = self.try_row_chunk(start..end).map_err(|e| e.to_string())?;
let chunk_m = fast_ab(&chunk, middle);
for local in 0..(end - start) {
out[start + local] = chunk.row(local).dot(&chunk_m.row(local)).max(0.0);
}
}
Ok(out)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "TensorProductDesignOperator::row_chunk_into shape mismatch",
});
}
for (local_row, global_row) in rows.enumerate() {
let row_values = self.row_values(global_row);
for (j, &value) in row_values.iter().enumerate() {
out[[local_row, j]] = value;
}
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
self.try_row_chunk(0..self.n)
.expect("TensorProductDesignOperator row_chunk_into is total")
}
}
pub trait SpatialKernelEvaluator: Send + Sync + 'static {
fn eval(&self, x: &[f64], c: &[f64]) -> f64;
}
impl<F> SpatialKernelEvaluator for F
where
F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static,
{
fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
self(x, c)
}
}
impl<F> SpatialKernelEvaluator for Arc<F>
where
F: Fn(&[f64], &[f64]) -> f64 + Send + Sync + 'static + ?Sized,
{
fn eval(&self, x: &[f64], c: &[f64]) -> f64 {
self.as_ref()(x, c)
}
}
pub struct ChunkedKernelDesignOperator<K: SpatialKernelEvaluator> {
data: Arc<Array2<f64>>,
centers: Arc<Array2<f64>>,
kernel: K,
constraint_transform: Option<Arc<Array2<f64>>>,
poly_basis: Option<Arc<Array2<f64>>>,
n: usize,
total_cols: usize,
materialized: OnceLock<Option<Arc<Array2<f64>>>>,
}
impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
pub fn new(
data: Arc<Array2<f64>>,
centers: Arc<Array2<f64>>,
kernel: K,
constraint_transform: Option<Arc<Array2<f64>>>,
poly_basis: Option<Arc<Array2<f64>>>,
) -> Result<Self, String> {
let n = data.nrows();
let k = centers.nrows();
if data.ncols() != centers.ncols() {
return Err(format!(
"ChunkedKernelDesignOperator: data dim {} != centers dim {}",
data.ncols(),
centers.ncols(),
));
}
if let Some(z) = constraint_transform.as_ref() {
if z.nrows() != k {
return Err(format!(
"ChunkedKernelDesignOperator: constraint_transform rows {} != centers rows {}",
z.nrows(),
k,
));
}
}
if let Some(poly) = poly_basis.as_ref() {
if poly.nrows() != n {
return Err(format!(
"ChunkedKernelDesignOperator: poly_basis rows {} != data rows {}",
poly.nrows(),
n,
));
}
}
let k_eff = constraint_transform.as_ref().map_or(k, |z| z.ncols());
let poly_cols = poly_basis.as_ref().map_or(0, |p| p.ncols());
Ok(Self {
data: Arc::new(data.as_standard_layout().to_owned()),
centers: Arc::new(centers.as_standard_layout().to_owned()),
kernel,
constraint_transform,
poly_basis,
n,
total_cols: k_eff + poly_cols,
materialized: OnceLock::new(),
})
}
const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
fn materialized_combined(&self) -> Option<&Array2<f64>> {
self.materialized
.get_or_init(|| {
let bytes = self
.n
.checked_mul(self.total_cols)
.and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
match bytes {
Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
Some(Arc::new(self.build_row_chunk_combined(0..self.n)))
}
_ => None,
}
})
.as_ref()
.map(|a| a.as_ref())
}
fn kernel_chunk(&self, rows: Range<usize>) -> Array2<f64> {
let chunk_n = rows.end - rows.start;
let k_raw = self.centers.nrows();
let dim = self.data.ncols();
let data = self
.data
.as_slice()
.expect("ChunkedKernelDesignOperator stores standard-layout data");
let centers = self
.centers
.as_slice()
.expect("ChunkedKernelDesignOperator stores standard-layout centers");
let kernel = &self.kernel;
let mut values = vec![0.0_f64; chunk_n * k_raw];
values
.par_chunks_mut(k_raw)
.enumerate()
.for_each(|(local, out_row)| {
let global = rows.start + local;
let x_start = global * dim;
let x = &data[x_start..x_start + dim];
for j in 0..k_raw {
let c_start = j * dim;
out_row[j] = kernel.eval(x, ¢ers[c_start..c_start + dim]);
}
});
let kernel_block = Array2::from_shape_vec((chunk_n, k_raw), values)
.expect("kernel chunk shape should match generated values");
if let Some(z) = self.constraint_transform.as_ref() {
fast_ab(&kernel_block, z)
} else {
kernel_block
}
}
}
impl<K: SpatialKernelEvaluator> LinearOperator for ChunkedKernelDesignOperator<K> {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.total_cols
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
if let Some(combined) = self.materialized_combined() {
return dense_matvec(combined, vector);
}
let k_eff = self
.constraint_transform
.as_ref()
.map_or(self.centers.nrows(), |z| z.ncols());
let v_kernel = vector.slice(s![..k_eff]);
let mut result = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = self.kernel_chunk(start..end);
let partial = chunk.dot(&v_kernel);
result.slice_mut(s![start..end]).assign(&partial);
}
if let Some(poly) = self.poly_basis.as_ref() {
let v_poly = vector.slice(s![k_eff..]);
let poly_part = poly.dot(&v_poly);
result += &poly_part;
}
result
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
if let Some(combined) = self.materialized_combined() {
return dense_transpose_matvec(combined, vector);
}
let k_eff = self
.constraint_transform
.as_ref()
.map_or(self.centers.nrows(), |z| z.ncols());
let mut result = Array1::<f64>::zeros(self.total_cols);
for start in (0..self.n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = self.kernel_chunk(start..end);
let v_slice = vector.slice(s![start..end]);
let partial = chunk.t().dot(&v_slice);
result.slice_mut(s![..k_eff]).scaled_add(1.0, &partial);
}
if let Some(poly) = self.poly_basis.as_ref() {
let poly_part = poly.t().dot(vector);
result.slice_mut(s![k_eff..]).assign(&poly_part);
}
result
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let p = self.total_cols;
if let Some(combined) = self.materialized_combined() {
let mut xtwx = Array2::<f64>::zeros((p, p));
streaming_blas_xt_diag_x(combined, weights, &mut xtwx);
return Ok(xtwx);
}
let n = self.n;
if n == 0 || p == 0 {
return Ok(Array2::<f64>::zeros((p, p)));
}
let chunk_starts: Vec<usize> = (0..n).step_by(KERNEL_OPERATOR_ROW_CHUNK_SIZE).collect();
let xtwx = chunk_starts
.into_par_iter()
.fold(
|| Array2::<f64>::zeros((p, p)),
|mut acc, start| {
let end = (start + KERNEL_OPERATOR_ROW_CHUNK_SIZE).min(n);
let chunk = self.row_chunk_combined(start..end);
let mut wchunk = chunk.clone();
for local in 0..(end - start) {
let wi = weights[start + local];
wchunk.row_mut(local).mapv_inplace(|v| v * wi);
}
let chunk_view = FaerArrayView::new(&chunk);
let wchunk_view = FaerArrayView::new(&wchunk);
let mut acc_view = array2_to_matmut(&mut acc);
matmul(
acc_view.as_mut(),
Accum::Add,
chunk_view.as_ref().transpose(),
wchunk_view.as_ref(),
1.0,
Par::Seq,
);
acc
},
)
.reduce(
|| Array2::<f64>::zeros((p, p)),
|mut a, b| {
a += &b;
a
},
);
Ok(xtwx)
}
}
impl<K: SpatialKernelEvaluator> ChunkedKernelDesignOperator<K> {
fn row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
if let Some(combined) = self.materialized_combined() {
return combined.slice(s![rows, ..]).to_owned();
}
self.build_row_chunk_combined(rows)
}
fn build_row_chunk_combined(&self, rows: Range<usize>) -> Array2<f64> {
let chunk_n = rows.end - rows.start;
let k_eff = self
.constraint_transform
.as_ref()
.map_or(self.centers.nrows(), |z| z.ncols());
let kernel = self.kernel_chunk(rows.clone());
let poly_cols = self.poly_basis.as_ref().map_or(0, |p| p.ncols());
let mut combined = Array2::<f64>::zeros((chunk_n, k_eff + poly_cols));
combined.slice_mut(s![.., ..k_eff]).assign(&kernel);
if let Some(poly) = self.poly_basis.as_ref() {
combined
.slice_mut(s![.., k_eff..])
.assign(&poly.slice(s![rows, ..]));
}
combined
}
}
impl<K: SpatialKernelEvaluator> DenseDesignOperator for ChunkedKernelDesignOperator<K> {
fn as_dense_ref(&self) -> Option<&Array2<f64>> {
self.materialized_combined()
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.total_cols {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "ChunkedKernelDesignOperator::row_chunk_into shape mismatch",
});
}
if let Some(combined) = self.materialized_combined() {
out.assign(&combined.slice(s![rows, ..]));
} else {
out.assign(&self.row_chunk_combined(rows));
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
if let Some(combined) = self.materialized_combined() {
return combined.clone();
}
self.row_chunk_combined(0..self.n)
}
}
pub struct CoefficientTransformOperator {
inner: DenseDesignMatrix,
transform: Arc<Array2<f64>>,
n: usize,
p_out: usize,
materialized: OnceLock<Option<Arc<Array2<f64>>>>,
}
impl CoefficientTransformOperator {
const MATERIALIZE_MAX_BYTES: usize = 1024 * 1024 * 1024;
pub fn new(inner: DenseDesignMatrix, transform: Array2<f64>) -> Result<Self, String> {
let p_inner = inner.ncols();
if transform.nrows() != p_inner {
return Err(format!(
"CoefficientTransformOperator: inner has {} cols but transform has {} rows",
p_inner,
transform.nrows(),
));
}
let n = inner.nrows();
let p_out = transform.ncols();
Ok(Self {
inner,
transform: Arc::new(transform),
n,
p_out,
materialized: OnceLock::new(),
})
}
fn materialized_combined(&self) -> Option<&Array2<f64>> {
self.materialized
.get_or_init(|| {
let bytes = self
.n
.checked_mul(self.p_out)
.and_then(|cells| cells.checked_mul(std::mem::size_of::<f64>()));
match bytes {
Some(b) if b <= Self::MATERIALIZE_MAX_BYTES => {
let auto_policy = ResourcePolicy::for_problem(
self.n,
self.p_out,
crate::resource::ProblemHints::default(),
);
let cache_policy = ResourcePolicy {
max_single_materialization_bytes: Self::MATERIALIZE_MAX_BYTES,
derivative_storage_mode: auto_policy.derivative_storage_mode,
..ResourcePolicy::default_library()
};
self.inner
.try_to_dense_arc_with_policy(
"CoefficientTransformOperator materialization",
&cache_policy,
)
.ok()
.map(|x| Arc::new(fast_ab(x.as_ref(), &self.transform)))
}
_ => None,
}
})
.as_ref()
.map(|a| a.as_ref())
}
}
impl LinearOperator for CoefficientTransformOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.p_out
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
if let Some(combined) = self.materialized_combined() {
return dense_matvec(combined, vector);
}
let tv = self.transform.dot(vector);
self.inner.apply(&tv)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
if let Some(combined) = self.materialized_combined() {
return dense_transpose_matvec(combined, vector);
}
let xtv = self.inner.apply_transpose(vector);
self.transform.t().dot(&xtv)
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
if let Some(combined) = self.materialized_combined() {
let mut xtwx = Array2::<f64>::zeros((self.p_out, self.p_out));
streaming_blas_xt_diag_x(combined, weights, &mut xtwx);
return Ok(xtwx);
}
let inner_xtwx = self.inner.diag_xtw_x(weights)?;
let tmp = fast_ab(&self.transform.t().to_owned(), &inner_xtwx);
Ok(fast_ab(&tmp, &self.transform))
}
}
impl DenseDesignOperator for CoefficientTransformOperator {
fn as_dense_ref(&self) -> Option<&Array2<f64>> {
self.materialized_combined()
}
fn to_dense(&self) -> Array2<f64> {
if let Some(combined) = self.materialized_combined() {
return combined.clone();
}
let x = self.inner.to_dense();
fast_ab(&x, &self.transform)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.p_out {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "CoefficientTransformOperator::row_chunk_into shape mismatch",
});
}
if let Some(combined) = self.materialized_combined() {
out.assign(&combined.slice(s![rows, ..]));
return Ok(());
}
let chunk = self.inner.try_row_chunk(rows)?;
out.assign(&fast_ab(&chunk, &self.transform));
Ok(())
}
}
impl RowwiseKroneckerOperator {
pub fn new(cov: DesignMatrix, time_basis: Arc<Array2<f64>>) -> Result<Self, String> {
let n = cov.nrows();
if time_basis.nrows() != n {
return Err(format!(
"RowwiseKroneckerOperator: cov has {} rows but time_basis has {}",
n,
time_basis.nrows()
));
}
let p_cov = cov.ncols();
let p_time = time_basis.ncols();
Ok(Self {
cov,
time_basis,
n,
p_cov,
p_time,
})
}
}
impl LinearOperator for RowwiseKroneckerOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.p_cov * self.p_time
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let p_cov = self.p_cov;
let p_time = self.p_time;
let n = self.n;
let time = self.time_basis.as_ref();
let mut out = Array1::<f64>::zeros(n);
let mut beta_slice = Array1::<f64>::zeros(p_cov);
for t in 0..p_time {
for j in 0..p_cov {
beta_slice[j] = vector[j * p_time + t];
}
let cov_beta_t = self.cov.matrixvectormultiply(&beta_slice);
let time_col = time.column(t);
ndarray::Zip::from(&mut out)
.and(&cov_beta_t)
.and(&time_col)
.par_for_each(|o, &cb, &tt| *o += cb * tt);
}
out
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let p_cov = self.p_cov;
let p_time = self.p_time;
let n = self.n;
let time = self.time_basis.as_ref();
let mut out = Array1::<f64>::zeros(p_cov * p_time);
let mut w_t = Array1::<f64>::zeros(n);
for t in 0..p_time {
let time_col = time.column(t);
ndarray::Zip::from(&mut w_t)
.and(vector)
.and(&time_col)
.par_for_each(|o, &v, &tt| *o = v * tt);
let col_t = self.cov.transpose_vector_multiply(&w_t);
for j in 0..p_cov {
out[j * p_time + t] = col_t[j];
}
}
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let n = self.n;
let p_cov = self.p_cov;
let p_time = self.p_time;
let p_total = p_cov * p_time;
if weights.len() != n {
return Err(format!(
"RowwiseKroneckerOperator::diag_xtw_x: weights length {} != n {}",
weights.len(),
n
));
}
let mut xtwx = Array2::<f64>::zeros((p_total, p_total));
let time = self.time_basis.as_ref();
let pairs: Vec<(usize, usize)> = (0..p_time)
.flat_map(|t1| (0..=t1).map(move |t2| (t1, t2)))
.collect();
let blocks: Result<Vec<(usize, usize, Array2<f64>)>, String> = pairs
.into_par_iter()
.map(|(t1, t2)| {
let time_t1 = time.column(t1);
let time_t2 = time.column(t2);
let mut gamma = Array1::<f64>::zeros(n);
ndarray::Zip::from(&mut gamma)
.and(weights)
.and(&time_t1)
.and(&time_t2)
.for_each(|g, &w, &a, &b| *g = w.max(0.0) * a * b);
self.cov.compute_xtwx(&gamma).map(|block| (t1, t2, block))
})
.collect();
for (t1, t2, block) in blocks? {
for j1 in 0..p_cov {
for j2 in 0..p_cov {
xtwx[[j1 * p_time + t1, j2 * p_time + t2]] = block[[j1, j2]];
if t1 != t2 {
xtwx[[j1 * p_time + t2, j2 * p_time + t1]] = block[[j1, j2]];
}
}
}
}
Ok(xtwx)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let n = self.n;
let p_cov = self.p_cov;
let p_time = self.p_time;
if weights.len() != n {
return Err(format!(
"RowwiseKroneckerOperator::diag_gram: weights {} != n {}",
weights.len(),
n
));
}
let time = self.time_basis.as_ref();
let mut out = Array1::<f64>::zeros(p_cov * p_time);
let mut gamma = Array1::<f64>::zeros(n);
for t in 0..p_time {
let time_col = time.column(t);
ndarray::Zip::from(&mut gamma)
.and(weights)
.and(&time_col)
.par_for_each(|g, &w, &tt| *g = w.max(0.0) * tt * tt);
let cov_diag = <DesignMatrix as LinearOperator>::diag_gram(&self.cov, &gamma)?;
for j in 0..p_cov {
out[j * p_time + t] = cov_diag[j];
}
}
Ok(out)
}
fn uses_matrix_free_pcg(&self) -> bool {
true
}
}
impl DenseDesignOperator for RowwiseKroneckerOperator {
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
let p_total = self.p_cov * self.p_time;
if middle.nrows() != p_total || middle.ncols() != p_total {
return Err(format!(
"RowwiseKroneckerOperator::quadratic_form_diag dimension mismatch: {}x{} vs expected {}x{}",
middle.nrows(),
middle.ncols(),
p_total,
p_total
));
}
let mut out = Array1::<f64>::zeros(self.n);
for start in (0..self.n).step_by(OPERATOR_ROW_CHUNK_SIZE) {
let end = (start + OPERATOR_ROW_CHUNK_SIZE).min(self.n);
let chunk = self.try_row_chunk(start..end).map_err(|e| e.to_string())?;
let chunk_m = fast_ab(&chunk, middle);
for local in 0..(end - start) {
out[start + local] = chunk.row(local).dot(&chunk_m.row(local)).max(0.0);
}
}
Ok(out)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
let p_cov = self.p_cov;
let p_time = self.p_time;
let chunk_rows = rows.end - rows.start;
if out.nrows() != chunk_rows || out.ncols() != p_cov * p_time {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "RowwiseKroneckerOperator::row_chunk_into shape mismatch",
});
}
out.fill(0.0);
let cov_chunk = self.cov.try_row_chunk(rows.clone())?;
let time = self.time_basis.as_ref();
for local in 0..chunk_rows {
let global = rows.start + local;
for j in 0..p_cov {
let cij = cov_chunk[[local, j]];
if cij == 0.0 {
continue;
}
for t in 0..p_time {
out[[local, j * p_time + t]] = cij * time[[global, t]];
}
}
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
let n = self.n;
let p_cov = self.p_cov;
let p_time = self.p_time;
let bytes = n
.saturating_mul(p_cov)
.saturating_mul(p_time)
.saturating_mul(std::mem::size_of::<f64>());
panic!(
"RowwiseKroneckerOperator must remain operator-backed; refused persistent n x p_covariate x p_time materialization (n={n}, p_covariate={p_cov}, p_time={p_time}, dense={:.1} MiB)",
bytes as f64 / (1024.0 * 1024.0),
);
}
}
pub struct ConditionedDesign {
inner: DesignMatrix,
columns: Vec<(usize, f64, f64)>,
}
impl ConditionedDesign {
pub fn new(inner: DesignMatrix, columns: Vec<(usize, f64, f64)>) -> Self {
Self { inner, columns }
}
}
impl LinearOperator for ConditionedDesign {
fn nrows(&self) -> usize {
self.inner.nrows()
}
fn ncols(&self) -> usize {
self.inner.ncols()
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut scaled = vector.clone();
let mut shift = 0.0;
for &(j, mean, scale) in &self.columns {
scaled[j] /= scale;
shift += mean * scaled[j];
}
let mut result = self.inner.apply(&scaled);
if shift != 0.0 {
result.mapv_inplace(|v| v - shift);
}
result
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut result = self.inner.apply_transpose(vector);
let sum_u: f64 = vector.iter().sum();
for &(j, mean, scale) in &self.columns {
result[j] = (result[j] - mean * sum_u) / scale;
}
result
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let mut base = self.inner.diag_xtw_x(weights)?;
if self.columns.is_empty() {
return Ok(base);
}
let p = base.ncols();
let w_pos: Array1<f64> = weights.mapv(|w| w.max(0.0));
let sum_w: f64 = w_pos.sum();
let cw = self.inner.apply_transpose(&w_pos);
let mut a = vec![1.0_f64; p];
let mut d = vec![0.0_f64; p];
for &(j, mean, scale) in &self.columns {
a[j] = 1.0 / scale;
d[j] = mean / scale;
}
for i in 0..p {
for j in i..p {
let val = a[i] * base[[i, j]] * a[j] - a[i] * cw[i] * d[j] - d[i] * cw[j] * a[j]
+ sum_w * d[i] * d[j];
base[[i, j]] = val;
base[[j, i]] = val;
}
}
Ok(base)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut result = self.inner.diag_gram(weights)?;
if self.columns.is_empty() {
return Ok(result);
}
let w_pos: Array1<f64> = weights.mapv(|w| w.max(0.0));
let sum_w: f64 = w_pos.sum();
let cw = self.inner.apply_transpose(&w_pos);
for &(j, mean, scale) in &self.columns {
let a_j = 1.0 / scale;
let d_j = mean / scale;
result[j] = a_j * a_j * result[j] - 2.0 * a_j * cw[j] * d_j + sum_w * d_j * d_j;
}
Ok(result)
}
fn uses_matrix_free_pcg(&self) -> bool {
match &self.inner {
DesignMatrix::Dense(_) => true,
DesignMatrix::Sparse(_) => false,
}
}
}
impl DenseDesignOperator for ConditionedDesign {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut result = self.inner.compute_xtwy(weights, y)?;
if self.columns.is_empty() {
return Ok(result);
}
let sum_wy: f64 = weights
.iter()
.zip(y.iter())
.map(|(&w, &yi)| w.max(0.0) * yi)
.sum();
for &(j, mean, scale) in &self.columns {
result[j] = (result[j] - mean * sum_wy) / scale;
}
Ok(result)
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
if self.columns.is_empty() {
return self.inner.quadratic_form_diag(middle);
}
let p = self.ncols();
let mut d = Array1::zeros(p);
for &(j, mean, scale) in &self.columns {
d[j] = mean / scale;
}
let mut ama = middle.clone();
for &(j, _, scale) in &self.columns {
for k in 0..p {
ama[[j, k]] /= scale;
ama[[k, j]] /= scale;
}
}
let md = middle.dot(&d);
let mut amd = md;
for &(j, _, scale) in &self.columns {
amd[j] /= scale;
}
let dtmd: f64 = d.dot(&middle.dot(&d));
let mut result = self.inner.quadratic_form_diag(&ama)?;
let x_amd = self.inner.apply(&amd);
for i in 0..result.len() {
result[i] = (result[i] - 2.0 * x_amd[i] + dtmd).max(0.0);
}
Ok(result)
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "ConditionedDesign::row_chunk_into shape mismatch",
});
}
let mut chunk = self.inner.try_row_chunk(rows)?;
for &(j, mean, scale) in &self.columns {
chunk.column_mut(j).mapv_inplace(|v| (v - mean) / scale);
}
out.assign(&chunk);
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
let mut dense = self.inner.to_dense();
for &(j, mean, scale) in &self.columns {
dense.column_mut(j).mapv_inplace(|v| (v - mean) / scale);
}
dense
}
}
#[derive(Clone)]
pub enum DesignMatrix {
Dense(DenseDesignMatrix),
Sparse(SparseDesignMatrix),
}
impl std::fmt::Debug for DesignMatrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense(m) => write!(f, "DesignMatrix::Dense({}x{})", m.nrows(), m.ncols()),
Self::Sparse(s) => write!(f, "DesignMatrix::Sparse({}x{})", s.nrows(), s.ncols()),
}
}
}
#[derive(Clone, Debug)]
pub enum SymmetricMatrix {
Dense(Array2<f64>),
Sparse(faer::sparse::SparseColMat<usize, f64>),
}
impl SymmetricMatrix {
pub fn as_dense(&self) -> Option<&Array2<f64>> {
match self {
Self::Dense(mat) => Some(mat),
Self::Sparse(_) => None,
}
}
pub fn as_sparse(&self) -> Option<&faer::sparse::SparseColMat<usize, f64>> {
match self {
Self::Sparse(mat) => Some(mat),
Self::Dense(_) => None,
}
}
pub fn to_dense(&self) -> Array2<f64> {
match self {
Self::Dense(mat) => mat.clone(),
Self::Sparse(mat) => {
let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..mat.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
out[[row, col]] += value;
if row != col {
out[[col, row]] += value;
}
}
}
out
}
}
}
pub fn try_to_dense_exact(&self, context: &str) -> Result<Array2<f64>, String> {
if self.nrows() != self.ncols() {
return Err(format!(
"{context}: exact symmetric matrix must be square, got {}x{}",
self.nrows(),
self.ncols()
));
}
let dense = self.to_dense();
if dense.iter().any(|v| !v.is_finite()) {
return Err(format!(
"{context}: exact dense materialization contains non-finite entries"
));
}
Ok(dense)
}
pub fn factorize(&self) -> Result<Box<dyn FactorizedSystem>, String> {
match self {
Self::Dense(mat) => {
let factor = crate::linalg::utils::StableSolver::new("unnamed")
.factorize(mat)
.map_err(|e| format!("Dense SymmetricMatrix factorization failed: {e:?}"))?;
Ok(Box::new(factor))
}
Self::Sparse(mat) => {
let factor = crate::linalg::sparse_exact::factorize_sparse_spd(mat)
.map_err(|e| format!("Sparse SymmetricMatrix factorization failed: {e:?}"))?;
Ok(Box::new(factor))
}
}
}
pub fn add(&self, other: &SymmetricMatrix) -> Result<Self, String> {
if self.nrows() != other.nrows() || self.ncols() != other.ncols() {
return Err(format!(
"SymmetricMatrix::add shape mismatch: lhs {}x{}, rhs {}x{}",
self.nrows(),
self.ncols(),
other.nrows(),
other.ncols()
));
}
match (self, other) {
(Self::Dense(a), Self::Dense(b)) => Ok(Self::Dense(a + b)),
(Self::Dense(a), Self::Sparse(_)) => {
let b_dense = other.to_dense();
Ok(Self::Dense(a + &b_dense))
}
(Self::Sparse(_), Self::Dense(b)) => {
let a_dense = self.to_dense();
Ok(Self::Dense(&a_dense + b))
}
(Self::Sparse(a), Self::Sparse(b)) => {
Ok(Self::Sparse(add_sparse_symmetric_upper(a, b)?))
}
}
}
pub(crate) fn add_dense(&self, other: &Array2<f64>) -> Result<Self, String> {
if self.nrows() != other.nrows() || self.ncols() != other.ncols() {
return Err(format!(
"SymmetricMatrix::add_dense shape mismatch: lhs {}x{}, rhs {}x{}",
self.nrows(),
self.ncols(),
other.nrows(),
other.ncols()
));
}
match self {
Self::Dense(mat) => {
let mut out = mat.clone();
out += other;
Ok(Self::Dense(out))
}
Self::Sparse(mat) => {
let other_sparse =
crate::linalg::sparse_exact::dense_to_sparse_symmetric_upper(other, 0.0)
.map_err(|e| format!("SymmetricMatrix::add_dense failed: {e}"))?;
Ok(Self::Sparse(add_sparse_symmetric_upper(
mat,
&other_sparse,
)?))
}
}
}
pub fn addridge(&self, ridge: f64) -> Result<Self, String> {
if ridge == 0.0 {
return Ok(self.clone());
}
match self {
Self::Dense(mat) => {
let mut out = mat.clone();
for i in 0..out.nrows() {
out[[i, i]] += ridge;
}
Ok(Self::Dense(out))
}
Self::Sparse(mat) => {
let n = mat.nrows();
let mut trip = Vec::with_capacity(n);
for i in 0..n {
trip.push(Triplet::new(i, i, ridge));
}
let diagonal = SparseColMat::<usize, f64>::try_new_from_triplets(n, n, &trip)
.map_err(|_| {
"SymmetricMatrix::addridge failed to assemble sparse diagonal".to_string()
})?;
Ok(Self::Sparse(add_sparse_symmetric_upper(mat, &diagonal)?))
}
}
}
pub fn nrows(&self) -> usize {
match self {
Self::Dense(m) => m.nrows(),
Self::Sparse(m) => m.nrows(),
}
}
pub fn ncols(&self) -> usize {
match self {
Self::Dense(m) => m.ncols(),
Self::Sparse(m) => m.ncols(),
}
}
pub fn dot(&self, rhs: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(mat) => mat.dot(rhs),
Self::Sparse(mat) => {
let mut out = Array1::<f64>::zeros(mat.nrows());
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..mat.ncols() {
let rhs_j = rhs[col];
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
out[row] += value * rhs_j;
if row != col {
out[col] += value * rhs[row];
}
}
}
out
}
}
}
pub fn max_abs_diag(&self) -> f64 {
match self {
Self::Dense(mat) => {
let n = mat.nrows().min(mat.ncols());
(0..n).map(|i| mat[[i, i]].abs()).fold(0.0_f64, f64::max)
}
Self::Sparse(mat) => {
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
let mut max_val = 0.0_f64;
for col in 0..mat.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
if row_idx[idx] == col {
max_val = max_val.max(values[idx].abs());
}
}
}
max_val
}
}
}
pub fn dot_matrix(&self, rhs: &Array2<f64>) -> Array2<f64> {
match self {
Self::Dense(mat) => mat.dot(rhs),
Self::Sparse(mat) => {
let n = mat.nrows();
let k = rhs.ncols();
let mut out = Array2::<f64>::zeros((n, k));
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..mat.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
for c in 0..k {
out[[row, c]] += value * rhs[[col, c]];
if row != col {
out[[col, c]] += value * rhs[[row, c]];
}
}
}
}
out
}
}
}
pub fn left_dot_matrix(&self, lhs: &Array2<f64>) -> Array2<f64> {
let lhs_t = lhs.t().to_owned();
let result_t = self.dot_matrix(&lhs_t);
result_t.t().to_owned()
}
}
pub fn xt_diag_x_symmetric(
design: &DesignMatrix,
diag: &Array1<f64>,
) -> Result<SymmetricMatrix, String> {
if design.nrows() != diag.len() {
return Err(format!(
"xt_diag_x_symmetric row mismatch: design has {} rows but diag has {} entries",
design.nrows(),
diag.len()
));
}
match design {
DesignMatrix::Dense(x) => Ok(SymmetricMatrix::Dense(x.diag_xtw_x(diag)?)),
DesignMatrix::Sparse(xs) => {
let n = xs.nrows();
let p = xs.ncols();
let nnz_x = xs.val().len();
let avg_nnz_row = if n > 0 { nnz_x / n } else { p };
let dense_regime = 4 * avg_nnz_row >= p;
if dense_regime {
let xd = xs.to_dense_arc();
let mut xtwx = Array2::<f64>::zeros((p, p));
streaming_blas_xt_diag_x(xd.as_ref(), diag, &mut xtwx);
return Ok(SymmetricMatrix::Dense(xtwx));
}
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let csr = xs
.to_csr_arc()
.ok_or_else(|| "xt_diag_x_symmetric: failed to obtain CSR view".to_string())?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let acc_template = SparseHessianAccumulator::from_single_csr(&*csr, p);
let n_threads = rayon::current_num_threads().max(1);
let target_chunks = (n_threads * 16).max(n_threads);
let chunk_rows = (n / target_chunks).max(256).min(n.max(1));
let chunk_starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
let mut local_accs: Vec<SparseHessianAccumulator> = chunk_starts
.into_par_iter()
.map(|start| {
let end = (start + chunk_rows).min(n);
let mut local = acc_template.empty_clone();
for i in start..end {
let wi = diag[i];
if wi == 0.0 {
continue;
}
let r_start = row_ptr[i];
let r_end = row_ptr[i + 1];
for a_ptr in r_start..r_end {
let a = col_idx[a_ptr];
let wxa = wi * vals[a_ptr];
local.add_upper(a, a, wxa * vals[a_ptr]);
for b_ptr in (a_ptr + 1)..r_end {
let b = col_idx[b_ptr];
local.add_upper(a, b, wxa * vals[b_ptr]);
}
}
}
local
})
.collect();
let mut acc = if let Some(first) = local_accs.pop() {
first
} else {
acc_template.empty_clone()
};
for other in local_accs.into_iter() {
acc.add_values(&other.values);
}
Ok(SymmetricMatrix::Sparse(acc.into_sparse_col_mat()))
}
}
}
fn add_sparse_symmetric_upper(
lhs: &SparseColMat<usize, f64>,
rhs: &SparseColMat<usize, f64>,
) -> Result<SparseColMat<usize, f64>, String> {
if lhs.nrows() != rhs.nrows() || lhs.ncols() != rhs.ncols() {
return Err(format!(
"add_sparse_symmetric_upper shape mismatch: lhs {}x{}, rhs {}x{}",
lhs.nrows(),
lhs.ncols(),
rhs.nrows(),
rhs.ncols()
));
}
let mut upper = BTreeMap::<(usize, usize), f64>::new();
for matrix in [lhs, rhs] {
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
for idx in col_ptr[col]..col_ptr[col + 1] {
let row = row_idx[idx];
let key = if row <= col { (row, col) } else { (col, row) };
*upper.entry(key).or_insert(0.0) += values[idx];
}
}
}
let triplets: Vec<_> = upper
.into_iter()
.filter_map(|((row, col), value)| (value != 0.0).then_some(Triplet::new(row, col, value)))
.collect();
SparseColMat::try_new_from_triplets(lhs.nrows(), lhs.ncols(), &triplets)
.map_err(|_| "add_sparse_symmetric_upper failed to assemble CSC".to_string())
}
struct SparseHessianSymbolic {
dim: usize,
nnz: usize,
col_ptrs: Vec<usize>,
row_indices: Vec<usize>,
first_row: Vec<usize>,
contiguous: bool,
}
impl SparseHessianSymbolic {
fn build(csrs: &[&SparseRowMat<usize, f64>], dim: usize) -> Self {
use std::collections::BTreeSet;
let n = csrs[0].nrows();
let mut rows_by_col = vec![BTreeSet::<usize>::new(); dim];
let mut cols = Vec::with_capacity(32);
for i in 0..n {
cols.clear();
for csr in csrs {
let sym = csr.symbolic();
let rp = sym.row_ptr();
let ci = sym.col_idx();
for p in rp[i]..rp[i + 1] {
cols.push(ci[p]);
}
}
cols.sort_unstable();
cols.dedup();
for (ai, &ca) in cols.iter().enumerate() {
assert!(
ca < dim,
"SparseHessianSymbolic::build: column index {ca} out of Hessian dimension {dim}"
);
for &cb in &cols[ai..] {
assert!(
cb < dim,
"SparseHessianSymbolic::build: column index {cb} out of Hessian dimension {dim}"
);
rows_by_col[cb].insert(ca);
}
}
}
let nnz = rows_by_col.iter().map(BTreeSet::len).sum();
let mut col_ptrs = Vec::with_capacity(dim + 1);
let mut row_indices = Vec::with_capacity(nnz);
col_ptrs.push(0);
for rows in rows_by_col {
row_indices.extend(rows);
col_ptrs.push(row_indices.len());
}
let mut first_row = vec![usize::MAX; dim];
let mut contiguous = true;
for c in 0..dim {
let start = col_ptrs[c];
let end = col_ptrs[c + 1];
if start == end {
continue;
}
first_row[c] = row_indices[start];
for (off, &ri) in row_indices[start..end].iter().enumerate() {
if ri != first_row[c] + off {
contiguous = false;
break;
}
}
if !contiguous {
break;
}
}
SparseHessianSymbolic {
dim,
nnz,
col_ptrs,
row_indices,
first_row,
contiguous,
}
}
}
pub struct SparseHessianAccumulator {
sym: Arc<SparseHessianSymbolic>,
pub values: Vec<f64>,
}
impl Clone for SparseHessianAccumulator {
fn clone(&self) -> Self {
SparseHessianAccumulator {
sym: Arc::clone(&self.sym),
values: self.values.clone(),
}
}
}
impl SparseHessianAccumulator {
pub fn from_single_csr(csr: &SparseRowMat<usize, f64>, dim: usize) -> Self {
Self::from_multi_csr(&[csr], dim)
}
pub fn from_multi_csr(csrs: &[&SparseRowMat<usize, f64>], dim: usize) -> Self {
let sym = Arc::new(SparseHessianSymbolic::build(csrs, dim));
let nnz = sym.nnz;
SparseHessianAccumulator {
sym,
values: vec![0.0; nnz],
}
}
#[inline(always)]
pub fn add_upper(&mut self, r: usize, c: usize, val: f64) {
assert!(r <= c, "add_upper requires r <= c, got ({r}, {c})");
let s = &*self.sym;
if s.contiguous {
let idx = s.col_ptrs[c] + (r - s.first_row[c]);
assert!(idx < s.col_ptrs[c + 1], "add_upper contiguous OOB");
unsafe {
*self.values.get_unchecked_mut(idx) += val;
}
} else {
let start = s.col_ptrs[c];
let end = s.col_ptrs[c + 1];
let slice = &s.row_indices[start..end];
for (off, &ri) in slice.iter().enumerate() {
if ri == r {
unsafe {
*self.values.get_unchecked_mut(start + off) += val;
}
return;
}
}
#[cfg(debug_assertions)]
unreachable!("SparseHessianAccumulator::add_upper: ({r}, {c}) not in pattern");
}
}
#[inline]
pub fn add_values(&mut self, other: &[f64]) {
debug_assert_eq!(self.values.len(), other.len());
for (a, &b) in self.values.iter_mut().zip(other.iter()) {
*a += b;
}
}
pub fn empty_clone(&self) -> Self {
SparseHessianAccumulator {
sym: Arc::clone(&self.sym),
values: vec![0.0; self.values.len()],
}
}
pub fn into_sparse_col_mat(self) -> SparseColMat<usize, f64> {
use faer::sparse::SymbolicSparseColMat;
let (col_ptrs, row_indices, dim) = match Arc::try_unwrap(self.sym) {
Ok(owned) => (owned.col_ptrs, owned.row_indices, owned.dim),
Err(shared) => (
shared.col_ptrs.clone(),
shared.row_indices.clone(),
shared.dim,
),
};
let symbolic =
unsafe { SymbolicSparseColMat::new_unchecked(dim, dim, col_ptrs, None, row_indices) };
SparseColMat::new(symbolic, self.values)
}
}
pub trait FactorizedSystem: Send + Sync {
fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String>;
fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String>;
fn logdet(&self) -> f64;
}
pub trait LinearOperator {
fn nrows(&self) -> usize;
fn ncols(&self) -> usize;
fn apply(&self, vector: &Array1<f64>) -> Array1<f64>;
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64>;
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String>;
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let xtwx = self.diag_xtw_x(weights)?;
Ok(Array1::from_iter((0..self.ncols()).map(|j| xtwx[[j, j]])))
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
let xv = self.apply(vector);
let mut weighted_xv = xv;
for i in 0..weighted_xv.len() {
weighted_xv[i] *= weights[i].max(0.0);
}
let mut out = self.apply_transpose(&weighted_xv);
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
out += &vector.mapv(|x| ridge * x);
}
out
}
fn uses_matrix_free_pcg(&self) -> bool {
false
}
fn solve_system_matrix_free_pcg_try(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
baseridge: f64,
) -> Result<Array1<f64>, String> {
self.solve_system_matrix_free_pcg_with_info_try(weights, rhs, penalty, baseridge)
.map(|(solution, _)| solution)
}
fn solve_system_matrix_free_pcg_with_info_try(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
baseridge: f64,
) -> Result<(Array1<f64>, PcgSolveInfo), String> {
if rhs.len() != self.ncols() {
return Err(format!(
"solve_system_matrix_free_pcg rhs dimension mismatch: rhs length {} != ncols {}",
rhs.len(),
self.ncols()
));
}
if !self.uses_matrix_free_pcg() {
return Err("matrix-free PCG is only enabled for eligible operator types".to_string());
}
if let Some(pen) = penalty
&& (pen.nrows() != self.ncols() || pen.ncols() != self.ncols())
{
return Err(format!(
"solve_system_matrix_free_pcg penalty shape mismatch: got {}x{}, expected {}x{}",
pen.nrows(),
pen.ncols(),
self.ncols(),
self.ncols()
));
}
for retry in 0..8 {
let ridge = if baseridge > 0.0 {
baseridge * 10f64.powi(retry as i32)
} else {
0.0
};
let normal_op = PenalizedWeightedNormalOperator {
operator: self,
weights,
penalty,
ridge,
};
let preconditioner = normal_op.jacobi_preconditioner()?;
let solved = crate::linalg::utils::solve_spd_pcg_with_info(
|v| normal_op.apply(v),
rhs,
&preconditioner,
MATRIX_FREE_PCG_REL_TOL,
MATRIX_FREE_PCG_MAX_ITER.max(4 * self.ncols()),
);
if let Some((solution, info)) = solved
&& solution.iter().all(|v| v.is_finite())
{
return Ok((solution, info));
}
}
Err("matrix-free PCG failed after ridge retries".to_string())
}
fn factorize_system(
&self,
weights: &Array1<f64>,
penalty: Option<&Array2<f64>>,
) -> Result<Box<dyn FactorizedSystem>, String> {
let mut system = self.diag_xtw_x(weights)?;
if let Some(pen) = penalty {
if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
return Err(format!(
"factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
pen.nrows(),
pen.ncols(),
system.nrows(),
system.ncols()
));
}
system += pen;
}
let factor = crate::linalg::utils::StableSolver::new("linear operator system")
.factorize(&system)
.map_err(|e| format!("factorize_system failed: {e:?}"))?;
Ok(Box::new(factor))
}
fn solve_system(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
) -> Result<Array1<f64>, String> {
self.solve_systemwith_policy(
weights,
rhs,
penalty,
1e-15,
RidgePolicy::explicit_stabilization_pospart(),
)
}
fn solve_systemwith_policy(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge_floor: f64,
ridge_policy: RidgePolicy,
) -> Result<Array1<f64>, String> {
if rhs.len() != self.ncols() {
return Err(format!(
"solve_systemwith_policy rhs dimension mismatch: rhs length {} != ncols {}",
rhs.len(),
self.ncols()
));
}
let baseridge = if ridge_policy.include_laplacehessian {
ridge_floor.max(1e-15)
} else {
0.0
};
if self.uses_matrix_free_pcg() && self.ncols() >= MATRIX_FREE_PCG_MIN_P {
if let Ok(solution) =
self.solve_system_matrix_free_pcg_try(weights, rhs, penalty, baseridge)
{
return Ok(solution);
}
}
let mut system = self.diag_xtw_x(weights)?;
if let Some(pen) = penalty {
if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
return Err(format!(
"solve_systemwith_policy penalty shape mismatch: got {}x{}, expected {}x{}",
pen.nrows(),
pen.ncols(),
system.nrows(),
system.ncols()
));
}
system += pen;
}
crate::linalg::utils::StableSolver::new("linear operator system")
.solvevectorwithridge_retries(&system, rhs, baseridge)
.ok_or_else(|| "solve_systemwith_policy failed after ridge retries".to_string())
}
fn matvec(&self, vector: &Array1<f64>) -> Array1<f64> {
self.apply(vector)
}
fn matvec_trans(&self, vector: &Array1<f64>) -> Array1<f64> {
self.apply_transpose(vector)
}
fn compute_xtwx(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
self.diag_xtw_x(weights)
}
}
impl LinearOperator for DesignMatrix {
fn uses_matrix_free_pcg(&self) -> bool {
match self {
Self::Dense(matrix) => matrix.uses_matrix_free_pcg(),
Self::Sparse(_) => false,
}
}
fn nrows(&self) -> usize {
match self {
Self::Dense(matrix) => matrix.nrows(),
Self::Sparse(matrix) => matrix.nrows(),
}
}
fn ncols(&self) -> usize {
match self {
Self::Dense(matrix) => matrix.ncols(),
Self::Sparse(matrix) => matrix.ncols(),
}
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(matrix) => matrix.apply(vector),
Self::Sparse(matrix) => {
let mut output = Array1::<f64>::zeros(matrix.nrows());
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
let x = vector[col];
for idx in start..end {
let row = row_idx[idx];
output[row] += values[idx] * x;
}
}
output
}
}
}
fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
match self {
Self::Dense(matrix) => matrix.apply_weighted_normal(weights, vector, penalty, ridge),
Self::Sparse(_) => {
let sparse = self
.as_sparse()
.expect("DesignMatrix::Sparse must expose sparse view");
let mut out = if let Some(csr) = sparse.to_csr_arc() {
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let mut fused = Array1::<f64>::zeros(self.ncols());
for i in 0..self.nrows() {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
let start = row_ptr[i];
let end = row_ptr[i + 1];
let mut row_dot = 0.0_f64;
for ptr in start..end {
row_dot += vals[ptr] * vector[col_idx[ptr]];
}
if row_dot == 0.0 {
continue;
}
let scaled = wi * row_dot;
for ptr in start..end {
fused[col_idx[ptr]] += vals[ptr] * scaled;
}
}
fused
} else {
let xv = self.apply(vector);
let mut weighted_xv = xv;
for i in 0..weighted_xv.len() {
weighted_xv[i] *= weights[i].max(0.0);
}
self.apply_transpose(&weighted_xv)
};
if let Some(pen) = penalty {
out += &pen.dot(vector);
}
if ridge > 0.0 {
for j in 0..out.len() {
out[j] += ridge * vector[j];
}
}
out
}
}
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(matrix) => matrix.apply_transpose(vector),
Self::Sparse(matrix) => {
let mut output = Array1::<f64>::zeros(matrix.ncols());
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let mut acc = 0.0;
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
acc += values[idx] * vector[row];
}
output[col] = acc;
}
output
}
}
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"compute_xtwx dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
let p = self.ncols();
match self {
Self::Dense(x) => x.diag_xtw_x(weights),
Self::Sparse(xs) => {
let n = self.nrows();
let nnz_x = xs.as_ref().val().len();
let avg_nnz_row = if n > 0 { nnz_x / n } else { p };
let dense_regime = 4 * avg_nnz_row >= p;
if dense_regime {
let xd = xs.to_dense_arc();
let mut xtwx = Array2::<f64>::zeros((p, p));
streaming_blas_xt_diag_x(xd.as_ref(), weights, &mut xtwx);
return Ok(xtwx);
}
let csr = xs
.to_csr_arc()
.ok_or_else(|| "failed to obtain CSR view in compute_xtwx".to_string())?;
let sym = csr.symbolic();
Ok(sparse_csr_weighted_xtwx(
sym.row_ptr(),
sym.col_idx(),
csr.val(),
n,
p,
weights.view(),
))
}
}
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"diag_gram dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
let p = self.ncols();
match self {
Self::Dense(x) => x.diag_gram(weights),
Self::Sparse(xs) => {
let csr = xs
.to_csr_arc()
.ok_or_else(|| "failed to obtain CSR view in diag_gram".to_string())?;
let sym = csr.symbolic();
Ok(sparse_csr_diag_gram(
sym.row_ptr(),
sym.col_idx(),
csr.val(),
self.nrows(),
p,
weights.view(),
))
}
}
}
fn factorize_system(
&self,
weights: &Array1<f64>,
penalty: Option<&Array2<f64>>,
) -> Result<Box<dyn FactorizedSystem>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"factorize_system dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
match self {
Self::Dense(_) => self.factorize_system_dense(weights, penalty),
Self::Sparse(matrix) => {
let system = assemble_sparseweighted_gram_system(matrix, weights, penalty)?;
let factor = crate::linalg::sparse_exact::factorize_sparse_spd(&system)
.map_err(|e| format!("factorize_system failed: {e:?}"))?;
Ok(Box::new(factor))
}
}
}
}
impl DenseDesignOperator for DesignMatrix {
fn compute_xtwy(&self, weights: &Array1<f64>, y: &Array1<f64>) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows() || y.len() != self.nrows() {
return Err(format!(
"compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
self.nrows()
));
}
match self {
Self::Dense(x) => x.compute_xtwy(weights, y),
Self::Sparse(xs) => {
let csr = xs
.as_ref()
.to_row_major()
.map_err(|_| "failed to obtain CSR view in compute_xtwy".to_string())?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let mut out = Array1::<f64>::zeros(xs.ncols());
for i in 0..xs.nrows() {
let scaled = weights[i].max(0.0) * y[i];
if scaled == 0.0 {
continue;
}
for idx in row_ptr[i]..row_ptr[i + 1] {
out[col_idx[idx]] += vals[idx] * scaled;
}
}
Ok(out)
}
}
}
fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
if middle.nrows() != self.ncols() || middle.ncols() != self.ncols() {
return Err(format!(
"quadratic_form_diag dimension mismatch: matrix is {}x{}, expected {}x{}",
middle.nrows(),
middle.ncols(),
self.ncols(),
self.ncols()
));
}
match self {
Self::Dense(xd) => xd.quadratic_form_diag(middle),
Self::Sparse(xs) => {
let csr = xs
.to_csr_arc()
.ok_or_else(|| "quadratic_form_diag: failed to obtain CSR view".to_string())?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let mut out = Array1::<f64>::zeros(self.nrows());
for i in 0..xs.nrows() {
let start = row_ptr[i];
let end = row_ptr[i + 1];
let mut acc = 0.0_f64;
for a in start..end {
let j = col_idx[a];
let xij = vals[a];
for b in start..end {
let k = col_idx[b];
let xik = vals[b];
acc += xij * middle[[j, k]] * xik;
}
}
out[i] = acc.max(0.0);
}
Ok(out)
}
}
}
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
if out.nrows() != rows.end - rows.start || out.ncols() != self.ncols() {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "DesignMatrix::row_chunk_into shape mismatch",
});
}
match self {
Self::Dense(matrix) => matrix.row_chunk_into(rows, out),
Self::Sparse(matrix) => {
out.fill(0.0);
let csr =
matrix
.to_csr_arc()
.ok_or(MatrixMaterializationError::MissingRowChunk {
context: "DesignMatrix::row_chunk_into: failed to obtain CSR view",
})?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
for (local_row, row) in rows.enumerate() {
for ptr in row_ptr[row]..row_ptr[row + 1] {
out[[local_row, col_idx[ptr]]] = vals[ptr];
}
}
Ok(())
}
}
}
fn to_dense(&self) -> Array2<f64> {
DesignMatrix::to_dense(self)
}
}
impl LinearOperator for DenseRightProductView<'_> {
fn nrows(&self) -> usize {
self.base.nrows()
}
fn ncols(&self) -> usize {
self.transformed_ncols()
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let rhs;
let v = match (self.second, self.first) {
(None, None) => vector,
(Some(s), None) => {
rhs = fast_av(s, vector);
&rhs
}
(None, Some(f)) => {
rhs = fast_av(f, vector);
&rhs
}
(Some(s), Some(f)) => {
let tmp = fast_av(s, vector);
rhs = fast_av(f, &tmp);
&rhs
}
};
dense_matvec(self.base, v)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = dense_transpose_matvec(self.base, vector);
if let Some(factor) = self.first {
out = fast_atv(factor, &out);
}
if let Some(factor) = self.second {
out = fast_atv(factor, &out);
}
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"compute_xtwx dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
let mut gram = fast_xt_diag_x(self.base, weights);
if let Some(factor) = self.first {
gram = fast_ab(&fast_atb(factor, &gram), factor);
}
if let Some(factor) = self.second {
gram = fast_ab(&fast_atb(factor, &gram), factor);
}
Ok(gram)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
Ok(self.diag_xtw_x(weights)?.diag().to_owned())
}
}
impl DenseRightProductView<'_> {
pub fn compute_xtwy(
&self,
weights: &Array1<f64>,
y: &Array1<f64>,
) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows() || y.len() != self.nrows() {
return Err(format!(
"compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
self.nrows()
));
}
let weighted_xty = dense_transpose_weighted_response(self.base, weights, y, None);
let mut out = weighted_xty;
if let Some(factor) = self.first {
out = fast_atv(factor, &out);
}
if let Some(factor) = self.second {
out = fast_atv(factor, &out);
}
Ok(out)
}
pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
let dense = self.materialize();
DesignMatrix::Dense(DenseDesignMatrix::from(dense)).quadratic_form_diag(middle)
}
}
impl LinearOperator for DenseRowScaledView<'_> {
fn nrows(&self) -> usize {
self.matrix.nrows()
}
fn ncols(&self) -> usize {
self.matrix.ncols()
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = dense_matvec(self.matrix, vector);
for i in 0..out.len() {
out[i] *= self.scale[i];
}
out
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let scaled = Array1::from_shape_fn(vector.len(), |i| vector[i] * self.scale[i]);
dense_transpose_matvec(self.matrix, &scaled)
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"compute_xtwx dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
let combined = Array1::from_shape_fn(weights.len(), |i| {
weights[i] * self.scale[i] * self.scale[i]
});
Ok(fast_xt_diag_x(self.matrix, &combined))
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
Ok(self.diag_xtw_x(weights)?.diag().to_owned())
}
}
impl DenseRowScaledView<'_> {
pub fn compute_xtwy(
&self,
weights: &Array1<f64>,
y: &Array1<f64>,
) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows() || y.len() != self.nrows() {
return Err(format!(
"compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
self.nrows()
));
}
Ok(dense_transpose_weighted_response(
self.matrix,
weights,
y,
Some(self.scale),
))
}
pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
if middle.nrows() != self.ncols() || middle.ncols() != self.ncols() {
return Err(format!(
"quadratic_form_diag dimension mismatch: matrix is {}x{}, expected {}x{}",
middle.nrows(),
middle.ncols(),
self.ncols(),
self.ncols()
));
}
let xm = fast_ab(self.matrix, middle);
let mut out = Array1::<f64>::zeros(self.nrows());
for i in 0..self.matrix.nrows() {
let s2 = self.scale[i] * self.scale[i];
out[i] = (self.matrix.row(i).dot(&xm.row(i)) * s2).max(0.0);
}
Ok(out)
}
}
impl LinearOperator for EmbeddedColumnBlock<'_> {
fn nrows(&self) -> usize {
self.local.nrows()
}
fn ncols(&self) -> usize {
self.total_cols
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
dense_matvec(
self.local,
&vector
.slice(ndarray::s![self.global_range.clone()])
.to_owned(),
)
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.total_cols);
out.slice_mut(ndarray::s![self.global_range.clone()])
.assign(&dense_transpose_matvec(self.local, vector));
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"compute_xtwx dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
let mut out = Array2::<f64>::zeros((self.total_cols, self.total_cols));
let local = fast_xt_diag_x(self.local, weights);
out.slice_mut(ndarray::s![
self.global_range.clone(),
self.global_range.clone()
])
.assign(&local);
Ok(out)
}
fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut out = Array1::<f64>::zeros(self.total_cols);
let local =
DesignMatrix::Dense(DenseDesignMatrix::from(self.local.clone())).diag_gram(weights)?;
out.slice_mut(ndarray::s![self.global_range.clone()])
.assign(&local);
Ok(out)
}
}
impl EmbeddedColumnBlock<'_> {
pub fn compute_xtwy(
&self,
weights: &Array1<f64>,
y: &Array1<f64>,
) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows() || y.len() != self.nrows() {
return Err(format!(
"compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
self.nrows()
));
}
let local = dense_transpose_weighted_response(self.local, weights, y, None);
let mut out = Array1::<f64>::zeros(self.total_cols);
out.slice_mut(ndarray::s![self.global_range.clone()])
.assign(&local);
Ok(out)
}
pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
let middle_local = middle
.slice(ndarray::s![
self.global_range.clone(),
self.global_range.clone()
])
.to_owned();
DesignMatrix::Dense(DenseDesignMatrix::from(self.local.clone()))
.quadratic_form_diag(&middle_local)
}
}
fn streaming_blas_xt_diag_x(x: &Array2<f64>, weights: &Array1<f64>, out: &mut Array2<f64>) {
let n = x.nrows();
let p = x.ncols();
if n == 0 || p == 0 {
return;
}
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_ROWS: usize = 512;
const MAX_ROWS: usize = 131_072;
let chunk_rows = (TARGET_BYTES / (p * 8)).max(MIN_ROWS).min(MAX_ROWS).min(n);
let par = faer::get_global_parallelism();
let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p).f());
let mut out_view = array2_to_matmut(out);
for start in (0..n).step_by(chunk_rows) {
let rows = (n - start).min(chunk_rows);
{
let mut chunk = wx_chunk.slice_mut(s![0..rows, ..]);
for local in 0..rows {
let src = start + local;
let wi = weights[src];
for col in 0..p {
chunk[[local, col]] = x[[src, col]] * wi;
}
}
}
let x_slice = x.slice(s![start..start + rows, ..]);
let wx_slice = wx_chunk.slice(s![0..rows, ..]);
let x_view = FaerArrayView::new(&x_slice);
let wx_view = FaerArrayView::new(&wx_slice);
matmul(
out_view.as_mut(),
Accum::Add,
x_view.as_ref().transpose(),
wx_view.as_ref(),
1.0,
par,
);
}
}
impl DesignMatrix {
fn factorize_system_dense(
&self,
weights: &Array1<f64>,
penalty: Option<&Array2<f64>>,
) -> Result<Box<dyn FactorizedSystem>, String> {
let mut system = self.diag_xtw_x(weights)?;
if let Some(pen) = penalty {
if pen.nrows() != system.nrows() || pen.ncols() != system.ncols() {
return Err(format!(
"factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
pen.nrows(),
pen.ncols(),
system.nrows(),
system.ncols()
));
}
system += pen;
}
let factor = crate::linalg::utils::StableSolver::new("linear operator system")
.factorize(&system)
.map_err(|e| format!("factorize_system failed: {e:?}"))?;
Ok(Box::new(factor))
}
}
fn assemble_sparseweighted_gram_system(
matrix: &SparseDesignMatrix,
weights: &Array1<f64>,
penalty: Option<&Array2<f64>>,
) -> Result<SparseColMat<usize, f64>, String> {
let csr = matrix
.to_csr_arc()
.ok_or_else(|| "failed to obtain CSR view in factorize_system".to_string())?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let p = matrix.ncols();
let mut upper = BTreeMap::<(usize, usize), f64>::new();
for i in 0..csr.nrows() {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
let start = row_ptr[i];
let end = row_ptr[i + 1];
for a_ptr in start..end {
let a = col_idx[a_ptr];
let xa = vals[a_ptr];
for b_ptr in a_ptr..end {
let b = col_idx[b_ptr];
let xb = vals[b_ptr];
let key = if a <= b { (a, b) } else { (b, a) };
*upper.entry(key).or_insert(0.0) += wi * xa * xb;
}
}
}
if let Some(pen) = penalty {
if pen.nrows() != p || pen.ncols() != p {
return Err(format!(
"factorize_system penalty shape mismatch: got {}x{}, expected {}x{}",
pen.nrows(),
pen.ncols(),
p,
p
));
}
for i in 0..p {
for j in i..p {
let value = pen[[i, j]];
if value != 0.0 {
*upper.entry((i, j)).or_insert(0.0) += value;
}
}
}
}
let mut triplets = Vec::with_capacity(upper.len());
for ((row, col), value) in upper {
if value != 0.0 {
triplets.push(Triplet::new(row, col, value));
}
}
Ok(SparseColMat::try_new_from_triplets(p, p, &triplets)
.map_err(|_| "failed to build sparse penalized system".to_string())?)
}
impl DesignMatrix {
pub fn hstack(blocks: Vec<DesignMatrix>) -> Result<Self, String> {
if blocks.is_empty() {
return Err("DesignMatrix::hstack requires at least one block".to_string());
}
if blocks.len() == 1 {
return Ok(blocks.into_iter().next().expect("non-empty block list"));
}
let operator =
BlockDesignOperator::new(blocks.into_iter().map(DesignBlock::from).collect())?;
Ok(Self::Dense(DenseDesignMatrix::from(Arc::new(operator))))
}
pub fn nrows(&self) -> usize {
<Self as LinearOperator>::nrows(self)
}
pub fn ncols(&self) -> usize {
<Self as LinearOperator>::ncols(self)
}
pub fn try_row_chunk(
&self,
rows: Range<usize>,
) -> Result<Array2<f64>, MatrixMaterializationError> {
match self {
Self::Dense(matrix) => matrix.try_row_chunk(rows),
Self::Sparse(matrix) => {
let csr =
matrix
.to_csr_arc()
.ok_or(MatrixMaterializationError::MissingRowChunk {
context: "DesignMatrix::try_row_chunk: failed to obtain CSR view",
})?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let chunk_rows = rows.end - rows.start;
let ncols = self.ncols();
let mut out = Array2::<f64>::zeros((chunk_rows, ncols));
for (local_row, row) in rows.enumerate() {
for ptr in row_ptr[row]..row_ptr[row + 1] {
out[[local_row, col_idx[ptr]]] = vals[ptr];
}
}
Ok(out)
}
}
}
pub fn try_to_dense_by_chunks(&self, context: &str) -> Result<Array2<f64>, String> {
let n = self.nrows();
let p = self.ncols();
let chunk_rows = dense_materialization_chunk_rows(n, p);
let mut out = Array2::<f64>::zeros((n, p));
for start in (0..n).step_by(chunk_rows) {
let end = (start + chunk_rows).min(n);
let slice = out.slice_mut(s![start..end, ..]);
self.row_chunk_into(start..end, slice)
.map_err(|err| format!("{context}: failed to materialize row chunk: {err}"))?;
}
Ok(out)
}
pub fn dot_row(&self, row: usize, beta: &Array1<f64>) -> f64 {
self.dot_row_view(row, beta.view())
}
pub fn dot_row_view(&self, row: usize, beta: ArrayView1<'_, f64>) -> f64 {
assert_eq!(
beta.len(),
self.ncols(),
"DesignMatrix::dot_row_view length mismatch: beta={}, ncols={}",
beta.len(),
self.ncols()
);
match self {
Self::Dense(matrix) => {
if let Some(dense) = matrix.as_dense_ref() {
dense.row(row).dot(&beta)
} else {
matrix
.try_row_chunk(row..row + 1)
.expect("DesignMatrix::dot_row_view: try_row_chunk must succeed")
.row(0)
.dot(&beta)
}
}
Self::Sparse(matrix) => {
let csr = matrix
.to_csr_arc()
.unwrap_or_else(|| panic!("DesignMatrix::dot_row: failed to obtain CSR view"));
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let mut out = 0.0;
for ptr in row_ptr[row]..row_ptr[row + 1] {
out += vals[ptr] * beta[col_idx[ptr]];
}
out
}
}
}
pub fn axpy_row_into(
&self,
row: usize,
alpha: f64,
out: &mut ArrayViewMut1<'_, f64>,
) -> Result<(), String> {
if out.len() != self.ncols() {
return Err(format!(
"DesignMatrix::axpy_row_into length mismatch: out={}, ncols={}",
out.len(),
self.ncols()
));
}
if alpha == 0.0 {
return Ok(());
}
match self {
Self::Dense(matrix) => {
if let Some(dense) = matrix.as_dense_ref() {
for (dst, &value) in out.iter_mut().zip(dense.row(row).iter()) {
*dst += alpha * value;
}
} else {
let chunk = matrix
.try_row_chunk(row..row + 1)
.map_err(|e| format!("DesignMatrix::axpy_row_into: {e}"))?;
for (dst, &value) in out.iter_mut().zip(chunk.row(0).iter()) {
*dst += alpha * value;
}
}
}
Self::Sparse(matrix) => {
let csr = matrix.to_csr_arc().unwrap_or_else(|| {
panic!("DesignMatrix::axpy_row_into: failed to obtain CSR view")
});
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
for ptr in row_ptr[row]..row_ptr[row + 1] {
out[col_idx[ptr]] += alpha * vals[ptr];
}
}
}
Ok(())
}
pub fn squared_axpy_row_into(
&self,
row: usize,
alpha: f64,
out: &mut ArrayViewMut1<'_, f64>,
) -> Result<(), String> {
if out.len() != self.ncols() {
return Err(format!(
"DesignMatrix::squared_axpy_row_into length mismatch: out={}, ncols={}",
out.len(),
self.ncols()
));
}
if alpha == 0.0 {
return Ok(());
}
match self {
Self::Dense(matrix) => {
if let Some(dense) = matrix.as_dense_ref() {
for (dst, &value) in out.iter_mut().zip(dense.row(row).iter()) {
*dst += alpha * value * value;
}
} else {
let chunk = matrix
.try_row_chunk(row..row + 1)
.map_err(|e| format!("DesignMatrix::squared_axpy_row_into: {e}"))?;
for (dst, &value) in out.iter_mut().zip(chunk.row(0).iter()) {
*dst += alpha * value * value;
}
}
}
Self::Sparse(matrix) => {
let csr = matrix.to_csr_arc().unwrap_or_else(|| {
panic!("DesignMatrix::squared_axpy_row_into: failed to obtain CSR view")
});
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
for ptr in row_ptr[row]..row_ptr[row + 1] {
let value = vals[ptr];
out[col_idx[ptr]] += alpha * value * value;
}
}
}
Ok(())
}
pub fn crossdiag_axpy_row_into(
&self,
row: usize,
other: &DesignMatrix,
alpha: f64,
out: &mut ArrayViewMut1<'_, f64>,
) -> Result<(), String> {
debug_assert_eq!(self.ncols(), other.ncols());
debug_assert_eq!(out.len(), self.ncols());
if alpha == 0.0 {
return Ok(());
}
match (self, other) {
(Self::Dense(lhs), Self::Dense(rhs)) => {
let lhs_chunk;
let rhs_chunk;
let x = if let Some(lhs_dense) = lhs.as_dense_ref() {
lhs_dense.row(row)
} else {
lhs_chunk = lhs
.try_row_chunk(row..row + 1)
.map_err(|e| format!("crossdiag_axpy_row_into lhs: {e}"))?;
lhs_chunk.row(0)
};
let y = if let Some(rhs_dense) = rhs.as_dense_ref() {
rhs_dense.row(row)
} else {
rhs_chunk = rhs
.try_row_chunk(row..row + 1)
.map_err(|e| format!("crossdiag_axpy_row_into rhs: {e}"))?;
rhs_chunk.row(0)
};
for (dst, (&xi, &yi)) in out.iter_mut().zip(x.iter().zip(y.iter())) {
*dst += alpha * xi * yi;
}
}
(Self::Sparse(lhs), Self::Sparse(rhs)) => {
let lhs_csr = lhs.to_csr_arc().unwrap_or_else(|| {
panic!("crossdiag_axpy_row_into: failed to obtain lhs CSR view")
});
let rhs_csr = rhs.to_csr_arc().unwrap_or_else(|| {
panic!("crossdiag_axpy_row_into: failed to obtain rhs CSR view")
});
let lhs_sym = lhs_csr.symbolic();
let rhs_sym = rhs_csr.symbolic();
let lhs_rp = lhs_sym.row_ptr();
let rhs_rp = rhs_sym.row_ptr();
let lhs_ci = lhs_sym.col_idx();
let rhs_ci = rhs_sym.col_idx();
let lhs_v = lhs_csr.val();
let rhs_v = rhs_csr.val();
let mut li = lhs_rp[row];
let mut ri = rhs_rp[row];
let l_end = lhs_rp[row + 1];
let r_end = rhs_rp[row + 1];
while li < l_end && ri < r_end {
let lc = lhs_ci[li];
let rc = rhs_ci[ri];
if lc == rc {
out[lc] += alpha * lhs_v[li] * rhs_v[ri];
li += 1;
ri += 1;
} else if lc < rc {
li += 1;
} else {
ri += 1;
}
}
}
_ => {
let (sparse_mat, dense_mat) = match (self, other) {
(Self::Sparse(s), Self::Dense(d)) => (s, d),
(Self::Dense(d), Self::Sparse(s)) => (s, d),
_ => unreachable!(),
};
let csr = sparse_mat.to_csr_arc().unwrap_or_else(|| {
panic!("crossdiag_axpy_row_into: failed to obtain CSR view")
});
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let dense_chunk;
let dense_row = if let Some(dense_ref) = dense_mat.as_dense_ref() {
dense_ref.row(row)
} else {
dense_chunk = dense_mat
.try_row_chunk(row..row + 1)
.map_err(|e| format!("crossdiag_axpy_row_into dense chunk: {e}"))?;
dense_chunk.row(0)
};
for ptr in row_ptr[row]..row_ptr[row + 1] {
let c = col_idx[ptr];
out[c] += alpha * vals[ptr] * dense_row[c];
}
}
}
Ok(())
}
pub fn syr_row_into(
&self,
row: usize,
alpha: f64,
target: &mut Array2<f64>,
) -> Result<(), String> {
self.syr_row_into_view(row, alpha, target.view_mut())
}
pub fn syr_row_into_view(
&self,
row: usize,
alpha: f64,
mut target: ArrayViewMut2<'_, f64>,
) -> Result<(), String> {
if target.nrows() != self.ncols() || target.ncols() != self.ncols() {
return Err(format!(
"DesignMatrix::syr_row_into shape mismatch: target={}x{}, ncols={}",
target.nrows(),
target.ncols(),
self.ncols()
));
}
if alpha == 0.0 {
return Ok(());
}
match self {
Self::Dense(matrix) => {
if let Some(dense) = matrix.as_dense_ref() {
let x = dense.row(row);
for i in 0..x.len() {
let xi = x[i];
if xi == 0.0 {
continue;
}
for j in 0..x.len() {
target[[i, j]] += alpha * xi * x[j];
}
}
} else {
let chunk = matrix
.try_row_chunk(row..row + 1)
.map_err(|e| format!("DesignMatrix::syr_row_into: {e}"))?;
let x = chunk.row(0);
for i in 0..x.len() {
let xi = x[i];
if xi == 0.0 {
continue;
}
for j in 0..x.len() {
target[[i, j]] += alpha * xi * x[j];
}
}
}
}
Self::Sparse(matrix) => {
let csr = matrix.to_csr_arc().unwrap_or_else(|| {
panic!("DesignMatrix::syr_row_into: failed to obtain CSR view")
});
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
for ptr_i in row_ptr[row]..row_ptr[row + 1] {
let i = col_idx[ptr_i];
let xi = vals[ptr_i];
for ptr_j in row_ptr[row]..row_ptr[row + 1] {
let j = col_idx[ptr_j];
target[[i, j]] += alpha * xi * vals[ptr_j];
}
}
}
}
Ok(())
}
pub fn row_outer_into(
&self,
row: usize,
other: &DesignMatrix,
alpha: f64,
target: &mut Array2<f64>,
) -> Result<(), String> {
self.row_outer_into_view(row, other, alpha, target.view_mut())
}
pub fn row_outer_into_view(
&self,
row: usize,
other: &DesignMatrix,
alpha: f64,
mut target: ArrayViewMut2<'_, f64>,
) -> Result<(), String> {
if target.nrows() != self.ncols() || target.ncols() != other.ncols() {
return Err(format!(
"DesignMatrix::row_outer_into shape mismatch: target={}x{}, lhs={}, rhs={}",
target.nrows(),
target.ncols(),
self.ncols(),
other.ncols()
));
}
if alpha == 0.0 {
return Ok(());
}
match (self, other) {
(Self::Dense(lhs), Self::Dense(rhs)) => {
let lhs_chunk;
let rhs_chunk;
let x = if let Some(lhs_dense) = lhs.as_dense_ref() {
lhs_dense.row(row)
} else {
lhs_chunk = lhs
.try_row_chunk(row..row + 1)
.map_err(|e| format!("row_outer_into_view lhs: {e}"))?;
lhs_chunk.row(0)
};
let y = if let Some(rhs_dense) = rhs.as_dense_ref() {
rhs_dense.row(row)
} else {
rhs_chunk = rhs
.try_row_chunk(row..row + 1)
.map_err(|e| format!("row_outer_into_view rhs: {e}"))?;
rhs_chunk.row(0)
};
for i in 0..x.len() {
let xi = x[i];
if xi == 0.0 {
continue;
}
for j in 0..y.len() {
target[[i, j]] += alpha * xi * y[j];
}
}
}
(Self::Sparse(lhs), Self::Sparse(rhs)) => {
let lhs_csr = lhs
.to_csr_arc()
.unwrap_or_else(|| panic!("row_outer_into: failed to obtain lhs CSR view"));
let rhs_csr = rhs
.to_csr_arc()
.unwrap_or_else(|| panic!("row_outer_into: failed to obtain rhs CSR view"));
let lhs_sym = lhs_csr.symbolic();
let rhs_sym = rhs_csr.symbolic();
let lhs_rp = lhs_sym.row_ptr();
let rhs_rp = rhs_sym.row_ptr();
let lhs_ci = lhs_sym.col_idx();
let rhs_ci = rhs_sym.col_idx();
let lhs_v = lhs_csr.val();
let rhs_v = rhs_csr.val();
for pi in lhs_rp[row]..lhs_rp[row + 1] {
let i = lhs_ci[pi];
let xi = lhs_v[pi];
for pj in rhs_rp[row]..rhs_rp[row + 1] {
let j = rhs_ci[pj];
target[[i, j]] += alpha * xi * rhs_v[pj];
}
}
}
_ => {
let x = self
.try_row_chunk(row..row + 1)
.map_err(|e| format!("row_outer_into_view lhs: {e}"))?;
let x_row = x.row(0);
let y = other
.try_row_chunk(row..row + 1)
.map_err(|e| format!("row_outer_into_view rhs: {e}"))?;
let y_row = y.row(0);
for i in 0..x_row.len() {
let xi = x_row[i];
if xi == 0.0 {
continue;
}
for j in 0..y_row.len() {
target[[i, j]] += alpha * xi * y_row[j];
}
}
}
}
Ok(())
}
#[inline]
pub fn get(&self, i: usize, j: usize) -> f64 {
match self {
Self::Dense(matrix) => match matrix.as_dense_ref() {
Some(dense) => dense[[i, j]],
None => {
let mut e_j = Array1::<f64>::zeros(matrix.ncols());
e_j[j] = 1.0;
matrix.apply(&e_j)[i]
}
},
Self::Sparse(sp) => {
let dense = sp
.try_to_dense_arc("DesignMatrix::get")
.unwrap_or_else(|msg| panic!("{msg}"));
dense[[i, j]]
}
}
}
pub fn extract_column(&self, j: usize) -> Array1<f64> {
match self {
Self::Dense(m) => {
if let Some(dense) = m.as_dense_ref() {
dense.column(j).to_owned()
} else {
let mut e_j = Array1::zeros(m.ncols());
e_j[j] = 1.0;
m.apply(&e_j)
}
}
Self::Sparse(sp) => {
let n = sp.nrows();
let mut col = Array1::zeros(n);
let (symbolic, values) = sp.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
let start = col_ptr[j];
let end = col_ptr[j + 1];
for idx in start..end {
col[row_idx[idx]] = values[idx];
}
col
}
}
}
pub fn as_dense_ref(&self) -> Option<&Array2<f64>> {
match self {
Self::Dense(matrix) => matrix.as_dense_ref(),
Self::Sparse(_) => None,
}
}
pub fn is_materialized_dense(&self) -> bool {
matches!(self, Self::Dense(DenseDesignMatrix::Materialized(_)))
}
pub fn is_operator_backed(&self) -> bool {
match self {
Self::Dense(matrix) => matrix.is_operator_backed(),
Self::Sparse(_) => false,
}
}
pub fn as_dense_cow(&self) -> Cow<'_, Array2<f64>> {
match self {
Self::Dense(DenseDesignMatrix::Materialized(matrix)) => Cow::Borrowed(matrix.as_ref()),
Self::Dense(DenseDesignMatrix::Lazy(op)) => match op.as_dense_ref() {
Some(dense) => Cow::Borrowed(dense),
None => panic!(
"DesignMatrix::as_dense_cow called on operator-backed design ({}x{}); use row chunks or matrix-vector products",
op.nrows(),
op.ncols()
),
},
Self::Sparse(matrix) => Cow::Owned(
matrix
.try_to_dense_arc("DesignMatrix::as_dense_cow")
.unwrap_or_else(|msg| panic!("{msg}"))
.as_ref()
.clone(),
),
}
}
pub fn to_dense_cow(&self) -> Cow<'_, Array2<f64>> {
match self {
Self::Dense(DenseDesignMatrix::Materialized(matrix)) => Cow::Borrowed(matrix.as_ref()),
Self::Dense(DenseDesignMatrix::Lazy(op)) => {
if let Some(dense) = op.as_dense_ref() {
Cow::Borrowed(dense)
} else {
Cow::Owned(
self.try_to_dense_arc("DesignMatrix::to_dense_cow")
.unwrap_or_else(|msg| panic!("{msg}"))
.as_ref()
.clone(),
)
}
}
Self::Sparse(matrix) => Cow::Owned(
matrix
.try_to_dense_arc("DesignMatrix::to_dense_cow")
.unwrap_or_else(|msg| panic!("{msg}"))
.as_ref()
.clone(),
),
}
}
pub fn to_dense(&self) -> Array2<f64> {
match self {
Self::Dense(matrix) => matrix
.try_to_dense_arc("DesignMatrix::to_dense")
.unwrap_or_else(|msg| panic!("{msg}"))
.as_ref()
.clone(),
Self::Sparse(matrix) => matrix
.try_to_dense_arc("DesignMatrix::to_dense")
.unwrap_or_else(|msg| panic!("{msg}"))
.as_ref()
.clone(),
}
}
pub fn to_dense_arc(&self) -> Arc<Array2<f64>> {
match self {
Self::Dense(matrix) => matrix
.try_to_dense_arc("DesignMatrix::to_dense_arc")
.unwrap_or_else(|msg| panic!("{msg}")),
Self::Sparse(matrix) => matrix
.try_to_dense_arc("DesignMatrix::to_dense_arc")
.unwrap_or_else(|msg| panic!("{msg}")),
}
}
pub fn try_to_dense_arc(&self, context: &str) -> Result<Arc<Array2<f64>>, String> {
match self {
Self::Dense(matrix) => matrix.try_to_dense_arc(context),
Self::Sparse(matrix) => matrix.try_to_dense_arc(context),
}
}
pub fn to_csr_cache(&self) -> Option<SparseRowMat<usize, f64>> {
match self {
Self::Dense(_) => None,
Self::Sparse(matrix) => matrix.to_csr_arc().map(|arc| (*arc).clone()),
}
}
pub fn as_sparse(&self) -> Option<&SparseDesignMatrix> {
match self {
Self::Sparse(matrix) => Some(matrix),
Self::Dense(_) => None,
}
}
pub fn as_dense(&self) -> Option<&Array2<f64>> {
match self {
Self::Dense(matrix) => matrix.as_dense_ref(),
Self::Sparse(_) => None,
}
}
fn apply_transpose_view(&self, vector: ArrayView1<'_, f64>) -> Array1<f64> {
match self {
Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
dense_transpose_matvec_view(matrix, vector)
}
Self::Dense(DenseDesignMatrix::Lazy(op)) => op.apply_transpose(&vector.to_owned()),
Self::Sparse(matrix) => {
let mut output = Array1::<f64>::zeros(matrix.ncols());
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let mut acc = 0.0;
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
acc += values[idx] * vector[row_idx[idx]];
}
output[col] = acc;
}
output
}
}
}
fn compute_xtwx_view(&self, weights: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"compute_xtwx dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
match self {
Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
Ok(dense_xtwx_view(matrix, weights))
}
Self::Dense(DenseDesignMatrix::Lazy(op)) => op.diag_xtw_x(&weights.to_owned()),
Self::Sparse(xs) => {
let p = xs.ncols();
let csr = xs
.to_csr_arc()
.ok_or_else(|| "failed to obtain CSR view in compute_xtwx".to_string())?;
let sym = csr.symbolic();
Ok(sparse_csr_weighted_xtwx(
sym.row_ptr(),
sym.col_idx(),
csr.val(),
xs.nrows(),
p,
weights,
))
}
}
}
fn diag_gram_view(&self, weights: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows() {
return Err(format!(
"diag_gram dimension mismatch: weights length {} != nrows {}",
weights.len(),
self.nrows()
));
}
match self {
Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
Ok(dense_diag_gram_view(matrix, weights))
}
Self::Dense(DenseDesignMatrix::Lazy(op)) => op.diag_gram(&weights.to_owned()),
Self::Sparse(xs) => {
let p = xs.ncols();
let csr = xs
.to_csr_arc()
.ok_or_else(|| "failed to obtain CSR view in diag_gram".to_string())?;
let sym = csr.symbolic();
Ok(sparse_csr_diag_gram(
sym.row_ptr(),
sym.col_idx(),
csr.val(),
xs.nrows(),
p,
weights,
))
}
}
}
fn compute_xtwy_view(
&self,
weights: ArrayView1<'_, f64>,
y: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, String> {
if weights.len() != self.nrows() || y.len() != self.nrows() {
return Err(format!(
"compute_xtwy dimension mismatch: weights={}, y={}, nrows={}",
weights.len(),
y.len(),
self.nrows()
));
}
match self {
Self::Dense(DenseDesignMatrix::Materialized(matrix)) => {
Ok(dense_transpose_weighted_response_view(matrix, weights, y))
}
Self::Dense(DenseDesignMatrix::Lazy(op)) => {
op.compute_xtwy(&weights.to_owned(), &y.to_owned())
}
Self::Sparse(xs) => {
let csr = xs
.as_ref()
.to_row_major()
.map_err(|_| "failed to obtain CSR view in compute_xtwy".to_string())?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let mut out = Array1::<f64>::zeros(xs.ncols());
for i in 0..xs.nrows() {
let scaled = weights[i].max(0.0) * y[i];
if scaled == 0.0 {
continue;
}
for idx in row_ptr[i]..row_ptr[i + 1] {
out[col_idx[idx]] += vals[idx] * scaled;
}
}
Ok(out)
}
}
}
pub fn dot(&self, vector: &Array1<f64>) -> Array1<f64> {
<Self as LinearOperator>::apply(self, vector)
}
pub fn matrixvectormultiply(&self, vector: &Array1<f64>) -> Array1<f64> {
<Self as LinearOperator>::apply(self, vector)
}
pub fn transpose_vector_multiply(&self, vector: &Array1<f64>) -> Array1<f64> {
<Self as LinearOperator>::apply_transpose(self, vector)
}
pub fn compute_xtwx(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
<Self as LinearOperator>::diag_xtw_x(self, weights)
}
pub fn compute_xtwy(
&self,
weights: &Array1<f64>,
y: &Array1<f64>,
) -> Result<Array1<f64>, String> {
<Self as DenseDesignOperator>::compute_xtwy(self, weights, y)
}
pub fn diag_gram(&self, weights: &Array1<f64>) -> Result<Array1<f64>, String> {
<Self as LinearOperator>::diag_gram(self, weights)
}
pub fn quadratic_form_diag(&self, middle: &Array2<f64>) -> Result<Array1<f64>, String> {
<Self as DenseDesignOperator>::quadratic_form_diag(self, middle)
}
pub fn apply_weighted_normal(
&self,
weights: &Array1<f64>,
vector: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge: f64,
) -> Array1<f64> {
<Self as LinearOperator>::apply_weighted_normal(self, weights, vector, penalty, ridge)
}
pub fn solve_system(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
) -> Result<Array1<f64>, String> {
<Self as LinearOperator>::solve_system(self, weights, rhs, penalty)
}
pub fn solve_systemwith_policy(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge_floor: f64,
ridge_policy: RidgePolicy,
) -> Result<Array1<f64>, String> {
<Self as LinearOperator>::solve_systemwith_policy(
self,
weights,
rhs,
penalty,
ridge_floor,
ridge_policy,
)
}
pub fn solve_system_matrix_free_pcg(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge_floor: f64,
) -> Result<Array1<f64>, String> {
<Self as LinearOperator>::solve_system_matrix_free_pcg_try(
self,
weights,
rhs,
penalty,
ridge_floor.max(1e-15),
)
}
pub fn solve_system_matrix_free_pcg_with_info(
&self,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: Option<&Array2<f64>>,
ridge_floor: f64,
) -> Result<(Array1<f64>, PcgSolveInfo), String> {
<Self as LinearOperator>::solve_system_matrix_free_pcg_with_info_try(
self,
weights,
rhs,
penalty,
ridge_floor.max(1e-15),
)
}
pub fn should_use_matrix_free_pcg(&self) -> bool {
<Self as LinearOperator>::uses_matrix_free_pcg(self)
&& self.ncols() >= MATRIX_FREE_PCG_MIN_P
}
pub fn factorize_system(
&self,
weights: &Array1<f64>,
penalty: Option<&Array2<f64>>,
) -> Result<Box<dyn FactorizedSystem>, String> {
<Self as LinearOperator>::factorize_system(self, weights, penalty)
}
}
impl<'a> From<ArrayView2<'a, f64>> for DesignMatrix {
fn from(value: ArrayView2<'a, f64>) -> Self {
Self::Dense(DenseDesignMatrix::from(value.to_owned()))
}
}
impl From<Array2<f64>> for DesignMatrix {
fn from(value: Array2<f64>) -> Self {
Self::Dense(DenseDesignMatrix::from(value))
}
}
impl From<Arc<Array2<f64>>> for DesignMatrix {
fn from(value: Arc<Array2<f64>>) -> Self {
Self::Dense(DenseDesignMatrix::from(value))
}
}
impl From<&Array2<f64>> for DesignMatrix {
fn from(value: &Array2<f64>) -> Self {
Self::Dense(DenseDesignMatrix::from(value.clone()))
}
}
impl From<DenseDesignMatrix> for DesignMatrix {
fn from(value: DenseDesignMatrix) -> Self {
Self::Dense(value)
}
}
impl From<SparseColMat<usize, f64>> for DesignMatrix {
fn from(value: SparseColMat<usize, f64>) -> Self {
Self::Sparse(SparseDesignMatrix::new(value))
}
}
impl From<&SparseColMat<usize, f64>> for DesignMatrix {
fn from(value: &SparseColMat<usize, f64>) -> Self {
Self::Sparse(SparseDesignMatrix::new(value.clone()))
}
}
impl From<&DesignMatrix> for DesignMatrix {
fn from(value: &DesignMatrix) -> Self {
value.clone()
}
}
impl From<DesignMatrix> for DesignBlock {
fn from(value: DesignMatrix) -> Self {
match value {
DesignMatrix::Dense(matrix) => Self::Dense(matrix),
DesignMatrix::Sparse(matrix) => Self::Sparse(matrix),
}
}
}
impl From<&DesignMatrix> for DesignBlock {
fn from(value: &DesignMatrix) -> Self {
match value {
DesignMatrix::Dense(matrix) => Self::Dense(matrix.clone()),
DesignMatrix::Sparse(matrix) => Self::Sparse(matrix.clone()),
}
}
}
#[cfg(test)]
mod tests {
use super::{
ChunkedKernelDesignOperator, CoefficientTransformOperator, DenseDesignMatrix,
DenseDesignOperator, DesignMatrix, EmbeddedColumnBlock, MultiChannelOperator,
ReparamOperator, RowwiseKroneckerOperator, SparseDesignMatrix, SparseHessianAccumulator,
dense_matvec, dense_operator_to_dense_by_chunks, dense_transpose_matvec,
dense_transpose_weighted_response, dense_xtwx_view,
};
use crate::linalg::matrix::LinearOperator;
use crate::linalg::utils::{PcgSolveInfo, StableSolver};
use crate::resource::MatrixMaterializationError;
use crate::testing::no_densify_design;
use crate::types::RidgePolicy;
use faer::sparse::{SparseColMat, SymbolicSparseColMat, Triplet};
use ndarray::{Array1, Array2, ArrayViewMut2, Axis, array, s};
use std::ops::Range;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct ChunkOnlyOperator {
n: usize,
p: usize,
row_chunk_calls: AtomicUsize,
}
impl ChunkOnlyOperator {
fn value(&self, i: usize, j: usize) -> f64 {
((i % 251) as f64) * 0.25 - ((j % 127) as f64) * 0.5 + ((i + j) % 7) as f64
}
}
impl LinearOperator for ChunkOnlyOperator {
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.p
}
fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.n);
for i in 0..self.n {
let mut acc = 0.0;
for j in 0..self.p {
acc += self.value(i, j) * vector[j];
}
out[i] = acc;
}
out
}
fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
for i in 0..self.n {
for j in 0..self.p {
out[j] += self.value(i, j) * vector[i];
}
}
out
}
fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
let dense = dense_operator_to_dense_by_chunks(self).map_err(|err| err.to_string())?;
Ok(dense_xtwx_view(&dense, weights.view()))
}
}
impl DenseDesignOperator for ChunkOnlyOperator {
fn row_chunk_into(
&self,
rows: Range<usize>,
mut out: ArrayViewMut2<'_, f64>,
) -> Result<(), MatrixMaterializationError> {
self.row_chunk_calls.fetch_add(1, Ordering::SeqCst);
if out.nrows() != rows.end - rows.start || out.ncols() != self.p {
return Err(MatrixMaterializationError::MissingRowChunk {
context: "ChunkOnlyOperator::row_chunk_into shape mismatch",
});
}
for (local, row) in rows.enumerate() {
for col in 0..self.p {
out[[local, col]] = self.value(row, col);
}
}
Ok(())
}
fn to_dense(&self) -> Array2<f64> {
panic!("ChunkOnlyOperator::to_dense fallback must not be used")
}
}
fn exact_weighted_penalized_solve(
design: &Array2<f64>,
weights: &Array1<f64>,
rhs: &Array1<f64>,
penalty: &Array2<f64>,
ridge: f64,
) -> Array1<f64> {
let mut h = design
.t()
.dot(&(design * &weights.view().insert_axis(Axis(1))));
h += penalty;
if ridge > 0.0 {
for i in 0..h.nrows() {
h[[i, i]] += ridge;
}
}
StableSolver::new("matrix-free pcg exact reference")
.solvevectorwithridge_retries(&h, rhs, 0.0)
.expect("exact reference solve")
}
#[test]
fn dense_matvec_matches_ndarray_dot() {
let x = array![[1.0, 2.0, -1.0], [0.5, -3.0, 4.0], [2.0, 0.0, 1.5]];
let v = array![0.25, -1.0, 2.0];
let expected = x.dot(&v);
let got = dense_matvec(&x, &v);
for i in 0..expected.len() {
assert!((expected[i] - got[i]).abs() < 1e-12);
}
}
#[test]
fn dense_transpose_matvec_matches_ndarray_dot() {
let x = array![[1.0, 2.0, -1.0], [0.5, -3.0, 4.0], [2.0, 0.0, 1.5]];
let v = array![0.25, -1.0, 2.0];
let expected = x.t().dot(&v);
let got = dense_transpose_matvec(&x, &v);
for i in 0..expected.len() {
assert!((expected[i] - got[i]).abs() < 1e-12);
}
}
#[test]
fn sparse_to_dense_accumulates_duplicate_entries() {
let symbolic = unsafe {
SymbolicSparseColMat::new_unchecked(
3,
2,
vec![0_usize, 2, 3],
None,
vec![1_usize, 1, 0],
)
};
let sparse = SparseColMat::new(symbolic, vec![2.0_f64, 3.5, -1.0]);
let design = DesignMatrix::from(sparse);
let dense = design.to_dense_arc();
assert!((dense[[1, 0]] - 5.5).abs() < 1e-12);
assert!((dense[[0, 1]] + 1.0).abs() < 1e-12);
let v = array![4.0, -2.0];
let y_sparse = design.matrixvectormultiply(&v);
let y_dense = dense.dot(&v);
for i in 0..y_sparse.len() {
assert!((y_sparse[i] - y_dense[i]).abs() < 1e-12);
}
}
#[test]
fn huge_sparse_densification_is_rejected_before_allocation() {
let sparse = SparseColMat::try_new_from_triplets(500_000, 10_000, &[])
.expect("empty sparse matrix should build");
let design = SparseDesignMatrix::new(sparse);
let err = design
.try_to_dense_arc("matrix test")
.expect_err("huge sparse densification should be rejected");
assert!(err.contains("refusing to densify sparse design"));
}
#[test]
fn multi_channel_operator_view_paths_match_stacked_dense_reference() {
let dense_channel = array![[1.0, 2.0], [0.5, -1.0], [3.0, 0.25]];
let sparse_dense = array![[0.0, 1.5], [2.0, 0.0], [-1.0, 0.75]];
let sparse = SparseColMat::try_new_from_triplets(
3,
2,
&[
Triplet::new(1, 0, 2.0),
Triplet::new(2, 0, -1.0),
Triplet::new(0, 1, 1.5),
Triplet::new(2, 1, 0.75),
],
)
.expect("sparse channel");
let op = MultiChannelOperator::new(vec![
DesignMatrix::Dense(DenseDesignMatrix::from(dense_channel.clone())),
DesignMatrix::from(sparse),
])
.expect("multi-channel operator");
let mut stacked = Array2::<f64>::zeros((6, 2));
stacked.slice_mut(s![0..3, ..]).assign(&dense_channel);
stacked.slice_mut(s![3..6, ..]).assign(&sparse_dense);
let beta = array![0.25, -0.4];
let expected_apply = stacked.dot(&beta);
let got_apply = op.apply(&beta);
for i in 0..expected_apply.len() {
assert!((expected_apply[i] - got_apply[i]).abs() < 1e-12);
}
let probe = array![0.5, -1.0, 0.25, 1.5, -0.75, 0.2];
let expected_transpose = stacked.t().dot(&probe);
let got_transpose = op.apply_transpose(&probe);
for i in 0..expected_transpose.len() {
assert!((expected_transpose[i] - got_transpose[i]).abs() < 1e-12);
}
let weights = array![1.0, -0.5, 0.75, 2.0, 0.25, 1.5];
let w_pos = weights.mapv(|w: f64| w.max(0.0));
let weighted = stacked.clone() * &w_pos.view().insert_axis(Axis(1));
let expected_xtwx = stacked.t().dot(&weighted);
let got_xtwx = op.diag_xtw_x(&weights).expect("multi-channel xtwx");
for i in 0..expected_xtwx.nrows() {
for j in 0..expected_xtwx.ncols() {
assert!((expected_xtwx[[i, j]] - got_xtwx[[i, j]]).abs() < 1e-12);
}
}
let expected_diag = Array1::from_iter((0..2).map(|j| expected_xtwx[[j, j]]));
let got_diag = op.diag_gram(&weights).expect("multi-channel diag gram");
for i in 0..expected_diag.len() {
assert!((expected_diag[i] - got_diag[i]).abs() < 1e-12);
}
let y = array![1.0, 0.5, -0.25, 2.0, -1.0, 0.75];
let expected_xtwy = stacked.t().dot(&(w_pos * &y));
let got_xtwy = op.compute_xtwy(&weights, &y).expect("multi-channel xtwy");
for i in 0..expected_xtwy.len() {
assert!((expected_xtwy[i] - got_xtwy[i]).abs() < 1e-12);
}
}
#[test]
#[should_panic(expected = "ReparamOperator: X cols (2) must match Qs rows (3)")]
fn reparam_operator_rejects_incompatible_transform_shape() {
let x = array![[1.0, 2.0], [0.5, -1.0]];
let qs = Arc::new(Array2::<f64>::zeros((3, 1)));
let _ = ReparamOperator::new(DesignMatrix::Dense(DenseDesignMatrix::from(x)), qs);
}
#[test]
fn chunked_kernel_operator_uses_center_rows_for_column_count() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
let kernel =
|x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
.expect("chunked kernel operator");
assert_eq!(operator.ncols(), 3);
let chunk = operator.row_chunk_combined(0..2);
assert_eq!(chunk.dim(), (2, 3));
}
#[test]
fn sparse_hessian_pattern_is_column_major_csc() {
let sparse = SparseColMat::try_new_from_triplets(
1,
3,
&[
Triplet::new(0, 0, 1.0),
Triplet::new(0, 1, 1.0),
Triplet::new(0, 2, 1.0),
],
)
.expect("sparse column matrix");
let csr = sparse.to_row_major().expect("csr conversion");
let accumulator = SparseHessianAccumulator::from_single_csr(&csr, 3);
assert_eq!(accumulator.sym.col_ptrs, vec![0, 1, 3, 6]);
assert_eq!(accumulator.sym.row_indices, vec![0, 0, 1, 0, 1, 2]);
}
#[test]
fn chunked_kernel_operator_rejects_incompatible_optional_shapes() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]]);
let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0], [2.0, -1.0]]);
let kernel = |_: &[f64], _: &[f64]| 0.0;
let bad_constraint = Arc::new(Array2::<f64>::zeros((2, 1)));
let bad_poly = Arc::new(Array2::<f64>::zeros((3, 1)));
let constraint_err = match ChunkedKernelDesignOperator::new(
data.clone(),
centers.clone(),
kernel,
Some(bad_constraint),
None,
) {
Ok(_) => panic!("constraint rows should match centers rows"),
Err(err) => err,
};
assert!(constraint_err.contains("constraint_transform rows 2 != centers rows 3"));
let poly_err =
match ChunkedKernelDesignOperator::new(data, centers, kernel, None, Some(bad_poly)) {
Ok(_) => panic!("poly rows should match data rows"),
Err(err) => err,
};
assert!(poly_err.contains("poly_basis rows 3 != data rows 2"));
}
#[test]
fn chunked_kernel_operator_canonicalizes_non_contiguous_inputs() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5]].reversed_axes());
let centers = Arc::new(array![[0.0, 1.0, 2.0], [0.0, 1.0, -1.0]].reversed_axes());
assert!(!data.is_standard_layout());
assert!(!centers.is_standard_layout());
let kernel =
|x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
let operator = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
.expect("chunked kernel operator");
let chunk = operator.row_chunk_combined(0..2);
assert_eq!(chunk.dim(), (2, 3));
assert_eq!(chunk[[0, 0]], 0.0);
assert_eq!(chunk[[1, 1]], 1.5);
}
#[test]
fn coefficient_transform_operator_exposes_cached_dense_to_block_dispatch() {
let inner = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let transform = array![[0.5, -1.0, 2.0], [1.0, 0.0, -0.5]];
let expected = inner.dot(&transform);
let op =
CoefficientTransformOperator::new(DenseDesignMatrix::from(inner), transform.clone())
.expect("coefficient transform operator");
let dense_design = DenseDesignMatrix::from(Arc::new(op));
let probe = Array1::from_elem(3, 1.0);
let _ = dense_design.apply_transpose(&probe);
let dense_ref = dense_design
.as_dense_ref()
.expect("DenseDesignMatrix::as_dense_ref must reach the cached X·T");
assert_eq!(dense_ref.dim(), expected.dim());
for ((r, c), v) in expected.indexed_iter() {
assert!((dense_ref[[r, c]] - v).abs() < 1e-12);
}
}
#[test]
fn chunked_kernel_operator_exposes_cached_dense_to_block_dispatch() {
let data = Arc::new(array![[0.0, 1.0], [1.0, 0.5], [2.0, -1.0]]);
let centers = Arc::new(array![[0.0, 0.0], [1.0, 1.0]]);
let kernel =
|x: &[f64], c: &[f64]| x.iter().zip(c.iter()).map(|(xi, ci)| xi * ci).sum::<f64>();
let op = ChunkedKernelDesignOperator::new(data, centers, kernel, None, None)
.expect("chunked kernel operator");
let expected = op.to_dense();
let dense_design = DenseDesignMatrix::from(Arc::new(op));
let probe = Array1::from_elem(3, 1.0);
let _ = dense_design.apply_transpose(&probe);
let dense_ref = dense_design
.as_dense_ref()
.expect("DenseDesignMatrix::as_dense_ref must reach the cached kernel block");
assert_eq!(dense_ref.dim(), expected.dim());
for ((r, c), v) in expected.indexed_iter() {
assert!((dense_ref[[r, c]] - v).abs() < 1e-12);
}
}
#[test]
fn design_matrix_hstack_preserves_lazy_blocks() {
let left_dense = array![[1.0, 2.0], [3.0, 4.0]];
let right_dense = array![[5.0], [6.0]];
let left = no_densify_design(left_dense.clone());
let right = no_densify_design(right_dense.clone());
let stacked = DesignMatrix::hstack(vec![left, right]).expect("stacked design");
assert!(stacked.as_dense_ref().is_none());
assert!(!stacked.is_materialized_dense());
assert!(stacked.is_operator_backed());
assert_eq!(stacked.nrows(), 2);
assert_eq!(stacked.ncols(), 3);
let beta = array![0.25, -0.5, 2.0];
let expected = array![9.25, 10.75];
let got = stacked.dot(&beta);
for i in 0..expected.len() {
assert!((got[i] - expected[i]).abs() < 1e-12);
}
let chunk = stacked
.try_row_chunk(0..2)
.expect("stacked.try_row_chunk must succeed");
assert_eq!(chunk, array![[1.0, 2.0, 5.0], [3.0, 4.0, 6.0]]);
}
#[test]
#[should_panic(expected = "DesignMatrix::as_dense_cow called on operator-backed design")]
fn design_matrix_as_dense_cow_rejects_operator_backed_designs() {
let design = no_densify_design(array![[1.0, 2.0], [3.0, 4.0]]);
let _ = design.as_dense_cow();
}
#[test]
fn sparse_factorized_solve_matches_dense_operator_solve() {
let triplets = vec![
Triplet::new(0usize, 0usize, 1.0),
Triplet::new(1, 0, 2.0),
Triplet::new(1, 1, -1.0),
Triplet::new(2, 1, 3.0),
Triplet::new(2, 2, 0.5),
];
let sparse = SparseColMat::try_new_from_triplets(3, 3, &triplets)
.expect("sparse design should build");
let sparse_design = DesignMatrix::from(sparse);
let dense_design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(
sparse_design.to_dense(),
));
let weights = array![1.5, 0.75, 2.0];
let rhs = array![1.0, -0.5, 2.0];
let penalty = Array2::from_diag(&array![0.25, 0.5, 0.75]);
let sparse_sol = sparse_design
.solve_system(&weights, &rhs, Some(&penalty))
.expect("sparse solve should factorize natively");
let dense_sol = dense_design
.solve_system(&weights, &rhs, Some(&penalty))
.expect("dense solve should factorize");
for i in 0..rhs.len() {
assert!(
(sparse_sol[i] - dense_sol[i]).abs() < 1e-10,
"solution mismatch at {i}: sparse={} dense={}",
sparse_sol[i],
dense_sol[i]
);
}
}
#[test]
fn solve_system_stabilizes_indefinite_penalty_and_returns_finite_solution() {
let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(array![
[1.0, 0.0],
[0.0, 0.0]
]));
let weights = array![1.0, 1.0];
let rhs = array![2.0, 0.0];
let penalty = array![[0.0, 0.0], [0.0, -1e-12]];
let beta = design
.solve_system(&weights, &rhs, Some(&penalty))
.expect("solve_system should stabilize indefinite systems");
assert!(beta.iter().all(|v| v.is_finite()));
assert!((beta[0] - 2.0).abs() < 1e-10);
assert!(beta[1].abs() < 1e-8);
}
#[test]
fn explicit_matrix_free_pcg_matches_exact_large_dense_weighted_penalized_solve() {
let n = 48usize;
let p = 520usize;
let mut x = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
x[[i, j]] = (((i + 3) * (j + 5)) % 17) as f64 / 17.0
+ 0.02 * (i as f64)
+ 0.001 * (j as f64);
}
}
let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone()));
let weights = Array1::from_iter((0..n).map(|i| 0.5 + (i as f64) / (2.0 * n as f64)));
let rhs = Array1::from_iter((0..p).map(|j| ((j % 13) as f64 - 6.0) / 13.0));
let penalty = Array2::from_diag(&Array1::from_iter(
(0..p).map(|j| 0.1 + 0.005 * ((j % 7) as f64)),
));
let ridge = 1e-8;
let pcg = design
.solve_system_matrix_free_pcg(&weights, &rhs, Some(&penalty), ridge)
.expect("matrix-free pcg solve");
let exact = exact_weighted_penalized_solve(&x, &weights, &rhs, &penalty, ridge);
for i in 0..p {
assert!(
(pcg[i] - exact[i]).abs() < 1e-5,
"solution mismatch at {i}: pcg={} exact={}",
pcg[i],
exact[i]
);
}
let mut h = x
.t()
.dot(&(x.clone() * &weights.view().insert_axis(Axis(1))));
h += &penalty;
for i in 0..p {
h[[i, i]] += ridge;
}
let residual = h.dot(&pcg) - &rhs;
let residual_norm = residual.dot(&residual).sqrt();
assert!(residual_norm < 1e-4, "residual_norm={residual_norm}");
}
#[test]
fn policy_solve_matches_explicit_matrix_free_pcg_on_large_dense_system() {
let n = 40usize;
let p = 520usize;
let mut x = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
x[[i, j]] = (((2 * i + j + 11) % 23) as f64 / 23.0) + 0.0005 * (j as f64);
}
}
let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x));
let weights = Array1::from_iter((0..n).map(|i| 1.0 + 0.01 * i as f64));
let rhs = Array1::from_iter((0..p).map(|j| ((j % 5) as f64) - 2.0));
let penalty = Array2::from_diag(&Array1::from_iter(
(0..p).map(|j| 0.2 + 0.01 * ((j % 3) as f64)),
));
let ridge_floor = 1e-8;
let explicit = design
.solve_system_matrix_free_pcg(&weights, &rhs, Some(&penalty), ridge_floor)
.expect("explicit pcg");
let policy = design
.solve_systemwith_policy(
&weights,
&rhs,
Some(&penalty),
ridge_floor,
RidgePolicy::explicit_stabilization_pospart(),
)
.expect("policy solve");
for i in 0..p {
assert!(
(explicit[i] - policy[i]).abs() < 1e-6,
"policy mismatch at {i}: explicit={} policy={}",
explicit[i],
policy[i]
);
}
}
#[test]
fn explicit_matrix_free_pcg_reports_convergence_diagnostics() {
let n = 36usize;
let p = 2160usize;
let mut x = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
x[[i, j]] = (((3 * i + 5 * j + 7) % 29) as f64 / 29.0)
+ 0.015 * (i as f64)
+ 1e-4 * j as f64;
}
}
let design = DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x.clone()));
assert!(design.should_use_matrix_free_pcg());
let weights = Array1::from_iter((0..n).map(|i| 0.75 + 0.01 * i as f64));
let rhs = Array1::from_iter((0..p).map(|j| ((j % 9) as f64 - 4.0) / 9.0));
let penalty = Array2::from_diag(&Array1::from_iter(
(0..p).map(|j| 0.05 + 0.002 * ((j % 11) as f64)),
));
let ridge = 1e-8;
let (pcg, info): (Array1<f64>, PcgSolveInfo) = design
.solve_system_matrix_free_pcg_with_info(&weights, &rhs, Some(&penalty), ridge)
.expect("pcg with info");
assert!(info.converged);
assert!(info.iterations > 0);
assert!(info.relative_residual_norm.is_finite());
assert!(info.relative_residual_norm < 1e-6);
let exact = exact_weighted_penalized_solve(&x, &weights, &rhs, &penalty, ridge);
for i in 0..p {
assert!(
(pcg[i] - exact[i]).abs() < 1e-5,
"solution mismatch at {i}: pcg={} exact={}",
pcg[i],
exact[i]
);
}
}
#[test]
fn compute_xtwy_dense_allocationfree_matches_matvec() {
let n = 2_000usize;
let p = 64usize;
let mut x = Array2::<f64>::zeros((n, p));
let mut y = Array1::<f64>::zeros(n);
let mut w = Array1::<f64>::zeros(n);
for i in 0..n {
y[i] = ((i % 17) as f64 - 8.0) * 0.1;
w[i] = 0.25 + ((i % 11) as f64) * 0.05;
for j in 0..p {
x[[i, j]] = (((i * 13 + j * 7) % 97) as f64) / 97.0;
}
}
let reference = {
let wy = Array1::from_shape_fn(n, |i| y[i] * w[i].max(0.0));
dense_transpose_matvec(&x, &wy)
};
let fused = dense_transpose_weighted_response(&x, &w, &y, None);
for j in 0..p {
assert!(
(reference[j] - fused[j]).abs() < 1e-10,
"mismatch at column {j}: ref={} fused={}",
reference[j],
fused[j]
);
}
}
#[test]
fn large_lazy_dense_materialization_streams_chunks_without_to_dense_fallback() {
let n = 11_000usize;
let p = 128usize;
let op = Arc::new(ChunkOnlyOperator {
n,
p,
row_chunk_calls: AtomicUsize::new(0),
});
let design = DenseDesignMatrix::from(Arc::clone(&op));
let dense = design.to_dense_arc();
assert_eq!(dense.dim(), (n, p));
assert!(
op.row_chunk_calls.load(Ordering::SeqCst) > 1,
"expected dense materialization to stream more than one row chunk"
);
for &(i, j) in &[(0, 0), (8_191, 127), (8_192, 0), (10_999, 64)] {
assert_eq!(dense[[i, j]], op.value(i, j));
}
}
#[test]
fn try_to_dense_by_chunks_writes_directly_into_output_slices() {
let n = 11_000usize;
let p = 128usize;
let op = Arc::new(ChunkOnlyOperator {
n,
p,
row_chunk_calls: AtomicUsize::new(0),
});
let design = DesignMatrix::Dense(DenseDesignMatrix::from(Arc::clone(&op)));
let dense = design
.try_to_dense_by_chunks("large chunked regression")
.expect("chunked materialization");
assert_eq!(dense.dim(), (n, p));
assert!(
op.row_chunk_calls.load(Ordering::SeqCst) > 1,
"expected direct chunked conversion to use bounded row chunks"
);
for &(i, j) in &[(1, 7), (4_096, 12), (8_193, 63), (10_998, 127)] {
assert_eq!(dense[[i, j]], op.value(i, j));
}
}
#[test]
fn tensor_product_design_operator_matches_dense_2d() {
use super::{DenseDesignOperator, TensorProductDesignOperator};
let n = 10;
let q1 = 4;
let q2 = 3;
let mut b1 = Array2::<f64>::zeros((n, q1));
let mut b2 = Array2::<f64>::zeros((n, q2));
for i in 0..n {
let t1 = i as f64 / (n - 1) as f64 * (q1 - 1) as f64;
let j1 = (t1.floor() as usize).min(q1 - 2);
let frac1 = t1 - j1 as f64;
b1[[i, j1]] = 1.0 - frac1;
b1[[i, j1 + 1]] = frac1;
let t2 = i as f64 / (n - 1) as f64 * (q2 - 1) as f64;
let j2 = (t2.floor() as usize).min(q2 - 2);
let frac2 = t2 - j2 as f64;
b2[[i, j2]] = 1.0 - frac2;
b2[[i, j2 + 1]] = frac2;
}
let op = TensorProductDesignOperator::new(vec![Arc::new(b1.clone()), Arc::new(b2.clone())])
.unwrap();
let p = q1 * q2;
let mut dense = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j1 in 0..q1 {
for j2 in 0..q2 {
dense[[i, j1 * q2 + j2]] = b1[[i, j1]] * b2[[i, j2]];
}
}
}
let op_dense = op.to_dense();
let max_diff = (&op_dense - &dense)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0f64, f64::max);
assert!(max_diff < 1e-14, "to_dense mismatch: max_diff={max_diff}");
let beta = Array1::from_vec((0..p).map(|j| (j as f64 + 1.0) * 0.1).collect());
let ref_result = dense.dot(&beta);
let op_result = op.apply(&beta);
let max_diff = (&op_result - &ref_result)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0f64, f64::max);
assert!(max_diff < 1e-12, "apply mismatch: max_diff={max_diff}");
let v = Array1::from_vec((0..n).map(|i| (i as f64 + 1.0) * 0.3).collect());
let ref_xt_v = dense.t().dot(&v);
let op_xt_v = op.apply_transpose(&v);
let max_diff = (&op_xt_v - &ref_xt_v)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0f64, f64::max);
assert!(
max_diff < 1e-12,
"apply_transpose mismatch: max_diff={max_diff}"
);
let w = Array1::from_vec((0..n).map(|i| 1.0 + i as f64 * 0.1).collect());
let ref_xtwx = {
let mut out = Array2::<f64>::zeros((p, p));
for i in 0..n {
for a in 0..p {
for b in 0..p {
out[[a, b]] += w[i] * dense[[i, a]] * dense[[i, b]];
}
}
}
out
};
let op_xtwx = op.diag_xtw_x(&w).unwrap();
let max_diff = (&op_xtwx - &ref_xtwx)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0f64, f64::max);
assert!(max_diff < 1e-10, "diag_xtw_x mismatch: max_diff={max_diff}");
}
#[test]
fn tensor_product_design_operator_3d() {
use super::{DenseDesignOperator, TensorProductDesignOperator};
let n = 8;
let dims = [3, 2, 2];
let mut marginals: Vec<Array2<f64>> = Vec::new();
for &q in &dims {
let mut b = Array2::<f64>::zeros((n, q));
for i in 0..n {
let t = i as f64 / (n - 1) as f64 * (q - 1) as f64;
let j = (t.floor() as usize).min(q - 2);
let frac = t - j as f64;
b[[i, j]] = 1.0 - frac;
b[[i, j + 1]] = frac;
}
marginals.push(b);
}
let op = TensorProductDesignOperator::new(
marginals.iter().map(|m| Arc::new(m.clone())).collect(),
)
.unwrap();
let p: usize = dims.iter().copied().product();
let mut dense = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j0 in 0..dims[0] {
for j1 in 0..dims[1] {
for j2 in 0..dims[2] {
let col = j0 * dims[1] * dims[2] + j1 * dims[2] + j2;
dense[[i, col]] =
marginals[0][[i, j0]] * marginals[1][[i, j1]] * marginals[2][[i, j2]];
}
}
}
}
let op_dense = op.to_dense();
let max_diff = (&op_dense - &dense)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0f64, f64::max);
assert!(
max_diff < 1e-14,
"3D to_dense mismatch: max_diff={max_diff}"
);
let beta = Array1::from_vec((0..p).map(|j| (j as f64).sin()).collect());
let xb = op.apply(&beta);
let xtxb = op.apply_transpose(&xb);
let ref_xtxb = dense.t().dot(&dense.dot(&beta));
let max_diff = (&xtxb - &ref_xtxb)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0f64, f64::max);
assert!(max_diff < 1e-10, "3D X'Xβ mismatch: max_diff={max_diff}");
}
#[test]
fn sparse_weighted_crossprod_parallel_path_matches_dense_reference() {
use faer::sparse::Triplet;
let n = 4096;
let p = 192;
let mut triplets = Vec::with_capacity(n * 4);
let mut dense = Array2::<f64>::zeros((n, p));
for i in 0..n {
let base = (i * 37) % p;
for k in 0..4 {
let col = (base + k * 11) % p;
let val = ((i + 3 * k + 1) as f64).sin() * 0.25 + 0.5;
triplets.push(Triplet::new(i, col, val));
dense[[i, col]] = val;
}
}
let sparse = faer::sparse::SparseColMat::try_new_from_triplets(n, p, &triplets).unwrap();
let design = DesignMatrix::Sparse(SparseDesignMatrix::new(sparse));
let weights = Array1::from_iter((0..n).map(|i| match i % 7 {
0 => 0.0,
r => 0.5 + r as f64 * 0.125,
}));
let got = design.compute_xtwx(&weights).unwrap();
let mut reference = Array2::<f64>::zeros((p, p));
for i in 0..n {
let wi = weights[i].max(0.0);
if wi == 0.0 {
continue;
}
for a in 0..p {
let xa = dense[[i, a]];
if xa == 0.0 {
continue;
}
for b in 0..p {
reference[[a, b]] += wi * xa * dense[[i, b]];
}
}
}
let max_diff = (&got - &reference)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0_f64, f64::max);
assert!(
max_diff < 1e-10,
"sparse xtwx mismatch: max_diff={max_diff}"
);
let got_diag = design.diag_gram(&weights).unwrap();
let ref_diag = reference.diag().to_owned();
let max_diag_diff = (&got_diag - &ref_diag)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0_f64, f64::max);
assert!(
max_diag_diff < 1e-10,
"sparse diag gram mismatch: max_diff={max_diag_diff}"
);
}
#[test]
fn rowwise_kronecker_sparse_structured_xtwx_matches_dense_reference() {
use faer::sparse::Triplet;
let n = 2048;
let p_cov = 64;
let p_time = 6;
let mut triplets = Vec::with_capacity(n * 3);
let mut cov_dense = Array2::<f64>::zeros((n, p_cov));
for i in 0..n {
let base = (i * 17) % p_cov;
for k in 0..3 {
let col = (base + k * 7) % p_cov;
let val = 0.2 + (((i + k) % 13) as f64) / 17.0;
triplets.push(Triplet::new(i, col, val));
cov_dense[[i, col]] = val;
}
}
let cov_sparse =
faer::sparse::SparseColMat::try_new_from_triplets(n, p_cov, &triplets).unwrap();
let cov = DesignMatrix::Sparse(SparseDesignMatrix::new(cov_sparse));
let mut time = Array2::<f64>::zeros((n, p_time));
for i in 0..n {
for t in 0..p_time {
time[[i, t]] = (((i + 1) * (t + 3)) as f64).cos() * 0.1 + 0.4;
}
}
let op = RowwiseKroneckerOperator::new(cov, Arc::new(time.clone())).unwrap();
let weights = Array1::from_iter((0..n).map(|i| 0.25 + ((i % 11) as f64) * 0.05));
let got = op.diag_xtw_x(&weights).unwrap();
let p_total = p_cov * p_time;
let mut reference = Array2::<f64>::zeros((p_total, p_total));
for i in 0..n {
for c1 in 0..p_cov {
let x1 = cov_dense[[i, c1]];
if x1 == 0.0 {
continue;
}
for t1 in 0..p_time {
let a = c1 * p_time + t1;
let xa = x1 * time[[i, t1]];
for c2 in 0..p_cov {
let x2 = cov_dense[[i, c2]];
if x2 == 0.0 {
continue;
}
for t2 in 0..p_time {
let b = c2 * p_time + t2;
reference[[a, b]] += weights[i] * xa * x2 * time[[i, t2]];
}
}
}
}
}
let max_diff = (&got - &reference)
.iter()
.map(|v: &f64| v.abs())
.fold(0.0_f64, f64::max);
assert!(
max_diff < 1e-9,
"rowwise kronecker sparse xtwx mismatch: max_diff={max_diff}"
);
}
#[test]
fn embedded_column_block_zero_row_local_materializes_empty_global_width() {
let local = Array2::<f64>::zeros((0, 0));
let out = EmbeddedColumnBlock::new(&local, 2..5, 7).materialize();
assert_eq!(out.dim(), (0, 7));
}
}