differential_equations/linalg/matrix/
mul.rs1use crate::traits::{Real, State};
4
5use super::base::{Matrix, MatrixStorage};
6
7impl<T: Real> Matrix<T> {
9 pub fn component_mul(mut self, rhs: T) -> Self {
11 match &mut self.storage {
12 MatrixStorage::Identity => Matrix::diagonal(vec![rhs; self.nrows]),
13 MatrixStorage::Full => {
14 for v in &mut self.data {
15 *v = *v * rhs;
16 }
17 self
18 }
19 MatrixStorage::Banded { ml, mu, .. } => {
20 let n = self.nrows;
21 let data = self.data.into_iter().map(|x| x * rhs).collect();
22 Matrix {
23 nrows: n,
24 ncols: n,
25 data,
26 storage: MatrixStorage::Banded {
27 ml: *ml,
28 mu: *mu,
29 zero: T::zero(),
30 },
31 }
32 }
33 }
34 }
35
36 pub fn component_mul_mut(&mut self, rhs: T) {
39 match &mut self.storage {
40 MatrixStorage::Identity => {
41 let n = self.nrows;
43 self.data = vec![rhs; n];
44 self.storage = MatrixStorage::Banded {
45 ml: 0,
46 mu: 0,
47 zero: T::zero(),
48 };
49 }
50 MatrixStorage::Full => {
51 for v in &mut self.data {
52 *v = *v * rhs;
53 }
54 }
55 MatrixStorage::Banded { .. } => {
56 for v in &mut self.data {
57 *v = *v * rhs;
58 }
59 }
60 }
61 }
62
63 pub fn mul_state<V: State<T>>(&self, vec: &V) -> V {
64 let n = self.n();
65 assert_eq!(vec.len(), n, "dimension mismatch in Matrix::mul_state");
66
67 let mut result = V::zeros();
68 for i in 0..n {
69 let mut sum = T::zero();
70 for j in 0..n {
71 sum = sum + self[(i, j)] * vec.get(j);
72 }
73 result.set(i, sum);
74 }
75 result
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::Matrix;
82
83 #[test]
84 fn mul_matrix_full() {
85 let a: Matrix<f64> = Matrix::full(2, vec![1.0, 2.0, 3.0, 4.0]);
86 let s = 5.0;
87 let out = a.component_mul(s);
88 assert_eq!(out[(0, 0)], 5.0);
89 assert_eq!(out[(0, 1)], 10.0);
90 assert_eq!(out[(1, 0)], 15.0);
91 assert_eq!(out[(1, 1)], 20.0);
92 }
93
94 #[test]
95 fn mul_identity() {
96 let a: Matrix<f64> = Matrix::identity(2);
97 let s = 5.0;
98 let out = a.component_mul(s);
99 assert_eq!(out[(0, 0)], 5.0);
100 assert_eq!(out[(0, 1)], 0.0);
101 assert_eq!(out[(1, 0)], 0.0);
102 assert_eq!(out[(1, 1)], 5.0);
103 }
104
105 #[test]
106 fn mul_assign() {
107 let a: Matrix<f64> = Matrix::identity(2);
108 let s = 5.0;
109 let a = a.component_mul(s);
110 assert_eq!(a[(0, 0)], 5.0);
111 assert_eq!(a[(0, 1)], 0.0);
112 assert_eq!(a[(1, 0)], 0.0);
113 assert_eq!(a[(1, 1)], 5.0);
114 }
115}