use crate::algebra::prelude::*;
use crate::matrix::sparse_api::CsrMatRef;
#[derive(Clone, Debug, PartialEq)]
pub struct CsrMatrix<S: KrystScalar> {
pub nrows: usize,
pub ncols: usize,
pub rowptr: Vec<usize>,
pub colind: Vec<usize>,
pub values: Vec<S>,
}
impl<S: KrystScalar> CsrMatRef<S> for CsrMatrix<S> {
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn row_ptr(&self) -> &[usize] {
&self.rowptr
}
fn col_idx(&self) -> &[usize] {
&self.colind
}
fn values(&self) -> &[S] {
&self.values
}
}
impl<S: KrystScalar> CsrMatrix<S> {
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn dims(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn rowptr(&self) -> &[usize] {
&self.rowptr
}
#[inline]
pub fn rowptr_mut(&mut self) -> &mut [usize] {
&mut self.rowptr
}
#[inline]
pub fn row_ptr(&self) -> &[usize] {
self.rowptr()
}
#[inline]
pub fn row_ptr_mut(&mut self) -> &mut [usize] {
self.rowptr_mut()
}
#[inline]
pub fn colind(&self) -> &[usize] {
&self.colind
}
#[inline]
pub fn colind_mut(&mut self) -> &mut [usize] {
&mut self.colind
}
#[inline]
pub fn col_idx(&self) -> &[usize] {
self.colind()
}
#[inline]
pub fn col_idx_mut(&mut self) -> &mut [usize] {
self.colind_mut()
}
#[inline]
pub fn values(&self) -> &[S] {
&self.values
}
#[inline]
pub fn values_mut(&mut self) -> &mut [S] {
&mut self.values
}
pub fn new(
nrows: usize,
ncols: usize,
rowptr: Vec<usize>,
colind: Vec<usize>,
values: Vec<S>,
) -> Self {
debug_assert_eq!(rowptr.len(), nrows + 1);
debug_assert_eq!(colind.len(), values.len());
Self {
nrows,
ncols,
rowptr,
colind,
values,
}
}
pub fn is_valid(&self) -> bool {
self.rowptr.len() == self.nrows + 1
&& self.colind.len() == self.values.len()
&& self.rowptr.windows(2).all(|w| w[0] <= w[1])
&& self.colind.iter().all(|&j| j < self.ncols)
}
#[inline]
pub fn spmv(&self, x: &[S], y: &mut [S]) {
crate::matrix::spmv::csr_matvec(self, x, y).expect("CsrMatrix::spmv dimension mismatch");
}
#[inline]
pub fn spmv_scaled(&self, alpha: S, x: &[S], beta: S, y: &mut [S]) {
assert_eq!(x.len(), self.ncols);
assert_eq!(y.len(), self.nrows);
crate::matrix::spmv::scalar::spmv_scaled_csr(
self.nrows,
&self.rowptr,
&self.colind,
&self.values,
alpha,
x,
beta,
y,
);
}
#[inline]
pub fn spmv_t(&self, x: &[S], y: &mut [S]) {
self.spmv_t_scaled(S::one(), x, S::zero(), y);
}
#[inline]
pub fn spmv_t_scaled(&self, alpha: S, x: &[S], beta: S, y: &mut [S]) {
assert_eq!(x.len(), self.nrows);
assert_eq!(y.len(), self.ncols);
crate::matrix::spmv::scalar::spmv_t_scaled_csr(
self.nrows,
&self.rowptr,
&self.colind,
&self.values,
alpha,
x,
beta,
y,
);
}
}
impl<S> CsrMatrix<S>
where
S: KrystScalar<Real = f64>,
{
pub fn from_real_csr(real: &crate::matrix::sparse::CsrMatrix<f64>) -> Self {
let values = real.values().iter().copied().map(S::from_real).collect();
Self::new(
real.nrows(),
real.ncols(),
real.row_ptr().to_vec(),
real.col_idx().to_vec(),
values,
)
}
}
pub type CsrMatrix64 = CsrMatrix<f64>;