differential_equations/linalg/matrix/
add.rs1use core::ops::Add;
4use core::ops::AddAssign;
5
6use crate::traits::Real;
7
8use super::base::{Matrix, MatrixStorage};
9
10impl<T: Real> AddAssign<Matrix<T>> for Matrix<T> {
12 fn add_assign(&mut self, rhs: Matrix<T>) {
13 let n = self.n;
14 let m = self.m;
15 let lhs = core::mem::replace(self, Matrix::zeros(n, m));
16 *self = lhs + rhs;
17 }
18}
19
20impl<T: Real> Add for Matrix<T> {
22 type Output = Matrix<T>;
23
24 fn add(self, rhs: Matrix<T>) -> Self::Output {
25 assert_eq!(self.n, rhs.n, "dimension mismatch in Matrix + Matrix");
26 let n = self.n;
27 match (self, rhs) {
28 (
29 Matrix {
30 n: n1,
31 m: _,
32 data: _,
33 storage: MatrixStorage::Identity,
34 },
35 Matrix {
36 n: n2,
37 m: _,
38 data: _,
39 storage: MatrixStorage::Identity,
40 },
41 ) => {
42 assert_eq!(n1, n2);
43 let mut data = vec![T::zero(); n * n];
44 for i in 0..n {
45 data[i * n + i] = T::one() + T::one();
46 }
47 Matrix {
48 n,
49 m: n,
50 data,
51 storage: MatrixStorage::Full,
52 }
53 }
54 (
55 Matrix {
56 data: a,
57 storage: MatrixStorage::Full,
58 ..
59 },
60 Matrix {
61 data: b,
62 storage: MatrixStorage::Full,
63 ..
64 },
65 ) => {
66 let data = a.into_iter().zip(b).map(|(x, y)| x + y).collect();
67 Matrix {
68 n,
69 m: n,
70 data,
71 storage: MatrixStorage::Full,
72 }
73 }
74 (
75 Matrix {
76 data: a,
77 storage: MatrixStorage::Banded { ml, mu, .. },
78 ..
79 },
80 Matrix {
81 data: b,
82 storage:
83 MatrixStorage::Banded {
84 ml: ml2, mu: mu2, ..
85 },
86 ..
87 },
88 ) => {
89 let ml_out = ml.max(ml2);
90 let mu_out = mu.max(mu2);
91 let rows_out = ml_out + mu_out + 1;
92 let mut out = Matrix {
93 n,
94 m: n,
95 data: vec![T::zero(); rows_out * n],
96 storage: MatrixStorage::Banded {
97 ml: ml_out,
98 mu: mu_out,
99 zero: T::zero(),
100 },
101 };
102 for j in 0..n {
104 for r in 0..(ml + mu + 1) {
105 let k = r as isize - mu as isize;
106 let i_signed = j as isize + k;
107 if i_signed >= 0 && (i_signed as usize) < n {
108 let row_out = (k + mu_out as isize) as usize;
109 out.data[row_out * n + j] += a[r * n + j];
110 }
111 }
112 }
113 for j in 0..n {
115 for r in 0..(ml2 + mu2 + 1) {
116 let k = r as isize - mu2 as isize;
117 let i_signed = j as isize + k;
118 if i_signed >= 0 && (i_signed as usize) < n {
119 let row_out = (k + mu_out as isize) as usize;
120 out.data[row_out * n + j] += b[r * n + j];
121 }
122 }
123 }
124 out
125 }
126 (
128 Matrix {
129 data: a,
130 storage: sa,
131 ..
132 },
133 Matrix {
134 data: b,
135 storage: sb,
136 ..
137 },
138 ) => {
139 let to_full = |n: usize, data: Vec<T>, storage: MatrixStorage<T>| -> Vec<T> {
140 match storage {
141 MatrixStorage::Full => data,
142 MatrixStorage::Identity => {
143 let mut d = vec![T::zero(); n * n];
144 for i in 0..n {
145 d[i * n + i] = T::one();
146 }
147 d
148 }
149 MatrixStorage::Banded { ml, mu, .. } => {
150 let mut d = vec![T::zero(); n * n];
151 for j in 0..n {
152 for r in 0..(ml + mu + 1) {
153 let k = r as isize - mu as isize;
154 let i_signed = j as isize + k;
155 if i_signed >= 0 && (i_signed as usize) < n {
156 let i = i_signed as usize;
157 d[i * n + j] += data[r * n + j];
158 }
159 }
160 }
161 d
162 }
163 }
164 };
165 let aa = to_full(n, a, sa);
166 let bb = to_full(n, b, sb);
167 let data = aa.into_iter().zip(bb).map(|(x, y)| x + y).collect();
168 Matrix {
169 n,
170 m: n,
171 data,
172 storage: MatrixStorage::Full,
173 }
174 }
175 }
176 }
177}
178
179impl<T: Real> Matrix<T> {
180 pub fn component_add(mut self, rhs: T) -> Self {
182 match &mut self.storage {
183 MatrixStorage::Identity => {
184 let n = self.n;
186 let mut data = vec![rhs; n * n];
187 for i in 0..n {
188 data[i * n + i] = rhs + T::one();
189 }
190 Matrix {
191 n,
192 m: n,
193 data,
194 storage: MatrixStorage::Full,
195 }
196 }
197 MatrixStorage::Full => {
198 for v in &mut self.data {
199 *v += rhs;
200 }
201 self
202 }
203 MatrixStorage::Banded { ml, mu, .. } => {
204 let n = self.n;
205 if rhs == T::zero() {
206 self
207 } else {
208 let rows = *ml + *mu + 1;
209 let mut dense = vec![rhs; n * n];
210 for j in 0..n {
211 for r in 0..rows {
212 let k = r as isize - *mu as isize;
213 let i_signed = j as isize + k;
214 if i_signed >= 0 && (i_signed as usize) < n {
215 let i = i_signed as usize;
216 let val = self.data[r * n + j];
217 dense[i * n + j] = val + rhs;
218 }
219 }
220 }
221 Matrix {
222 n,
223 m: n,
224 data: dense,
225 storage: MatrixStorage::Full,
226 }
227 }
228 }
229 }
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use crate::linalg::matrix::Matrix;
236
237 #[test]
238 fn add_scalar_full() {
239 let mut m: Matrix<f64> = Matrix::full(2, 2);
240 m[(0, 0)] = 1.0;
241 m[(0, 1)] = 2.0;
242 m[(1, 0)] = 3.0;
243 m[(1, 1)] = 4.0;
244 let r = m.component_add(1.0);
245 assert_eq!(r[(0, 0)], 2.0);
246 assert_eq!(r[(0, 1)], 3.0);
247 assert_eq!(r[(1, 0)], 4.0);
248 assert_eq!(r[(1, 1)], 5.0);
249 }
250
251 #[test]
252 fn add_scalar_banded_zero_keeps_banded() {
253 let m: Matrix<f64> = Matrix::banded(3, 1, 1);
254 let r = m.component_add(0.0);
255 for i in 0..3 {
256 for j in 0..3 {
257 assert_eq!(r[(i, j)], 0.0);
258 }
259 }
260 }
261
262 #[test]
263 fn add_matrix_full_full() {
264 let mut a: Matrix<f64> = Matrix::full(2, 2);
265 a[(0, 0)] = 1.0;
266 a[(0, 1)] = 2.0;
267 a[(1, 0)] = 3.0;
268 a[(1, 1)] = 4.0;
269 let mut b: Matrix<f64> = Matrix::full(2, 2);
270 b[(0, 0)] = 4.0;
271 b[(0, 1)] = 3.0;
272 b[(1, 0)] = 2.0;
273 b[(1, 1)] = 1.0;
274 let r = a + b;
275 assert_eq!(r[(0, 0)], 5.0);
276 assert_eq!(r[(0, 1)], 5.0);
277 assert_eq!(r[(1, 0)], 5.0);
278 assert_eq!(r[(1, 1)], 5.0);
279 }
280
281 #[test]
282 fn add_matrix_banded_banded() {
283 let mut a: Matrix<f64> = Matrix::banded(3, 1, 0);
285 let mut b: Matrix<f64> = Matrix::banded(3, 0, 1);
286 a[(0, 0)] = 1.0;
288 a[(1, 1)] = 1.0;
289 a[(2, 2)] = 1.0;
290 a[(1, 0)] = 1.0;
291 a[(2, 1)] = 1.0;
292 b[(0, 0)] = 2.0;
294 b[(1, 1)] = 2.0;
295 b[(2, 2)] = 2.0;
296 b[(0, 1)] = 2.0;
297 b[(1, 2)] = 2.0;
298 let r = a + b;
299 assert_eq!(r[(0, 0)], 3.0);
301 assert_eq!(r[(1, 1)], 3.0);
302 assert_eq!(r[(2, 2)], 3.0);
303 assert_eq!(r[(1, 0)], 1.0);
304 assert_eq!(r[(2, 1)], 1.0);
305 assert_eq!(r[(0, 1)], 2.0);
306 assert_eq!(r[(1, 2)], 2.0);
307 assert_eq!(r[(0, 2)], 0.0);
308 }
309}