opensrdk_linear_algebra/matrix/di/operators/
mul.rs

1use crate::number::{c64, Number};
2use crate::DiagonalMatrix;
3use rayon::prelude::*;
4use std::ops::Mul;
5
6pub(crate) fn mul_scalar<T>(slf: T, rhs: DiagonalMatrix<T>) -> DiagonalMatrix<T>
7where
8    T: Number,
9{
10    let mut rhs = rhs;
11    rhs.d
12        .par_iter_mut()
13        .map(|di| {
14            *di *= slf;
15        })
16        .collect::<Vec<_>>();
17
18    rhs
19}
20
21fn mul_di<T>(lhs: DiagonalMatrix<T>, rhs: &DiagonalMatrix<T>) -> DiagonalMatrix<T>
22where
23    T: Number,
24{
25    if lhs.dim() != rhs.dim() {
26        panic!("Dimension mismatch.")
27    }
28
29    DiagonalMatrix::new(mul_vec(lhs.d, rhs.d()))
30}
31
32fn mul_vec<T>(lhs: Vec<T>, rhs: &[T]) -> Vec<T>
33where
34    T: Number,
35{
36    if lhs.len() != rhs.len() {
37        panic!("Dimension mismatch.")
38    }
39
40    let mut lhs = lhs;
41    lhs.par_iter_mut()
42        .zip(rhs.par_iter())
43        .for_each(|(li, &ri)| *li *= ri);
44
45    lhs
46}
47
48macro_rules! impl_mul_scalar {
49    {$t: ty} => {
50        impl Mul<DiagonalMatrix<$t>> for $t {
51            type Output = DiagonalMatrix<$t>;
52
53            fn mul(self, rhs: DiagonalMatrix<$t>) -> Self::Output {
54                mul_scalar(self, rhs)
55            }
56        }
57
58        impl Mul<$t> for DiagonalMatrix<$t> {
59            type Output = DiagonalMatrix<$t>;
60
61            fn mul(self, rhs: $t) -> Self::Output {
62                mul_scalar(rhs, self)
63            }
64        }
65    };
66}
67
68impl_mul_scalar! {f64}
69impl_mul_scalar! {c64}
70
71macro_rules! impl_mul_di {
72  {$t: ty, $e: expr} => {
73      impl Mul<DiagonalMatrix<$t>> for DiagonalMatrix<$t> {
74          type Output = DiagonalMatrix<$t>;
75
76          fn mul(self, rhs: DiagonalMatrix<$t>) -> Self::Output {
77              $e(self, &rhs)
78          }
79      }
80
81      impl Mul<&DiagonalMatrix<$t>> for DiagonalMatrix<$t> {
82          type Output = DiagonalMatrix<$t>;
83
84          fn mul(self, rhs: &DiagonalMatrix<$t>) -> Self::Output {
85              $e(self, rhs)
86          }
87      }
88
89      impl Mul<DiagonalMatrix<$t>> for &DiagonalMatrix<$t> {
90        type Output = DiagonalMatrix<$t>;
91
92        fn mul(self, rhs: DiagonalMatrix<$t>) -> Self::Output {
93            $e(rhs, self)
94        }
95      }
96  };
97}
98
99impl_mul_di! {f64, mul_di}
100impl_mul_di! {c64, mul_di}
101
102macro_rules! impl_mul_vec {
103  {$t: ty, $e: expr} => {
104      impl Mul<Vec<$t>> for DiagonalMatrix<$t> {
105          type Output = Vec<$t>;
106
107          fn mul(self, rhs: Vec<$t>) -> Self::Output {
108              $e(self.d, &rhs)
109          }
110      }
111
112      impl Mul<&Vec<$t>> for DiagonalMatrix<$t> {
113        type Output = Vec<$t>;
114
115        fn mul(self, rhs: &Vec<$t>) -> Self::Output {
116            $e(self.d, rhs)
117        }
118    }
119      impl Mul<Vec<$t>> for &DiagonalMatrix<$t> {
120          type Output = Vec<$t>;
121
122          fn mul(self, rhs: Vec<$t>) -> Self::Output {
123              $e(rhs, self.d())
124          }
125      }
126  };
127}
128
129impl_mul_vec! {f64, mul_vec}
130impl_mul_vec! {c64, mul_vec}
131
132#[cfg(test)]
133mod tests {
134    use crate::*;
135    #[test]
136    fn mul() {
137        let a = DiagonalMatrix::new(vec![2.0, 3.0]) * DiagonalMatrix::new(vec![4.0, 5.0]);
138        assert_eq!(a[0], 8.0);
139    }
140
141    #[test]
142    fn mul_vec() {
143        let a = DiagonalMatrix::new(vec![2.0, 3.0]) * vec![4.0, 5.0];
144        assert_eq!(a[0], 8.0);
145    }
146}