differential_equations/linalg/matrix/
mul.rs

1//! Matrix multiplication helpers.
2
3use crate::traits::{Real, State};
4
5use super::base::{Matrix, MatrixStorage};
6
7// Matrix * State (vector-like)
8impl<T: Real> Matrix<T> {
9    /// Return a new matrix where each stored entry is multiplied by `rhs`.
10    pub fn component_mul(mut self, rhs: T) -> Self {
11        match &mut self.storage {
12            MatrixStorage::Identity => Matrix::diagonal(vec![rhs; self.n]),
13            MatrixStorage::Full => {
14                for v in &mut self.data {
15                    *v *= rhs;
16                }
17                self
18            }
19            MatrixStorage::Banded { ml, mu, .. } => {
20                let n = self.n;
21                let data = self.data.into_iter().map(|x| x * rhs).collect();
22                Matrix {
23                    n,
24                    m: n,
25                    data,
26                    storage: MatrixStorage::Banded {
27                        ml: *ml,
28                        mu: *mu,
29                        zero: T::zero(),
30                    },
31                }
32            }
33        }
34    }
35
36    /// In-place component-wise scalar multiplication: self[i,j] *= rhs for all stored entries.
37    /// For Identity, converts to a diagonal banded matrix with `rhs` on the diagonal.
38    pub fn component_mul_mut(&mut self, rhs: T) {
39        match &mut self.storage {
40            MatrixStorage::Identity => {
41                // Become diagonal with rhs on the main diagonal
42                let n = self.n;
43                self.data = vec![rhs; n];
44                self.storage = MatrixStorage::Banded {
45                    ml: 0,
46                    mu: 0,
47                    zero: T::zero(),
48                };
49            }
50            MatrixStorage::Full => {
51                for v in &mut self.data {
52                    *v *= rhs;
53                }
54            }
55            MatrixStorage::Banded { .. } => {
56                for v in &mut self.data {
57                    *v *= rhs;
58                }
59            }
60        }
61    }
62
63    pub fn mul_state<V: State<T>>(&self, vec: &V) -> V {
64        let n = self.n;
65        assert_eq!(vec.len(), n, "dimension mismatch in Matrix::mul_state");
66
67        let mut result = V::zeros();
68        for i in 0..n {
69            let mut sum = T::zero();
70            for j in 0..n {
71                sum += self[(i, j)] * vec.get(j);
72            }
73            result.set(i, sum);
74        }
75        result
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::Matrix;
82
83    #[test]
84    fn mul_matrix_full() {
85        let a: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
86        let s = 5.0;
87        let out = a.component_mul(s);
88        assert_eq!(out[(0, 0)], 5.0);
89        assert_eq!(out[(0, 1)], 10.0);
90        assert_eq!(out[(1, 0)], 15.0);
91        assert_eq!(out[(1, 1)], 20.0);
92    }
93
94    #[test]
95    fn mul_identity() {
96        let a: Matrix<f64> = Matrix::identity(2);
97        let s = 5.0;
98        let out = a.component_mul(s);
99        assert_eq!(out[(0, 0)], 5.0);
100        assert_eq!(out[(0, 1)], 0.0);
101        assert_eq!(out[(1, 0)], 0.0);
102        assert_eq!(out[(1, 1)], 5.0);
103    }
104
105    #[test]
106    fn mul_assign() {
107        let a: Matrix<f64> = Matrix::identity(2);
108        let s = 5.0;
109        let a = a.component_mul(s);
110        assert_eq!(a[(0, 0)], 5.0);
111        assert_eq!(a[(0, 1)], 0.0);
112        assert_eq!(a[(1, 0)], 0.0);
113        assert_eq!(a[(1, 1)], 5.0);
114    }
115}