opensrdk_linear_algebra/matrix/di/operators/
mul.rs1use crate::number::{c64, Number};
2use crate::DiagonalMatrix;
3use rayon::prelude::*;
4use std::ops::Mul;
5
6pub(crate) fn mul_scalar<T>(slf: T, rhs: DiagonalMatrix<T>) -> DiagonalMatrix<T>
7where
8 T: Number,
9{
10 let mut rhs = rhs;
11 rhs.d
12 .par_iter_mut()
13 .map(|di| {
14 *di *= slf;
15 })
16 .collect::<Vec<_>>();
17
18 rhs
19}
20
21fn mul_di<T>(lhs: DiagonalMatrix<T>, rhs: &DiagonalMatrix<T>) -> DiagonalMatrix<T>
22where
23 T: Number,
24{
25 if lhs.dim() != rhs.dim() {
26 panic!("Dimension mismatch.")
27 }
28
29 DiagonalMatrix::new(mul_vec(lhs.d, rhs.d()))
30}
31
32fn mul_vec<T>(lhs: Vec<T>, rhs: &[T]) -> Vec<T>
33where
34 T: Number,
35{
36 if lhs.len() != rhs.len() {
37 panic!("Dimension mismatch.")
38 }
39
40 let mut lhs = lhs;
41 lhs.par_iter_mut()
42 .zip(rhs.par_iter())
43 .for_each(|(li, &ri)| *li *= ri);
44
45 lhs
46}
47
48macro_rules! impl_mul_scalar {
49 {$t: ty} => {
50 impl Mul<DiagonalMatrix<$t>> for $t {
51 type Output = DiagonalMatrix<$t>;
52
53 fn mul(self, rhs: DiagonalMatrix<$t>) -> Self::Output {
54 mul_scalar(self, rhs)
55 }
56 }
57
58 impl Mul<$t> for DiagonalMatrix<$t> {
59 type Output = DiagonalMatrix<$t>;
60
61 fn mul(self, rhs: $t) -> Self::Output {
62 mul_scalar(rhs, self)
63 }
64 }
65 };
66}
67
68impl_mul_scalar! {f64}
69impl_mul_scalar! {c64}
70
71macro_rules! impl_mul_di {
72 {$t: ty, $e: expr} => {
73 impl Mul<DiagonalMatrix<$t>> for DiagonalMatrix<$t> {
74 type Output = DiagonalMatrix<$t>;
75
76 fn mul(self, rhs: DiagonalMatrix<$t>) -> Self::Output {
77 $e(self, &rhs)
78 }
79 }
80
81 impl Mul<&DiagonalMatrix<$t>> for DiagonalMatrix<$t> {
82 type Output = DiagonalMatrix<$t>;
83
84 fn mul(self, rhs: &DiagonalMatrix<$t>) -> Self::Output {
85 $e(self, rhs)
86 }
87 }
88
89 impl Mul<DiagonalMatrix<$t>> for &DiagonalMatrix<$t> {
90 type Output = DiagonalMatrix<$t>;
91
92 fn mul(self, rhs: DiagonalMatrix<$t>) -> Self::Output {
93 $e(rhs, self)
94 }
95 }
96 };
97}
98
99impl_mul_di! {f64, mul_di}
100impl_mul_di! {c64, mul_di}
101
102macro_rules! impl_mul_vec {
103 {$t: ty, $e: expr} => {
104 impl Mul<Vec<$t>> for DiagonalMatrix<$t> {
105 type Output = Vec<$t>;
106
107 fn mul(self, rhs: Vec<$t>) -> Self::Output {
108 $e(self.d, &rhs)
109 }
110 }
111
112 impl Mul<&Vec<$t>> for DiagonalMatrix<$t> {
113 type Output = Vec<$t>;
114
115 fn mul(self, rhs: &Vec<$t>) -> Self::Output {
116 $e(self.d, rhs)
117 }
118 }
119 impl Mul<Vec<$t>> for &DiagonalMatrix<$t> {
120 type Output = Vec<$t>;
121
122 fn mul(self, rhs: Vec<$t>) -> Self::Output {
123 $e(rhs, self.d())
124 }
125 }
126 };
127}
128
129impl_mul_vec! {f64, mul_vec}
130impl_mul_vec! {c64, mul_vec}
131
132#[cfg(test)]
133mod tests {
134 use crate::*;
135 #[test]
136 fn mul() {
137 let a = DiagonalMatrix::new(vec![2.0, 3.0]) * DiagonalMatrix::new(vec![4.0, 5.0]);
138 assert_eq!(a[0], 8.0);
139 }
140
141 #[test]
142 fn mul_vec() {
143 let a = DiagonalMatrix::new(vec![2.0, 3.0]) * vec![4.0, 5.0];
144 assert_eq!(a[0], 8.0);
145 }
146}