use crate::algebra::prelude::*;
use crate::core::traits::{Indexing, SubmatrixExtract};
use crate::error::KError;
use crate::matrix::sparse_api::CsrMatRef;
#[cfg(all(feature = "backend-faer", feature = "simd"))]
use crate::matrix::spmv::{SpmvPlan, SpmvTuning, build_plan_owned as build_spmv_plan};
use std::collections::HashMap;
pub trait SparseMatrix {
type Scalar;
fn nrows(&self) -> usize;
fn ncols(&self) -> usize;
fn row_ptr(&self) -> &[usize];
fn col_idx(&self) -> &[usize];
fn values(&self) -> &[Self::Scalar];
}
#[derive(Clone, Debug)]
pub struct CsrMatrix<T> {
nrows: usize,
ncols: usize,
row_ptr: Vec<usize>,
col_idx: Vec<usize>,
values: Vec<T>,
diag_pos: Vec<Option<usize>>,
#[cfg(all(feature = "backend-faer", feature = "simd"))]
spmv_plan: Option<SpmvPlan<f64>>,
}
impl<T> CsrMatrix<T> {
pub fn from_csr(
nrows: usize,
ncols: usize,
row_ptr: Vec<usize>,
col_idx: Vec<usize>,
values: Vec<T>,
) -> Self {
debug_assert_eq!(row_ptr.len(), nrows + 1);
debug_assert_eq!(col_idx.len(), values.len());
debug_assert!(row_ptr.windows(2).all(|w| w[0] <= w[1]));
debug_assert!(col_idx.iter().all(|&j| j < ncols));
#[cfg(debug_assertions)]
{
for i in 0..nrows {
let start = row_ptr[i];
let end = row_ptr[i + 1];
debug_assert!(
col_idx[start..end].windows(2).all(|w| w[0] <= w[1]),
"CsrMatrix::from_csr: row {i} has unsorted column indices"
);
}
}
let mut this = Self {
nrows,
ncols,
row_ptr,
col_idx,
values,
diag_pos: Vec::new(),
#[cfg(all(feature = "backend-faer", feature = "simd"))]
spmv_plan: None,
};
this.build_diag_pos();
this
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn row_ptr(&self) -> &[usize] {
&self.row_ptr
}
#[inline]
pub fn col_idx(&self) -> &[usize] {
&self.col_idx
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
#[cfg(all(feature = "backend-faer", feature = "simd"))]
self.invalidate_spmv_plan();
&mut self.values
}
#[inline]
pub fn row(&self, i: usize) -> (&[usize], &[T]) {
let start = self.row_ptr[i];
let end = self.row_ptr[i + 1];
(&self.col_idx[start..end], &self.values[start..end])
}
#[inline]
pub fn row_values_mut(&mut self, i: usize) -> &mut [T] {
#[cfg(all(feature = "backend-faer", feature = "simd"))]
self.invalidate_spmv_plan();
let start = self.row_ptr[i];
let end = self.row_ptr[i + 1];
&mut self.values[start..end]
}
#[inline]
pub fn diag_ref(&self, i: usize) -> Option<&T> {
self.diag_pos[i].map(|k| &self.values()[k])
}
#[inline]
pub fn diag_mut(&mut self, i: usize) -> Option<&mut T> {
if let Some(k) = self.diag_pos[i] {
Some(&mut self.values_mut()[k])
} else {
None
}
}
pub fn build_diag_pos(&mut self) {
let n = self.nrows();
self.diag_pos.resize(n, None);
for i in 0..n {
let start = self.row_ptr[i];
let end = self.row_ptr[i + 1];
if let Ok(off) = self.col_idx[start..end].binary_search(&i) {
self.diag_pos[i] = Some(start + off);
} else {
self.diag_pos[i] = None;
}
}
}
}
impl<T> SparseMatrix for CsrMatrix<T> {
type Scalar = T;
fn nrows(&self) -> usize {
self.nrows()
}
fn ncols(&self) -> usize {
self.ncols()
}
fn row_ptr(&self) -> &[usize] {
self.row_ptr()
}
fn col_idx(&self) -> &[usize] {
self.col_idx()
}
fn values(&self) -> &[Self::Scalar] {
self.values()
}
}
impl<T: KrystScalar> CsrMatrix<T> {
pub fn identity(n: usize) -> Self {
let row_ptr: Vec<usize> = (0..=n).collect();
let col_idx: Vec<usize> = (0..n).collect();
let values: Vec<T> = vec![T::one(); n];
Self::from_csr(n, n, row_ptr, col_idx, values)
}
pub fn diagonal(&self) -> Vec<T> {
let n = self.nrows().min(self.ncols());
let mut diag = vec![T::zero(); n];
for i in 0..n {
let (cols, vals) = self.row(i);
if let Some((_, &val)) = cols
.iter()
.copied()
.zip(vals.iter())
.find(|(col, _)| *col == i)
{
diag[i] = val;
}
}
diag
}
pub fn spmv(&self, x: &[T], y: &mut [T]) {
if let Err(err) = self.try_spmv(x, y) {
debug_assert!(false, "CsrMatrix::spmv dimension mismatch: {err}");
}
}
pub fn try_spmv(&self, x: &[T], y: &mut [T]) -> Result<(), KError> {
crate::matrix::spmv::csr_matvec(self, x, y)
}
pub fn spmv_scaled(
&self,
alpha: T,
x: &[T],
beta: T,
y: &mut [T],
) -> Result<(), KError> {
if x.len() != self.ncols() || y.len() != self.nrows() {
return Err(KError::InvalidInput(format!(
"Dimension mismatch in spmv: A={}x{}, x.len()={}, y.len={}",
self.nrows(),
self.ncols(),
x.len(),
y.len()
)));
}
crate::matrix::spmv::scalar::spmv_scaled_csr(
self.nrows(),
self.row_ptr(),
self.col_idx(),
self.values(),
alpha,
x,
beta,
y,
);
Ok(())
}
pub fn spmv_transpose_scaled(
&self,
alpha: T,
x: &[T],
beta: T,
y: &mut [T],
) -> Result<(), KError> {
if x.len() != self.nrows() || y.len() != self.ncols() {
return Err(KError::InvalidInput(format!(
"Dimension mismatch in spmv^T: A={}x{}, x.len()={}, y.len()={}",
self.nrows(),
self.ncols(),
x.len(),
y.len()
)));
}
crate::matrix::spmv::scalar::spmv_t_scaled_csr(
self.nrows(),
self.row_ptr(),
self.col_idx(),
self.values(),
alpha,
x,
beta,
y,
);
Ok(())
}
}
impl<T> CsrMatrix<T>
where
T: KrystScalar<Real = f64>,
{
#[cfg(feature = "backend-faer")]
pub fn to_dense(&self) -> Result<faer::Mat<f64>, crate::error::KError> {
if crate::algebra::scalar::is_complex_scalar::<T>() {
return Err(crate::error::KError::Unsupported(
"CSR to_dense is real-only; complex scalars are unsupported",
));
}
let mut dense = faer::Mat::zeros(self.nrows, self.ncols);
for i in 0..self.nrows {
let (cols, vals) = self.row(i);
for (&j, &v) in cols.iter().zip(vals.iter()) {
dense[(i, j)] = v.real();
}
}
Ok(dense)
}
#[cfg(feature = "backend-faer")]
pub fn from_dense(
dense: &faer::Mat<R>,
drop_tol: R,
) -> Result<Self, crate::error::KError> {
if crate::algebra::scalar::is_complex_scalar::<T>() {
return Err(crate::error::KError::Unsupported(
"CSR from_dense is real-only; complex scalars are unsupported",
));
}
let nrows = dense.nrows();
let ncols = dense.ncols();
let mut row_ptr = vec![0];
let mut col_idx = Vec::new();
let mut values = Vec::new();
for i in 0..nrows {
for j in 0..ncols {
let val = dense[(i, j)];
if val.abs() >= drop_tol {
col_idx.push(j);
values.push(T::from_real(val));
}
}
row_ptr.push(col_idx.len());
}
Ok(Self::from_csr(nrows, ncols, row_ptr, col_idx, values))
}
#[cfg(feature = "backend-faer")]
pub fn from_dense_owned(
dense: faer::Mat<R>,
drop_tol: R,
) -> Result<Self, crate::error::KError> {
Self::from_dense(&dense, drop_tol)
}
}
impl CsrMatrix<f64> {
pub fn to_scalar_csr(&self) -> crate::matrix::csr::CsrMatrix<S> {
let values = self.values().iter().copied().map(S::from_real).collect();
crate::matrix::csr::CsrMatrix::new(
self.nrows(),
self.ncols(),
self.row_ptr().to_vec(),
self.col_idx().to_vec(),
values,
)
}
}
impl<T> Indexing for CsrMatrix<T> {
fn nrows(&self) -> usize {
self.nrows()
}
}
impl<T: Clone> SubmatrixExtract for CsrMatrix<T> {
type S = T;
fn extract_submatrix(&self, rows: &[usize], cols: &[usize]) -> Self {
let m = rows.len();
let n = cols.len();
let mut row_ptr = Vec::with_capacity(m + 1);
row_ptr.push(0);
let mut col_idx = Vec::new();
let mut values = Vec::new();
let mut g2l: HashMap<usize, usize> = HashMap::with_capacity(n);
for (l, &g) in cols.iter().enumerate() {
g2l.insert(g, l);
}
for &g_row in rows {
let rs = self.row_ptr[g_row];
let re = self.row_ptr[g_row + 1];
for p in rs..re {
let gcol = self.col_idx[p];
if let Some(&lcol) = g2l.get(&gcol) {
col_idx.push(lcol);
values.push(self.values[p].clone());
}
}
row_ptr.push(col_idx.len());
}
CsrMatrix::from_csr(m, n, row_ptr, col_idx, values)
}
}
#[cfg(all(feature = "backend-faer", feature = "simd"))]
impl<T> CsrMatrix<T> {
#[inline]
fn invalidate_spmv_plan(&mut self) {
self.spmv_plan = None;
}
}
#[cfg(all(feature = "backend-faer", feature = "simd"))]
impl CsrMatrix<f64> {
pub fn build_spmv_plan(&mut self, tuning: &SpmvTuning) {
let owned = crate::matrix::csr::CsrMatrix::new(
self.nrows(),
self.ncols(),
self.row_ptr().to_vec(),
self.col_idx().to_vec(),
self.values().to_vec(),
);
self.spmv_plan = Some(build_spmv_plan(owned, tuning));
}
pub fn clear_spmv_plan(&mut self) {
self.spmv_plan = None;
}
}
#[cfg(feature = "rayon")]
impl<T> CsrMatrix<T>
where
T: KrystScalar,
{
pub fn spmv_parallel(&self, x: &[T], y: &mut [T]) {
if let Err(err) = self.try_spmv_parallel(x, y) {
debug_assert!(false, "CsrMatrix::spmv_parallel dimension mismatch: {err}");
}
}
pub fn try_spmv_parallel(&self, x: &[T], y: &mut [T]) -> Result<(), KError> {
crate::matrix::spmv::csr_matvec_par(self, x, y)
}
}
impl<T: KrystScalar> CsrMatRef<T> for CsrMatrix<T> {
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn row_ptr(&self) -> &[usize] {
&self.row_ptr
}
fn col_idx(&self) -> &[usize] {
&self.col_idx
}
fn values(&self) -> &[T] {
&self.values
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_spmv() {
let m = CsrMatrix::from_csr(
3,
3,
vec![0, 1, 2, 3],
vec![0, 1, 2],
vec![S::from_real(1.0), S::from_real(1.0), S::from_real(1.0)],
);
let x = vec![S::from_real(2.0), S::from_real(3.0), S::from_real(5.0)];
let mut y = vec![S::zero(); 3];
m.spmv_scaled(S::one(), &x, S::zero(), &mut y).unwrap();
assert_eq!(y, x);
}
#[test]
fn simple_pattern() {
let m = CsrMatrix::from_csr(
2,
3,
vec![0, 2, 4],
vec![0, 1, 1, 2],
vec![
S::from_real(1.0),
S::from_real(2.0),
S::from_real(3.0),
S::from_real(4.0),
],
);
let x = vec![S::one(), S::one(), S::one()];
let mut y = vec![S::zero(); 2];
m.spmv_scaled(S::one(), &x, S::zero(), &mut y).unwrap();
assert_eq!(y, vec![S::from_real(3.0), S::from_real(7.0)]);
}
#[test]
fn transpose_spmv() {
let m = CsrMatrix::from_csr(
2,
3,
vec![0, 2, 4],
vec![0, 1, 1, 2],
vec![
S::from_real(1.0),
S::from_real(2.0),
S::from_real(3.0),
S::from_real(4.0),
],
);
let x = vec![S::from_real(1.0), S::from_real(2.0)];
let mut y = vec![S::zero(); 3];
m.spmv_transpose_scaled(S::one(), &x, S::zero(), &mut y)
.unwrap();
assert_eq!(
y,
vec![S::from_real(1.0), S::from_real(8.0), S::from_real(8.0)]
);
}
#[test]
fn diag_ref_tracks_cached_positions() {
let m = CsrMatrix::from_csr(
3,
3,
vec![0, 2, 4, 5],
vec![0, 1, 1, 2, 2],
vec![
S::from_real(1.0),
S::from_real(2.0),
S::from_real(3.0),
S::from_real(4.0),
S::from_real(5.0),
],
);
assert_eq!(m.diag_ref(0).map(|v| *v), Some(S::from_real(1.0)));
assert_eq!(m.diag_ref(1).map(|v| *v), Some(S::from_real(3.0)));
assert_eq!(m.diag_ref(2).map(|v| *v), Some(S::from_real(5.0)));
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "unsorted column indices")]
fn from_csr_panics_on_unsorted_row() {
let _ = CsrMatrix::from_csr(
2,
2,
vec![0, 2, 4],
vec![1, 0, 0, 1], vec![1.0, 1.0, 1.0, 1.0],
);
}
#[cfg(all(feature = "backend-faer", feature = "complex"))]
#[test]
fn dense_conversions_reject_complex_scalars() {
let csr = CsrMatrix::from_csr(
2,
2,
vec![0, 2, 4],
vec![0, 1, 0, 1],
vec![S::from_parts(1.0, 0.5), S::from_parts(2.0, -1.0), S::one(), S::zero()],
);
let err = csr.to_dense().unwrap_err();
assert!(matches!(err, crate::error::KError::Unsupported(_)));
let dense = faer::Mat::<R>::from_fn(2, 2, |i, j| if i == j { 1.0 } else { 0.0 });
let err = CsrMatrix::<S>::from_dense(&dense, 0.0).unwrap_err();
assert!(matches!(err, crate::error::KError::Unsupported(_)));
}
#[test]
fn try_spmv_reports_dim_mismatch() {
let m = CsrMatrix::from_csr(
2,
3,
vec![0, 2, 3],
vec![0, 2, 1],
vec![S::one(), S::from_real(2.0), S::from_real(3.0)],
);
let x = vec![S::one(); 2];
let mut y = vec![S::zero(); 2];
let err = m.try_spmv(&x, &mut y).unwrap_err();
assert!(matches!(err, KError::InvalidInput(_)));
}
}