opensrdk_linear_algebra/matrix/di/operators/
add.rs1use crate::{
2 number::{c64, Number},
3 DiagonalMatrix, Matrix,
4};
5use rayon::prelude::*;
6use std::ops::Add;
7
8fn add_scalar<T>(lhs: DiagonalMatrix<T>, rhs: T) -> DiagonalMatrix<T>
9where
10 T: Number,
11{
12 let mut lhs = lhs;
13
14 lhs.d.par_iter_mut().for_each(|l| {
15 *l += rhs;
16 });
17
18 lhs
19}
20
21impl<T> Add<T> for DiagonalMatrix<T>
22where
23 T: Number,
24{
25 type Output = DiagonalMatrix<T>;
26
27 fn add(self, rhs: T) -> Self::Output {
28 add_scalar(self, rhs)
29 }
30}
31
32impl Add<DiagonalMatrix> for f64 {
33 type Output = DiagonalMatrix;
34
35 fn add(self, rhs: DiagonalMatrix) -> Self::Output {
36 add_scalar(rhs, self)
37 }
38}
39
40impl Add<DiagonalMatrix<c64>> for c64 {
41 type Output = DiagonalMatrix<c64>;
42
43 fn add(self, rhs: DiagonalMatrix<c64>) -> Self::Output {
44 add_scalar(rhs, self)
45 }
46}
47
48fn add<T>(lhs: DiagonalMatrix<T>, rhs: &DiagonalMatrix<T>) -> DiagonalMatrix<T>
49where
50 T: Number,
51{
52 if lhs.dim() != rhs.dim() {
53 panic!("Dimension mismatch.")
54 }
55 let mut lhs = lhs;
56
57 lhs.d
58 .par_iter_mut()
59 .zip(rhs.d.par_iter())
60 .for_each(|(l, &r)| {
61 *l += r;
62 });
63
64 lhs
65}
66
67impl<T> Add<DiagonalMatrix<T>> for DiagonalMatrix<T>
68where
69 T: Number,
70{
71 type Output = DiagonalMatrix<T>;
72
73 fn add(self, rhs: DiagonalMatrix<T>) -> Self::Output {
74 add(self, &rhs)
75 }
76}
77
78impl<T> Add<&DiagonalMatrix<T>> for DiagonalMatrix<T>
79where
80 T: Number,
81{
82 type Output = DiagonalMatrix<T>;
83
84 fn add(self, rhs: &DiagonalMatrix<T>) -> Self::Output {
85 add(self, rhs)
86 }
87}
88
89impl<T> Add<DiagonalMatrix<T>> for &DiagonalMatrix<T>
90where
91 T: Number,
92{
93 type Output = DiagonalMatrix<T>;
94
95 fn add(self, rhs: DiagonalMatrix<T>) -> Self::Output {
96 add(rhs, self)
97 }
98}
99
100fn add_mat<T>(lhs: Matrix<T>, rhs: &DiagonalMatrix<T>) -> Matrix<T>
101where
102 T: Number,
103{
104 let n = rhs.dim();
105 if lhs.rows() != n || lhs.cols() != n {
106 panic!("Dimension mismatch.")
107 }
108 let mut lhs = lhs;
109
110 for i in 0..n {
111 lhs[i][i] += rhs[i];
112 }
113
114 lhs
115}
116
117impl<T> Add<Matrix<T>> for DiagonalMatrix<T>
118where
119 T: Number,
120{
121 type Output = Matrix<T>;
122
123 fn add(self, rhs: Matrix<T>) -> Self::Output {
124 add_mat(rhs, &self)
125 }
126}
127
128impl<T> Add<Matrix<T>> for &DiagonalMatrix<T>
129where
130 T: Number,
131{
132 type Output = Matrix<T>;
133
134 fn add(self, rhs: Matrix<T>) -> Self::Output {
135 add_mat(rhs, self)
136 }
137}
138
139impl<T> Add<DiagonalMatrix<T>> for Matrix<T>
140where
141 T: Number,
142{
143 type Output = Matrix<T>;
144
145 fn add(self, rhs: DiagonalMatrix<T>) -> Self::Output {
146 add_mat(self, &rhs)
147 }
148}
149
150impl<T> Add<&DiagonalMatrix<T>> for Matrix<T>
151where
152 T: Number,
153{
154 type Output = Matrix<T>;
155
156 fn add(self, rhs: &DiagonalMatrix<T>) -> Self::Output {
157 add_mat(self, rhs)
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use crate::*;
164 #[test]
165 fn add() {
166 let a = DiagonalMatrix::new(vec![2.0, 3.0]) + DiagonalMatrix::new(vec![4.0, 5.0]);
167 assert_eq!(a[0], 6.0);
168 }
169
170 #[test]
171 fn add_mat() {
172 let a = DiagonalMatrix::new(vec![2.0, 3.0])
173 + mat!(
174 4.0, 5.0;
175 6.0, 7.0
176 );
177 assert_eq!(a[(0, 0)], 6.0);
178 }
179}