use super::*;
#[derive(Clone, Debug)]
pub enum SymmetricMatrix {
Dense(Array2<f64>),
Sparse(faer::sparse::SparseColMat<usize, f64>),
}
impl SymmetricMatrix {
pub fn as_dense(&self) -> Option<&Array2<f64>> {
match self {
Self::Dense(mat) => Some(mat),
Self::Sparse(_) => None,
}
}
pub fn as_sparse(&self) -> Option<&faer::sparse::SparseColMat<usize, f64>> {
match self {
Self::Sparse(mat) => Some(mat),
Self::Dense(_) => None,
}
}
pub fn to_dense(&self) -> Array2<f64> {
match self {
Self::Dense(mat) => mat.clone(),
Self::Sparse(mat) => {
let mut out = Array2::<f64>::zeros((mat.nrows(), mat.ncols()));
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..mat.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
out[[row, col]] += value;
if row != col {
out[[col, row]] += value;
}
}
}
out
}
}
}
pub fn try_to_dense_exact(&self, context: &str) -> Result<Array2<f64>, String> {
if self.nrows() != self.ncols() {
return Err(format!(
"{context}: exact symmetric matrix must be square, got {}x{}",
self.nrows(),
self.ncols()
));
}
let dense = self.to_dense();
if dense.iter().any(|v| !v.is_finite()) {
return Err(format!(
"{context}: exact dense materialization contains non-finite entries"
));
}
Ok(dense)
}
pub fn factorize(&self) -> Result<Box<dyn FactorizedSystem>, String> {
match self {
Self::Dense(mat) => {
let factor = crate::linalg::utils::StableSolver::new("unnamed")
.factorize(mat)
.map_err(|e| format!("Dense SymmetricMatrix factorization failed: {e:?}"))?;
Ok(Box::new(factor))
}
Self::Sparse(mat) => {
let factor = crate::linalg::sparse_exact::factorize_sparse_spd(mat)
.map_err(|e| format!("Sparse SymmetricMatrix factorization failed: {e:?}"))?;
Ok(Box::new(factor))
}
}
}
pub fn add(&self, other: &SymmetricMatrix) -> Result<Self, String> {
if self.nrows() != other.nrows() || self.ncols() != other.ncols() {
return Err(format!(
"SymmetricMatrix::add shape mismatch: lhs {}x{}, rhs {}x{}",
self.nrows(),
self.ncols(),
other.nrows(),
other.ncols()
));
}
match (self, other) {
(Self::Dense(a), Self::Dense(b)) => Ok(Self::Dense(a + b)),
(Self::Dense(a), Self::Sparse(_)) => {
let b_dense = other.to_dense();
Ok(Self::Dense(a + &b_dense))
}
(Self::Sparse(_), Self::Dense(b)) => {
let a_dense = self.to_dense();
Ok(Self::Dense(&a_dense + b))
}
(Self::Sparse(a), Self::Sparse(b)) => {
Ok(Self::Sparse(add_sparse_symmetric_upper(a, b)?))
}
}
}
pub(crate) fn add_dense(&self, other: &Array2<f64>) -> Result<Self, String> {
if self.nrows() != other.nrows() || self.ncols() != other.ncols() {
return Err(format!(
"SymmetricMatrix::add_dense shape mismatch: lhs {}x{}, rhs {}x{}",
self.nrows(),
self.ncols(),
other.nrows(),
other.ncols()
));
}
match self {
Self::Dense(mat) => {
let mut out = mat.clone();
out += other;
Ok(Self::Dense(out))
}
Self::Sparse(mat) => {
let other_sparse =
crate::linalg::sparse_exact::dense_to_sparse_symmetric_upper(other, 0.0)
.map_err(|e| format!("SymmetricMatrix::add_dense failed: {e}"))?;
Ok(Self::Sparse(add_sparse_symmetric_upper(
mat,
&other_sparse,
)?))
}
}
}
pub fn addridge(&self, ridge: f64) -> Result<Self, String> {
if ridge == 0.0 {
return Ok(self.clone());
}
match self {
Self::Dense(mat) => {
let mut out = mat.clone();
for i in 0..out.nrows() {
out[[i, i]] += ridge;
}
Ok(Self::Dense(out))
}
Self::Sparse(mat) => {
let n = mat.nrows();
let mut trip = Vec::with_capacity(n);
for i in 0..n {
trip.push(Triplet::new(i, i, ridge));
}
let diagonal = SparseColMat::<usize, f64>::try_new_from_triplets(n, n, &trip)
.map_err(|_| {
"SymmetricMatrix::addridge failed to assemble sparse diagonal".to_string()
})?;
Ok(Self::Sparse(add_sparse_symmetric_upper(mat, &diagonal)?))
}
}
}
pub fn nrows(&self) -> usize {
match self {
Self::Dense(m) => m.nrows(),
Self::Sparse(m) => m.nrows(),
}
}
pub fn ncols(&self) -> usize {
match self {
Self::Dense(m) => m.ncols(),
Self::Sparse(m) => m.ncols(),
}
}
pub fn dot(&self, rhs: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(mat) => fast_av(mat, rhs),
Self::Sparse(mat) => {
let mut out = Array1::<f64>::zeros(mat.nrows());
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..mat.ncols() {
let rhs_j = rhs[col];
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
out[row] += value * rhs_j;
if row != col {
out[col] += value * rhs[row];
}
}
}
out
}
}
}
pub fn max_abs_diag(&self) -> f64 {
match self {
Self::Dense(mat) => {
let n = mat.nrows().min(mat.ncols());
(0..n).map(|i| mat[[i, i]].abs()).fold(0.0_f64, f64::max)
}
Self::Sparse(mat) => {
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
let mut max_val = 0.0_f64;
for col in 0..mat.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
if row_idx[idx] == col {
max_val = max_val.max(values[idx].abs());
}
}
}
max_val
}
}
}
pub fn dot_matrix(&self, rhs: &Array2<f64>) -> Array2<f64> {
match self {
Self::Dense(mat) => fast_ab(mat, rhs),
Self::Sparse(mat) => {
let n = mat.nrows();
let k = rhs.ncols();
let mut out = Array2::<f64>::zeros((n, k));
let (symbolic, values) = mat.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..mat.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
for c in 0..k {
out[[row, c]] += value * rhs[[col, c]];
if row != col {
out[[col, c]] += value * rhs[[row, c]];
}
}
}
}
out
}
}
}
pub fn left_dot_matrix(&self, lhs: &Array2<f64>) -> Array2<f64> {
let lhs_t = lhs.t().to_owned();
let result_t = self.dot_matrix(&lhs_t);
result_t.t().to_owned()
}
}
pub fn xt_diag_x_signed(
design: &DesignMatrix,
diag: SignedWeightsView<'_>,
) -> Result<SymmetricMatrix, String> {
xt_diag_x_symmetric(design, &diag.view().to_owned())
}
pub fn symmetrize_in_place(matrix: &mut Array2<f64>) {
let p = matrix.nrows();
for i in 0..p {
for j in 0..i {
let v = 0.5 * (matrix[[i, j]] + matrix[[j, i]]);
matrix[[i, j]] = v;
matrix[[j, i]] = v;
}
}
}
pub fn symmetrize(matrix: &Array2<f64>) -> Array2<f64> {
(matrix + &matrix.t()) * 0.5
}
pub fn xt_diag_x_psd(
design: &DesignMatrix,
diag: PsdWeightsView<'_>,
) -> Result<SymmetricMatrix, String> {
xt_diag_x_symmetric(design, &diag.view().to_owned())
}
pub fn xt_diag_x_symmetric(
design: &DesignMatrix,
diag: &Array1<f64>,
) -> Result<SymmetricMatrix, String> {
if design.nrows() != diag.len() {
return Err(format!(
"xt_diag_x_symmetric row mismatch: design has {} rows but diag has {} entries",
design.nrows(),
diag.len()
));
}
match design {
DesignMatrix::Dense(x) => Ok(SymmetricMatrix::Dense(x.diag_xtw_x(diag)?)),
DesignMatrix::Sparse(xs) => {
let n = xs.nrows();
let p = xs.ncols();
let nnz_x = xs.val().len();
let avg_nnz_row = if n > 0 { nnz_x / n } else { p };
let dense_regime = 4 * avg_nnz_row >= p;
if dense_regime {
let mut xtwx = Array2::<f64>::zeros((p, p));
let dense_bytes =
checked_dense_nbytes(n, p, "xt_diag_x_symmetric dense sparse route")?;
if dense_bytes <= MAX_SPARSE_TO_DENSE_BYTES {
let xd = xs.try_to_dense_arc("xt_diag_x_symmetric dense sparse route")?;
stream_weighted_crossprod_into(
xd.as_ref(),
diag,
&mut xtwx,
CrossprodStructure::Full,
CrossprodAccum::Replace,
effective_global_parallelism(),
);
} else {
let (symbolic, values) = xs.parts();
streaming_sparse_csc_xt_diag_x(
symbolic.col_ptr(),
symbolic.row_idx(),
values,
n,
p,
diag.view(),
&mut xtwx,
);
}
return Ok(SymmetricMatrix::Dense(xtwx));
}
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let csr = xs
.to_csr_arc()
.ok_or_else(|| "xt_diag_x_symmetric: failed to obtain CSR view".to_string())?;
let sym = csr.symbolic();
let row_ptr = sym.row_ptr();
let col_idx = sym.col_idx();
let vals = csr.val();
let acc_template = SparseHessianAccumulator::from_single_csr(&csr, p);
let n_threads = rayon::current_num_threads().max(1);
let target_chunks = (n_threads * 16).max(n_threads);
let chunk_rows = (n / target_chunks).max(256).min(n.max(1));
let chunk_starts: Vec<usize> = (0..n).step_by(chunk_rows).collect();
let mut local_accs: Vec<SparseHessianAccumulator> = chunk_starts
.into_par_iter()
.map(|start| {
let end = (start + chunk_rows).min(n);
let mut local = acc_template.empty_clone();
for i in start..end {
let wi = diag[i];
if wi == 0.0 {
continue;
}
let r_start = row_ptr[i];
let r_end = row_ptr[i + 1];
for a_ptr in r_start..r_end {
let a = col_idx[a_ptr];
let wxa = wi * vals[a_ptr];
local.add_upper(a, a, wxa * vals[a_ptr]);
for b_ptr in (a_ptr + 1)..r_end {
let b = col_idx[b_ptr];
local.add_upper(a, b, wxa * vals[b_ptr]);
}
}
}
local
})
.collect();
let mut acc = if let Some(first) = local_accs.pop() {
first
} else {
acc_template.empty_clone()
};
for other in local_accs.into_iter() {
acc.add_values(&other.values);
}
Ok(SymmetricMatrix::Sparse(acc.into_sparse_col_mat()))
}
}
}
fn add_sparse_symmetric_upper(
lhs: &SparseColMat<usize, f64>,
rhs: &SparseColMat<usize, f64>,
) -> Result<SparseColMat<usize, f64>, String> {
if lhs.nrows() != rhs.nrows() || lhs.ncols() != rhs.ncols() {
return Err(format!(
"add_sparse_symmetric_upper shape mismatch: lhs {}x{}, rhs {}x{}",
lhs.nrows(),
lhs.ncols(),
rhs.nrows(),
rhs.ncols()
));
}
let mut upper = BTreeMap::<(usize, usize), f64>::new();
for matrix in [lhs, rhs] {
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
for idx in col_ptr[col]..col_ptr[col + 1] {
let row = row_idx[idx];
let key = if row <= col { (row, col) } else { (col, row) };
*upper.entry(key).or_insert(0.0) += values[idx];
}
}
}
let triplets: Vec<_> = upper
.into_iter()
.filter_map(|((row, col), value)| (value != 0.0).then_some(Triplet::new(row, col, value)))
.collect();
SparseColMat::try_new_from_triplets(lhs.nrows(), lhs.ncols(), &triplets)
.map_err(|_| "add_sparse_symmetric_upper failed to assemble CSC".to_string())
}