#![cfg(feature = "backend-sprs")]
use std::any::Any;
use std::sync::Arc;
use sprs::{CsMat, TriMat, CSR};
use crate::algebra::prelude::*;
use crate::error::KError;
use crate::matrix::backend::SparseBackend;
use crate::matrix::dense_api::{DenseMatMut, DenseMatRef, DenseMatShape};
use crate::matrix::format::{BackendFormatSupport, OpFormat};
use crate::matrix::op::LinOp;
pub struct SprsBackend;
#[derive(Clone, Debug, PartialEq)]
pub struct SprsDenseMat {
nrows: usize,
ncols: usize,
data: Vec<f64>,
}
impl SprsDenseMat {
pub fn from_row_major(nrows: usize, ncols: usize, data: Vec<f64>) -> Self {
assert_eq!(
data.len(),
nrows * ncols,
"row-major dense data length must equal nrows * ncols"
);
Self { nrows, ncols, data }
}
#[inline]
fn idx(&self, i: usize, j: usize) -> usize {
i * self.ncols + j
}
}
impl DenseMatShape for SprsDenseMat {
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
}
impl DenseMatRef<f64> for SprsDenseMat {
fn get(&self, i: usize, j: usize) -> f64 {
self.data[self.idx(i, j)]
}
}
impl DenseMatMut<f64> for SprsDenseMat {
fn set(&mut self, i: usize, j: usize, val: f64) {
let idx = self.idx(i, j);
self.data[idx] = val;
}
}
impl LinOp for SprsDenseMat {
type S = f64;
fn dims(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
fn matvec(&self, x: &[Self::S], y: &mut [Self::S]) {
y.fill(0.0);
for i in 0..self.nrows {
let row_offset = i * self.ncols;
let mut acc = 0.0;
for j in 0..self.ncols {
acc += self.data[row_offset + j] * x[j];
}
y[i] = acc;
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self) -> OpFormat {
OpFormat::Dense
}
}
impl LinOp for CsMat<f64> {
type S = f64;
fn dims(&self) -> (usize, usize) {
(self.rows(), self.cols())
}
fn matvec(&self, x: &[Self::S], y: &mut [Self::S]) {
y.fill(0.0);
if self.storage() != CSR {
let csr = self.to_csr();
for (row_ind, row_vec) in csr.outer_iterator().enumerate() {
let mut acc = 0.0;
for (col_ind, val) in row_vec.iter() {
acc += val * x[col_ind];
}
y[row_ind] = acc;
}
return;
}
for (row_ind, row_vec) in self.outer_iterator().enumerate() {
let mut acc = 0.0;
for (col_ind, val) in row_vec.iter() {
acc += val * x[col_ind];
}
y[row_ind] = acc;
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self) -> OpFormat {
OpFormat::Csr
}
}
impl<S> SparseBackend<S> for SprsBackend
where
S: KrystScalar<Real = f64>,
{
const FORMAT_SUPPORT: BackendFormatSupport = BackendFormatSupport::new(true, true, false, false);
type Csr = CsMat<f64>;
type Csc = ();
type Dense = SprsDenseMat;
fn csr_from_dense(dense: &Self::Dense, drop_tol: S::Real) -> Result<Self::Csr, KError> {
let mut triplet = TriMat::with_capacity((dense.nrows, dense.ncols), dense.data.len());
for i in 0..dense.nrows {
for j in 0..dense.ncols {
let val = dense.get(i, j);
if val.abs() > drop_tol {
triplet.add_triplet(i, j, val);
}
}
}
Ok(triplet.to_csr())
}
fn csc_from_csr(_csr: &Self::Csr, _drop_tol: S::Real) -> Self::Csc {
()
}
fn csr_from_csc(_csc: &Self::Csc, _drop_tol: S::Real) -> Self::Csr {
unreachable!("sprs backend does not support CSC materialization")
}
fn dense_from_csr(csr: &Self::Csr) -> Result<Self::Dense, KError> {
let csr = if csr.storage() == CSR { csr.clone() } else { csr.to_csr() };
let mut data = vec![0.0; csr.rows() * csr.cols()];
for (row_ind, row_vec) in csr.outer_iterator().enumerate() {
for (col_ind, val) in row_vec.iter() {
data[row_ind * csr.cols() + col_ind] += val;
}
}
Ok(SprsDenseMat::from_row_major(csr.rows(), csr.cols(), data))
}
fn dense_from_csc(_csc: &Self::Csc) -> Result<Self::Dense, KError> {
Err(KError::Unsupported(
"sprs backend does not support CSC materialization",
))
}
}
pub fn try_materialize(
op: Arc<dyn LinOp<S = S>>,
want: OpFormat,
drop_tol: R,
) -> Result<Arc<dyn LinOp<S = S>>, KError> {
if want.is_any() {
return Ok(op);
}
if !<SprsBackend as SparseBackend<S>>::FORMAT_SUPPORT.supports(want) {
return Err(KError::Unsupported(
"sprs backend does not support the requested format",
));
}
if let Some(csr) = op.as_any().downcast_ref::<CsMat<f64>>() {
return match want {
OpFormat::Csr => Ok(Arc::new(csr.clone())),
OpFormat::Dense => Ok(Arc::new(
<SprsBackend as SparseBackend<S>>::dense_from_csr(csr)?,
)),
OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
"sprs backend cannot materialize the requested format",
)),
};
}
if let Some(dense) = op.as_any().downcast_ref::<SprsDenseMat>() {
return match want {
OpFormat::Csr => Ok(Arc::new(
<SprsBackend as SparseBackend<S>>::csr_from_dense(dense, drop_tol)?,
)),
OpFormat::Dense => Ok(Arc::new(dense.clone())),
OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
"sprs backend cannot materialize the requested format",
)),
};
}
Err(KError::Unsupported(
"sprs backend cannot materialize the requested operator",
))
}
pub fn try_materialize_ref(
op: &dyn LinOp<S = S>,
want: OpFormat,
drop_tol: R,
) -> Result<Arc<dyn LinOp<S = S>>, KError> {
if want.is_any() {
return Err(KError::Unsupported(
"sprs backend cannot materialize OpFormat::Any",
));
}
if !<SprsBackend as SparseBackend<S>>::FORMAT_SUPPORT.supports(want) {
return Err(KError::Unsupported(
"sprs backend does not support the requested format",
));
}
if let Some(csr) = op.as_any().downcast_ref::<CsMat<f64>>() {
return match want {
OpFormat::Csr => Ok(Arc::new(csr.clone())),
OpFormat::Dense => Ok(Arc::new(
<SprsBackend as SparseBackend<S>>::dense_from_csr(csr)?,
)),
OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
"sprs backend cannot materialize the requested format",
)),
};
}
if let Some(dense) = op.as_any().downcast_ref::<SprsDenseMat>() {
return match want {
OpFormat::Csr => Ok(Arc::new(
<SprsBackend as SparseBackend<S>>::csr_from_dense(dense, drop_tol)?,
)),
OpFormat::Dense => Ok(Arc::new(dense.clone())),
OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
"sprs backend cannot materialize the requested format",
)),
};
}
Err(KError::Unsupported(
"sprs backend cannot materialize the requested operator",
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::backend;
use crate::matrix::format::OpFormat;
#[test]
fn materialize_dense_and_csr() {
let dense = SprsDenseMat::from_row_major(2, 2, vec![1.0, 0.0, 0.0, 2.0]);
let op: Arc<dyn LinOp<S = S>> = Arc::new(dense.clone());
let csr = backend::materialize(op.clone(), OpFormat::Csr, 0.0).unwrap();
assert_eq!(csr.format(), OpFormat::Csr);
let csr_ref = csr.as_any().downcast_ref::<CsMat<f64>>().unwrap();
assert_eq!(csr_ref.rows(), 2);
let dense_again =
backend::materialize(Arc::new(csr_ref.clone()), OpFormat::Dense, 0.0).unwrap();
assert_eq!(dense_again.format(), OpFormat::Dense);
let dense_ref = dense_again
.as_any()
.downcast_ref::<SprsDenseMat>()
.unwrap();
assert_eq!(dense_ref.get(0, 0), 1.0);
assert_eq!(dense_ref.get(1, 1), 2.0);
}
}