differential-equations 0.6.1

A Rust library for solving differential equations.
Documentation
//! Matrix multiplication helpers.

use crate::traits::{Real, State};

use super::base::{Matrix, MatrixStorage};

// Matrix * State (vector-like)
impl<T: Real> Matrix<T> {
    /// Return a new matrix where each stored entry is multiplied by `rhs`.
    pub fn component_mul(mut self, rhs: T) -> Self {
        match self.storage {
            MatrixStorage::Identity => Matrix::diagonal(vec![rhs; self.n]),
            MatrixStorage::Full => {
                for v in &mut self.data {
                    *v *= rhs;
                }
                self
            }
            MatrixStorage::Banded { ml, mu, .. } => {
                let n = self.n;
                let data = self.data.into_iter().map(|x| x * rhs).collect();
                Matrix {
                    n,
                    m: n,
                    data,
                    storage: MatrixStorage::Banded {
                        ml,
                        mu,
                        zero: T::zero(),
                    },
                }
            }
            MatrixStorage::Sparse { mut coords, zero } => {
                let n = self.n;
                let m = self.m;
                for item in coords.iter_mut() {
                    item.2 *= rhs;
                }
                Matrix {
                    n,
                    m,
                    data: Vec::new(),
                    storage: MatrixStorage::Sparse { coords, zero },
                }
            }
        }
    }

    /// In-place component-wise scalar multiplication: self[i,j] *= rhs for all stored entries.
    /// For Identity, converts to a diagonal banded matrix with `rhs` on the diagonal.
    pub fn component_mul_mut(&mut self, rhs: T) {
        match &mut self.storage {
            MatrixStorage::Identity => {
                // Become diagonal with rhs on the main diagonal
                let n = self.n;
                self.data = vec![rhs; n];
                self.storage = MatrixStorage::Banded {
                    ml: 0,
                    mu: 0,
                    zero: T::zero(),
                };
            }
            MatrixStorage::Full => {
                for v in &mut self.data {
                    *v *= rhs;
                }
            }
            MatrixStorage::Banded { .. } => {
                for v in &mut self.data {
                    *v *= rhs;
                }
            }
            MatrixStorage::Sparse { coords, .. } => {
                for item in coords.iter_mut() {
                    item.2 *= rhs;
                }
            }
        }
    }

    pub fn mul_state<V: State<T>>(&self, vec: &V) -> V {
        let n = self.n;
        assert_eq!(vec.len(), self.m, "dimension mismatch in Matrix::mul_state");

        match self.storage {
            MatrixStorage::Identity => vec.clone(),
            MatrixStorage::Sparse { ref coords, .. } => {
                let mut result = vec.zeros_like();
                for &(r, c, v) in coords {
                    let current = result.get_component(r);
                    result.set_component(r, current + v * vec.get_component(c));
                }
                result
            }
            MatrixStorage::Banded { ml, mu, .. } => {
                let mut result = vec.zeros_like();
                for i in 0..self.n {
                    let start = i.saturating_sub(ml);
                    let end = (i + mu + 1).min(self.m);
                    let mut sum = T::zero();
                    for j in start..end {
                        sum += self[(i, j)] * vec.get_component(j);
                    }
                    result.set_component(i, sum);
                }
                result
            }
            MatrixStorage::Full => vec.mul_by_dense_matrix(&self.data, n, self.m),
        }
    }
}

#[cfg(all(test, feature = "nalgebra"))]
mod tests {
    use super::Matrix;
    use nalgebra::Vector2;

    #[test]
    fn mul_matrix_full() {
        let a: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
        let s = 5.0;
        let out = a.component_mul(s);
        assert_eq!(out[(0, 0)], 5.0);
        assert_eq!(out[(0, 1)], 10.0);
        assert_eq!(out[(1, 0)], 15.0);
        assert_eq!(out[(1, 1)], 20.0);
    }

    #[test]
    fn mul_identity() {
        let a: Matrix<f64> = Matrix::identity(2);
        let s = 5.0;
        let out = a.component_mul(s);
        assert_eq!(out[(0, 0)], 5.0);
        assert_eq!(out[(0, 1)], 0.0);
        assert_eq!(out[(1, 0)], 0.0);
        assert_eq!(out[(1, 1)], 5.0);
    }

    #[test]
    fn mul_assign() {
        let a: Matrix<f64> = Matrix::identity(2);
        let s = 5.0;
        let a = a.component_mul(s);
        assert_eq!(a[(0, 0)], 5.0);
        assert_eq!(a[(0, 1)], 0.0);
        assert_eq!(a[(1, 0)], 0.0);
        assert_eq!(a[(1, 1)], 5.0);
    }

    #[test]
    fn mul_sparse_preserves_dimensions() {
        let a = Matrix::sparse_from_triplets(2, 3, vec![(0, 2, 4.0)]).component_mul(2.0);
        assert_eq!(a.dims(), (2, 3));
        assert_eq!(a[(0, 2)], 8.0);
    }

    #[test]
    fn mul_state_sparse_accumulates_duplicate_entries() {
        let a = Matrix::sparse_from_triplets(2, 2, vec![(0, 0, 2.0), (0, 0, 3.0), (1, 1, 4.0)]);
        let x = Vector2::new(2.0, 3.0);
        let y = a.mul_state(&x);
        assert_eq!(y, Vector2::new(10.0, 12.0));
    }
}