opensrdk_linear_algebra/matrix/di/operators/
add.rs

1use crate::{
2    number::{c64, Number},
3    DiagonalMatrix, Matrix,
4};
5use rayon::prelude::*;
6use std::ops::Add;
7
8fn add_scalar<T>(lhs: DiagonalMatrix<T>, rhs: T) -> DiagonalMatrix<T>
9where
10    T: Number,
11{
12    let mut lhs = lhs;
13
14    lhs.d.par_iter_mut().for_each(|l| {
15        *l += rhs;
16    });
17
18    lhs
19}
20
21impl<T> Add<T> for DiagonalMatrix<T>
22where
23    T: Number,
24{
25    type Output = DiagonalMatrix<T>;
26
27    fn add(self, rhs: T) -> Self::Output {
28        add_scalar(self, rhs)
29    }
30}
31
32impl Add<DiagonalMatrix> for f64 {
33    type Output = DiagonalMatrix;
34
35    fn add(self, rhs: DiagonalMatrix) -> Self::Output {
36        add_scalar(rhs, self)
37    }
38}
39
40impl Add<DiagonalMatrix<c64>> for c64 {
41    type Output = DiagonalMatrix<c64>;
42
43    fn add(self, rhs: DiagonalMatrix<c64>) -> Self::Output {
44        add_scalar(rhs, self)
45    }
46}
47
48fn add<T>(lhs: DiagonalMatrix<T>, rhs: &DiagonalMatrix<T>) -> DiagonalMatrix<T>
49where
50    T: Number,
51{
52    if lhs.dim() != rhs.dim() {
53        panic!("Dimension mismatch.")
54    }
55    let mut lhs = lhs;
56
57    lhs.d
58        .par_iter_mut()
59        .zip(rhs.d.par_iter())
60        .for_each(|(l, &r)| {
61            *l += r;
62        });
63
64    lhs
65}
66
67impl<T> Add<DiagonalMatrix<T>> for DiagonalMatrix<T>
68where
69    T: Number,
70{
71    type Output = DiagonalMatrix<T>;
72
73    fn add(self, rhs: DiagonalMatrix<T>) -> Self::Output {
74        add(self, &rhs)
75    }
76}
77
78impl<T> Add<&DiagonalMatrix<T>> for DiagonalMatrix<T>
79where
80    T: Number,
81{
82    type Output = DiagonalMatrix<T>;
83
84    fn add(self, rhs: &DiagonalMatrix<T>) -> Self::Output {
85        add(self, rhs)
86    }
87}
88
89impl<T> Add<DiagonalMatrix<T>> for &DiagonalMatrix<T>
90where
91    T: Number,
92{
93    type Output = DiagonalMatrix<T>;
94
95    fn add(self, rhs: DiagonalMatrix<T>) -> Self::Output {
96        add(rhs, self)
97    }
98}
99
100fn add_mat<T>(lhs: Matrix<T>, rhs: &DiagonalMatrix<T>) -> Matrix<T>
101where
102    T: Number,
103{
104    let n = rhs.dim();
105    if lhs.rows() != n || lhs.cols() != n {
106        panic!("Dimension mismatch.")
107    }
108    let mut lhs = lhs;
109
110    for i in 0..n {
111        lhs[i][i] += rhs[i];
112    }
113
114    lhs
115}
116
117impl<T> Add<Matrix<T>> for DiagonalMatrix<T>
118where
119    T: Number,
120{
121    type Output = Matrix<T>;
122
123    fn add(self, rhs: Matrix<T>) -> Self::Output {
124        add_mat(rhs, &self)
125    }
126}
127
128impl<T> Add<Matrix<T>> for &DiagonalMatrix<T>
129where
130    T: Number,
131{
132    type Output = Matrix<T>;
133
134    fn add(self, rhs: Matrix<T>) -> Self::Output {
135        add_mat(rhs, self)
136    }
137}
138
139impl<T> Add<DiagonalMatrix<T>> for Matrix<T>
140where
141    T: Number,
142{
143    type Output = Matrix<T>;
144
145    fn add(self, rhs: DiagonalMatrix<T>) -> Self::Output {
146        add_mat(self, &rhs)
147    }
148}
149
150impl<T> Add<&DiagonalMatrix<T>> for Matrix<T>
151where
152    T: Number,
153{
154    type Output = Matrix<T>;
155
156    fn add(self, rhs: &DiagonalMatrix<T>) -> Self::Output {
157        add_mat(self, rhs)
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use crate::*;
164    #[test]
165    fn add() {
166        let a = DiagonalMatrix::new(vec![2.0, 3.0]) + DiagonalMatrix::new(vec![4.0, 5.0]);
167        assert_eq!(a[0], 6.0);
168    }
169
170    #[test]
171    fn add_mat() {
172        let a = DiagonalMatrix::new(vec![2.0, 3.0])
173            + mat!(
174              4.0, 5.0;
175              6.0, 7.0
176            );
177        assert_eq!(a[(0, 0)], 6.0);
178    }
179}