algebra-sparse 0.4.0-beta.1

Efficient sparse linear algebra library built on nalgebra with CSR/CSC formats and block diagonal matrix support
Documentation
// Copyright (C) 2020-2025 algebra-sparse authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::Mul;

use na::{DMatrix, DMatrixView, DVector, DVectorView, DVectorViewMut};

use crate::csv::CsVecRef;
use crate::{
    CscMatrixView, CsrMatrix, CsrMatrixView, CsrMatrixViewMethods, DiagonalBlockMatrixView, Real,
};

/// Multiply a sparse matrix `a` with a block diagonal matrix `b` and store the result in `o`.
///
/// `C = A * B`
pub(crate) fn mul_csr_bd_to<T>(
    a: CsrMatrixView<T>,
    b: DiagonalBlockMatrixView<T>,
    o: &mut CsrMatrix<T>,
) where
    T: Real,
{
    assert_eq!(a.ncols(), b.nrows());
    assert_eq!(b.ncols(), o.ncols());
    assert_eq!(o.nrows(), 0);

    for i in 0..a.nrows() {
        let mut or = o.new_row_builder(T::zero_threshold());
        let ar = a.get_row(i);
        let mut a_col_start = 0;
        for bindex in 0..b.num_blocks() {
            let range = b.get_block_row_range(bindex);
            let block = b.view_block(bindex);
            let mut a_n = 0;
            for col in ar.indices().iter().skip(a_col_start) {
                if *col >= range.end {
                    break;
                }
                a_n += 1;
            }
            for j in 0..block.ncols() {
                let mut o_ij = T::zero();
                let col = block.column(j);
                for (k, a_ik) in ar.iter().skip(a_col_start).take(a_n) {
                    o_ij += a_ik * col[k - range.start];
                }

                if o_ij.abs() > T::zero_threshold() {
                    or.push(range.start + j, o_ij);
                }
            }
            a_col_start += a_n;
        }
    }
}

/// Csr(O) = Csr(a) * Csc(b).
///
/// The o is assumed empty before pass to this function.
pub(crate) fn mul_csr_csc_to<T: Real>(
    a: CsrMatrixView<T>,
    b: CscMatrixView<T>,
    o: &mut CsrMatrix<T>,
) {
    assert_eq!(o.nrows(), 0);
    assert_eq!(a.ncols(), b.nrows());

    for i in 0..a.nrows() {
        let mut or = o.new_row_builder(T::zero_threshold());
        let ar = a.get_row(i);

        for j in 0..b.ncols() {
            let bc = b.get_col(j);

            let o_ij = dot_csvec(ar, bc);
            if o_ij.abs() > T::zero_threshold() {
                or.push(j, o_ij);
            }
        }
    }
}

/// `DenseVec(o) = CsrMat(a) x DenseVec(b)`
#[inline]
pub(crate) fn mul_csr_dvec_to_dvec<T: Real>(
    a: CsrMatrixView<T>,
    b: DVectorView<T>,
    mut o: DVectorViewMut<T>,
) {
    assert_eq!(a.nrows(), o.len());
    assert!(!b.is_empty());

    for i in 0..a.nrows() {
        let ar = a.get_row(i);
        let mut o_i = T::zero();

        for (j, a_ij) in ar.iter() {
            o_i += a_ij * b[j];
        }

        o[i] = o_i;
    }
}

/// `DenseVec(o) = CscMat(a) x DenseVec(b)`
pub(crate) fn mul_csc_dvec_to_dvec<T>(
    a: CscMatrixView<T>,
    b: DVectorView<T>,
    mut o: DVectorViewMut<T>,
) where
    T: Real,
{
    assert_eq!(a.ncols(), b.len());
    assert_eq!(o.len(), a.nrows());
    assert!(!b.is_empty());
    debug_assert!(o.iter().all(|v| v.abs() < T::zero_threshold()));

    for j in 0..a.ncols() {
        let aj = a.get_col(j);

        for (i, a_ij) in aj.iter() {
            o[i] += a_ij * b[j];
        }
    }
}

pub(crate) fn mul_csr_dmat_to_csr<T>(a: CsrMatrixView<T>, b: DMatrixView<T>, o: &mut CsrMatrix<T>)
where
    T: Real,
{
    assert_eq!(a.ncols(), b.nrows());
    assert_eq!(o.ncols(), b.ncols());
    assert_eq!(o.nrows(), 0);

    for i in 0..a.nrows() {
        let ar = a.get_row(i);
        let mut or = o.new_row_builder(T::zero_threshold());
        for j in 0..b.ncols() {
            let bc = b.column(j);
            let o_ij = dot_csv_dv(ar, bc);
            if o_ij.abs() > T::zero_threshold() {
                or.push(j, o_ij);
            }
        }
    }
}

/// Compute the dot product of two sparse vectors `a` and `b`.
///
/// # Note
///
/// This method need CsVec's element stored in ascending order.
pub(crate) fn dot_csvec<T: Real>(a: CsVecRef<T>, b: CsVecRef<T>) -> T {
    let mut res = T::zero();

    let col_a = a.indices();
    let col_b = b.indices();
    let values_a = a.values();
    let values_b = b.values();
    let mut ia = 0;
    let mut ib = 0;

    unsafe {
        while ia < col_a.len() && ib < col_b.len() {
            let ca = *col_a.get_unchecked(ia);
            let cb = *col_b.get_unchecked(ib);
            match ca.cmp(&cb) {
                std::cmp::Ordering::Less => {
                    ia += 1;
                }
                std::cmp::Ordering::Equal => {
                    res += *values_a.get_unchecked(ia) * *values_b.get_unchecked(ib);
                    ia += 1;
                    ib += 1;
                }
                std::cmp::Ordering::Greater => {
                    ib += 1;
                }
            }
        }
    }

    res
}

/// Dot product between sparse vec `a` and dense vector `b`
pub(crate) fn dot_csv_dv<T>(a: CsVecRef<T>, b: DVectorView<T>) -> T
where
    T: Real,
{
    assert_eq!(a.len(), b.len());
    let mut res = T::zero();
    for (i, a_ij) in a.iter() {
        res += a_ij * b[i];
    }
    res
}

pub(crate) fn add_csv_dv<T>(a: CsVecRef<T>, d: DVectorView<T>, mut o: DVectorViewMut<T>)
where
    T: Real,
{
    assert_eq!(a.len(), d.len());
    assert_eq!(a.len(), o.len());
    o.copy_from(&d);
    for (i, a_ij) in a.iter() {
        o[i] += a_ij * d[i];
    }
}

pub(crate) fn mul_bd_vec<T>(
    a: DiagonalBlockMatrixView<T>,
    b: DVectorView<T>,
    mut o: DVectorViewMut<T>,
) where
    T: Real,
{
    assert_eq!(a.ncols(), b.len());
    assert_eq!(a.nrows(), o.len());
    let mut element_offset = 0;
    let mut row_offset = 0;
    for block_index in 0..a.num_blocks() {
        let block_size = a.get_block_size(block_index);
        let block_size2 = block_size * block_size;
        let block = DMatrixView::from_slice(
            &a.values()[element_offset..element_offset + block_size2],
            block_size,
            block_size,
        );
        // let block = a.view_block(block_index);
        // let row_range = a.get_block_row_range(block_index);
        let row_range = row_offset..row_offset + block_size;
        let mut o = o.rows_range_mut(row_range.clone());
        let b = b.rows_range(row_range);
        block.mul_to(&b, &mut o);

        element_offset += block_size2;
        row_offset += block_size;
    }
}

pub fn mul_add_diag_to_csr<T: Real>(
    o: &mut CsrMatrix<T>,
    diag_scale: DVectorView<T>,
    diag_add: DVectorView<T>,
) {
    assert_eq!(o.ncols(), o.nrows());
    assert_eq!(o.nrows(), diag_scale.len());
    assert_eq!(o.nrows(), diag_add.len());

    for i in 0..o.nrows() {
        let row = o.get_row_mut(i);
        let k = row
            .col_indices
            .binary_search(&i)
            .expect("Diagonal element must be present in CSR row");
        row.values[k] = row.values[k] * diag_scale[i] + diag_add[i];
    }
}

impl<'a, T: Real> Mul<DVectorView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
    type Output = DVector<T>;

    #[inline]
    fn mul(self, rhs: DVectorView<'a, T>) -> DVector<T> {
        let mut o = DVector::zeros(self.nrows());
        mul_bd_vec(self, rhs, o.as_view_mut());
        o
    }
}

impl<'a, T: Real> Mul<DVector<T>> for DiagonalBlockMatrixView<'a, T> {
    type Output = DVector<T>;

    #[inline]
    fn mul(self, rhs: DVector<T>) -> DVector<T> {
        let mut o = DVector::zeros(self.nrows());
        mul_bd_vec(self, rhs.as_view(), o.as_view_mut());
        o
    }
}

impl<'a, T: Real> Mul<DiagonalBlockMatrixView<'a, T>> for CsrMatrixView<'a, T> {
    type Output = CsrMatrix<T>;

    fn mul(self, rhs: DiagonalBlockMatrixView<'a, T>) -> Self::Output {
        let mut o = CsrMatrix::new(rhs.ncols());
        mul_csr_bd_to(self, rhs, &mut o);
        o
    }
}

impl<'a, T: Real> Mul<CscMatrixView<'a, T>> for CsrMatrixView<'a, T> {
    type Output = CsrMatrix<T>;

    fn mul(self, rhs: CscMatrixView<'a, T>) -> Self::Output {
        let mut o = CsrMatrix::new(rhs.ncols());
        mul_csr_csc_to(self, rhs, &mut o);
        o
    }
}

impl<'a, T: Real> Mul<DMatrixView<'a, T>> for CsrMatrixView<'a, T> {
    type Output = CsrMatrix<T>;

    fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
        let mut o = CsrMatrix::new(rhs.ncols());
        mul_csr_dmat_to_csr(self, rhs, &mut o);
        o
    }
}

impl<'a, T: Real> Mul<DVectorView<'a, T>> for CsrMatrixView<'a, T> {
    type Output = DVector<T>;

    #[inline]
    fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
        let mut o = DVector::zeros(self.nrows());
        mul_csr_dvec_to_dvec(self, rhs, o.as_view_mut());
        o
    }
}

impl<'a, T: Real> Mul<DVector<T>> for CsrMatrixView<'a, T> {
    type Output = DVector<T>;

    #[inline]
    fn mul(self, rhs: DVector<T>) -> Self::Output {
        let rhs: DVectorView<T> = rhs.as_view();
        self * rhs
    }
}

impl<'a, T: Real> Mul<DVectorView<'a, T>> for CscMatrixView<'a, T> {
    type Output = DVector<T>;

    #[inline]
    fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
        let mut o = DVector::zeros(self.nrows());
        mul_csc_dvec_to_dvec(self, rhs, o.as_view_mut());
        o
    }
}

impl<'a, T: Real> Mul<DMatrixView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
    type Output = DMatrix<T>;

    fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
        let mut result = DMatrix::zeros(self.nrows(), rhs.ncols());

        for bindex in 0..self.num_blocks() {
            let block = self.view_block(bindex);
            let range = self.get_block_row_range(bindex);
            let mut output = result.rows_range_mut(range.clone());
            let rhs = rhs.rows_range(range);
            block.mul_to(&rhs, &mut output);
        }

        result
    }
}

mod add {

    use std::ops::Add;

    use super::*;

    impl<'a, T: Real> Add<DVectorView<'a, T>> for CsVecRef<'a, T> {
        type Output = DVector<T>;

        #[inline]
        fn add(self, rhs: DVectorView<'a, T>) -> Self::Output {
            let mut o = DVector::zeros(self.len());
            add_csv_dv(self, rhs, o.as_view_mut());
            o
        }
    }
}