kryst 3.2.1

Krylov subspace and preconditioned iterative solvers for dense and sparse linear systems, with shared and distributed memory parallelism.
#![cfg(feature = "backend-sprs")]
//! Sprs-backed sparse backend.

use std::any::Any;
use std::sync::Arc;

use sprs::{CsMat, TriMat, CSR};

use crate::algebra::prelude::*;
use crate::error::KError;
use crate::matrix::backend::SparseBackend;
use crate::matrix::dense_api::{DenseMatMut, DenseMatRef, DenseMatShape};
use crate::matrix::format::{BackendFormatSupport, OpFormat};
use crate::matrix::op::LinOp;

/// Marker type for the sprs backend.
pub struct SprsBackend;

/// Minimal row-major dense matrix for sprs-backed conversions.
#[derive(Clone, Debug, PartialEq)]
pub struct SprsDenseMat {
    nrows: usize,
    ncols: usize,
    data: Vec<f64>,
}

impl SprsDenseMat {
    pub fn from_row_major(nrows: usize, ncols: usize, data: Vec<f64>) -> Self {
        assert_eq!(
            data.len(),
            nrows * ncols,
            "row-major dense data length must equal nrows * ncols"
        );
        Self { nrows, ncols, data }
    }

    #[inline]
    fn idx(&self, i: usize, j: usize) -> usize {
        i * self.ncols + j
    }
}

impl DenseMatShape for SprsDenseMat {
    fn nrows(&self) -> usize {
        self.nrows
    }

    fn ncols(&self) -> usize {
        self.ncols
    }
}

impl DenseMatRef<f64> for SprsDenseMat {
    fn get(&self, i: usize, j: usize) -> f64 {
        self.data[self.idx(i, j)]
    }
}

impl DenseMatMut<f64> for SprsDenseMat {
    fn set(&mut self, i: usize, j: usize, val: f64) {
        let idx = self.idx(i, j);
        self.data[idx] = val;
    }
}

impl LinOp for SprsDenseMat {
    type S = f64;

    fn dims(&self) -> (usize, usize) {
        (self.nrows, self.ncols)
    }

    fn matvec(&self, x: &[Self::S], y: &mut [Self::S]) {
        y.fill(0.0);
        for i in 0..self.nrows {
            let row_offset = i * self.ncols;
            let mut acc = 0.0;
            for j in 0..self.ncols {
                acc += self.data[row_offset + j] * x[j];
            }
            y[i] = acc;
        }
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn format(&self) -> OpFormat {
        OpFormat::Dense
    }
}

impl LinOp for CsMat<f64> {
    type S = f64;

    fn dims(&self) -> (usize, usize) {
        (self.rows(), self.cols())
    }

    fn matvec(&self, x: &[Self::S], y: &mut [Self::S]) {
        y.fill(0.0);
        if self.storage() != CSR {
            let csr = self.to_csr();
            for (row_ind, row_vec) in csr.outer_iterator().enumerate() {
                let mut acc = 0.0;
                for (col_ind, val) in row_vec.iter() {
                    acc += val * x[col_ind];
                }
                y[row_ind] = acc;
            }
            return;
        }
        for (row_ind, row_vec) in self.outer_iterator().enumerate() {
            let mut acc = 0.0;
            for (col_ind, val) in row_vec.iter() {
                acc += val * x[col_ind];
            }
            y[row_ind] = acc;
        }
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn format(&self) -> OpFormat {
        OpFormat::Csr
    }
}

impl<S> SparseBackend<S> for SprsBackend
where
    S: KrystScalar<Real = f64>,
{
    const FORMAT_SUPPORT: BackendFormatSupport = BackendFormatSupport::new(true, true, false, false);

    type Csr = CsMat<f64>;
    type Csc = ();
    type Dense = SprsDenseMat;

    fn csr_from_dense(dense: &Self::Dense, drop_tol: S::Real) -> Result<Self::Csr, KError> {
        let mut triplet = TriMat::with_capacity((dense.nrows, dense.ncols), dense.data.len());
        for i in 0..dense.nrows {
            for j in 0..dense.ncols {
                let val = dense.get(i, j);
                if val.abs() > drop_tol {
                    triplet.add_triplet(i, j, val);
                }
            }
        }
        Ok(triplet.to_csr())
    }

    fn csc_from_csr(_csr: &Self::Csr, _drop_tol: S::Real) -> Self::Csc {
        ()
    }

    fn csr_from_csc(_csc: &Self::Csc, _drop_tol: S::Real) -> Self::Csr {
        unreachable!("sprs backend does not support CSC materialization")
    }

    fn dense_from_csr(csr: &Self::Csr) -> Result<Self::Dense, KError> {
        let csr = if csr.storage() == CSR { csr.clone() } else { csr.to_csr() };
        let mut data = vec![0.0; csr.rows() * csr.cols()];
        for (row_ind, row_vec) in csr.outer_iterator().enumerate() {
            for (col_ind, val) in row_vec.iter() {
                data[row_ind * csr.cols() + col_ind] += val;
            }
        }
        Ok(SprsDenseMat::from_row_major(csr.rows(), csr.cols(), data))
    }

    fn dense_from_csc(_csc: &Self::Csc) -> Result<Self::Dense, KError> {
        Err(KError::Unsupported(
            "sprs backend does not support CSC materialization",
        ))
    }
}

pub fn try_materialize(
    op: Arc<dyn LinOp<S = S>>,
    want: OpFormat,
    drop_tol: R,
) -> Result<Arc<dyn LinOp<S = S>>, KError> {
    if want.is_any() {
        return Ok(op);
    }
    if !<SprsBackend as SparseBackend<S>>::FORMAT_SUPPORT.supports(want) {
        return Err(KError::Unsupported(
            "sprs backend does not support the requested format",
        ));
    }

    if let Some(csr) = op.as_any().downcast_ref::<CsMat<f64>>() {
        return match want {
            OpFormat::Csr => Ok(Arc::new(csr.clone())),
            OpFormat::Dense => Ok(Arc::new(
                <SprsBackend as SparseBackend<S>>::dense_from_csr(csr)?,
            )),
            OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
                "sprs backend cannot materialize the requested format",
            )),
        };
    }

    if let Some(dense) = op.as_any().downcast_ref::<SprsDenseMat>() {
        return match want {
            OpFormat::Csr => Ok(Arc::new(
                <SprsBackend as SparseBackend<S>>::csr_from_dense(dense, drop_tol)?,
            )),
            OpFormat::Dense => Ok(Arc::new(dense.clone())),
            OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
                "sprs backend cannot materialize the requested format",
            )),
        };
    }

    Err(KError::Unsupported(
        "sprs backend cannot materialize the requested operator",
    ))
}

pub fn try_materialize_ref(
    op: &dyn LinOp<S = S>,
    want: OpFormat,
    drop_tol: R,
) -> Result<Arc<dyn LinOp<S = S>>, KError> {
    if want.is_any() {
        return Err(KError::Unsupported(
            "sprs backend cannot materialize OpFormat::Any",
        ));
    }
    if !<SprsBackend as SparseBackend<S>>::FORMAT_SUPPORT.supports(want) {
        return Err(KError::Unsupported(
            "sprs backend does not support the requested format",
        ));
    }

    if let Some(csr) = op.as_any().downcast_ref::<CsMat<f64>>() {
        return match want {
            OpFormat::Csr => Ok(Arc::new(csr.clone())),
            OpFormat::Dense => Ok(Arc::new(
                <SprsBackend as SparseBackend<S>>::dense_from_csr(csr)?,
            )),
            OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
                "sprs backend cannot materialize the requested format",
            )),
        };
    }

    if let Some(dense) = op.as_any().downcast_ref::<SprsDenseMat>() {
        return match want {
            OpFormat::Csr => Ok(Arc::new(
                <SprsBackend as SparseBackend<S>>::csr_from_dense(dense, drop_tol)?,
            )),
            OpFormat::Dense => Ok(Arc::new(dense.clone())),
            OpFormat::Csc | OpFormat::BlockCsr | OpFormat::Any => Err(KError::Unsupported(
                "sprs backend cannot materialize the requested format",
            )),
        };
    }

    Err(KError::Unsupported(
        "sprs backend cannot materialize the requested operator",
    ))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::matrix::backend;
    use crate::matrix::format::OpFormat;

    #[test]
    fn materialize_dense_and_csr() {
        let dense = SprsDenseMat::from_row_major(2, 2, vec![1.0, 0.0, 0.0, 2.0]);
        let op: Arc<dyn LinOp<S = S>> = Arc::new(dense.clone());

        let csr = backend::materialize(op.clone(), OpFormat::Csr, 0.0).unwrap();
        assert_eq!(csr.format(), OpFormat::Csr);
        let csr_ref = csr.as_any().downcast_ref::<CsMat<f64>>().unwrap();
        assert_eq!(csr_ref.rows(), 2);

        let dense_again =
            backend::materialize(Arc::new(csr_ref.clone()), OpFormat::Dense, 0.0).unwrap();
        assert_eq!(dense_again.format(), OpFormat::Dense);
        let dense_ref = dense_again
            .as_any()
            .downcast_ref::<SprsDenseMat>()
            .unwrap();
        assert_eq!(dense_ref.get(0, 0), 1.0);
        assert_eq!(dense_ref.get(1, 1), 2.0);
    }
}