use crate::traits::{Real, State};
use super::base::{Matrix, MatrixStorage};
impl<T: Real> Matrix<T> {
pub fn component_mul(mut self, rhs: T) -> Self {
match self.storage {
MatrixStorage::Identity => Matrix::diagonal(vec![rhs; self.n]),
MatrixStorage::Full => {
for v in &mut self.data {
*v *= rhs;
}
self
}
MatrixStorage::Banded { ml, mu, .. } => {
let n = self.n;
let data = self.data.into_iter().map(|x| x * rhs).collect();
Matrix {
n,
m: n,
data,
storage: MatrixStorage::Banded {
ml,
mu,
zero: T::zero(),
},
}
}
MatrixStorage::Sparse { mut coords, zero } => {
let n = self.n;
let m = self.m;
for item in coords.iter_mut() {
item.2 *= rhs;
}
Matrix {
n,
m,
data: Vec::new(),
storage: MatrixStorage::Sparse { coords, zero },
}
}
}
}
pub fn component_mul_mut(&mut self, rhs: T) {
match &mut self.storage {
MatrixStorage::Identity => {
let n = self.n;
self.data = vec![rhs; n];
self.storage = MatrixStorage::Banded {
ml: 0,
mu: 0,
zero: T::zero(),
};
}
MatrixStorage::Full => {
for v in &mut self.data {
*v *= rhs;
}
}
MatrixStorage::Banded { .. } => {
for v in &mut self.data {
*v *= rhs;
}
}
MatrixStorage::Sparse { coords, .. } => {
for item in coords.iter_mut() {
item.2 *= rhs;
}
}
}
}
pub fn mul_state<V: State<T>>(&self, vec: &V) -> V {
let n = self.n;
assert_eq!(vec.len(), self.m, "dimension mismatch in Matrix::mul_state");
match self.storage {
MatrixStorage::Identity => vec.clone(),
MatrixStorage::Sparse { ref coords, .. } => {
let mut result = vec.zeros_like();
for &(r, c, v) in coords {
let current = result.get_component(r);
result.set_component(r, current + v * vec.get_component(c));
}
result
}
MatrixStorage::Banded { ml, mu, .. } => {
let mut result = vec.zeros_like();
for i in 0..self.n {
let start = i.saturating_sub(ml);
let end = (i + mu + 1).min(self.m);
let mut sum = T::zero();
for j in start..end {
sum += self[(i, j)] * vec.get_component(j);
}
result.set_component(i, sum);
}
result
}
MatrixStorage::Full => vec.mul_by_dense_matrix(&self.data, n, self.m),
}
}
}
#[cfg(all(test, feature = "nalgebra"))]
mod tests {
use super::Matrix;
use nalgebra::Vector2;
#[test]
fn mul_matrix_full() {
let a: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let s = 5.0;
let out = a.component_mul(s);
assert_eq!(out[(0, 0)], 5.0);
assert_eq!(out[(0, 1)], 10.0);
assert_eq!(out[(1, 0)], 15.0);
assert_eq!(out[(1, 1)], 20.0);
}
#[test]
fn mul_identity() {
let a: Matrix<f64> = Matrix::identity(2);
let s = 5.0;
let out = a.component_mul(s);
assert_eq!(out[(0, 0)], 5.0);
assert_eq!(out[(0, 1)], 0.0);
assert_eq!(out[(1, 0)], 0.0);
assert_eq!(out[(1, 1)], 5.0);
}
#[test]
fn mul_assign() {
let a: Matrix<f64> = Matrix::identity(2);
let s = 5.0;
let a = a.component_mul(s);
assert_eq!(a[(0, 0)], 5.0);
assert_eq!(a[(0, 1)], 0.0);
assert_eq!(a[(1, 0)], 0.0);
assert_eq!(a[(1, 1)], 5.0);
}
#[test]
fn mul_sparse_preserves_dimensions() {
let a = Matrix::sparse_from_triplets(2, 3, vec![(0, 2, 4.0)]).component_mul(2.0);
assert_eq!(a.dims(), (2, 3));
assert_eq!(a[(0, 2)], 8.0);
}
#[test]
fn mul_state_sparse_accumulates_duplicate_entries() {
let a = Matrix::sparse_from_triplets(2, 2, vec![(0, 0, 2.0), (0, 0, 3.0), (1, 1, 4.0)]);
let x = Vector2::new(2.0, 3.0);
let y = a.mul_state(&x);
assert_eq!(y, Vector2::new(10.0, 12.0));
}
}