differential_equations/linalg/matrix/
add.rs1use core::ops::Add;
4use core::ops::AddAssign;
5
6use crate::traits::Real;
7
8use super::base::{Matrix, MatrixStorage};
9
10impl<T: Real> AddAssign<Matrix<T>> for Matrix<T> {
12 fn add_assign(&mut self, rhs: Matrix<T>) {
13 let n = self.n();
14 let lhs = core::mem::replace(self, Matrix::zeros(n));
15 *self = lhs + rhs;
16 }
17}
18
19impl<T: Real> Add for Matrix<T> {
21 type Output = Matrix<T>;
22
23 fn add(self, rhs: Matrix<T>) -> Self::Output {
24 assert_eq!(
25 self.nrows, rhs.nrows,
26 "dimension mismatch in Matrix + Matrix"
27 );
28 let n = self.nrows;
29 match (self, rhs) {
30 (
31 Matrix {
32 nrows: n1,
33 ncols: _,
34 data: _,
35 storage: MatrixStorage::Identity,
36 },
37 Matrix {
38 nrows: n2,
39 ncols: _,
40 data: _,
41 storage: MatrixStorage::Identity,
42 },
43 ) => {
44 assert_eq!(n1, n2);
45 let mut data = vec![T::zero(); n * n];
46 for i in 0..n {
47 data[i * n + i] = T::one() + T::one();
48 }
49 Matrix {
50 nrows: n,
51 ncols: n,
52 data,
53 storage: MatrixStorage::Full,
54 }
55 }
56 (
57 Matrix {
58 data: a,
59 storage: MatrixStorage::Full,
60 ..
61 },
62 Matrix {
63 data: b,
64 storage: MatrixStorage::Full,
65 ..
66 },
67 ) => {
68 let data = a
69 .into_iter()
70 .zip(b.into_iter())
71 .map(|(x, y)| x + y)
72 .collect();
73 Matrix {
74 nrows: n,
75 ncols: n,
76 data,
77 storage: MatrixStorage::Full,
78 }
79 }
80 (
81 Matrix {
82 data: a,
83 storage: MatrixStorage::Banded { ml, mu, .. },
84 ..
85 },
86 Matrix {
87 data: b,
88 storage:
89 MatrixStorage::Banded {
90 ml: ml2, mu: mu2, ..
91 },
92 ..
93 },
94 ) => {
95 let ml_out = ml.max(ml2);
96 let mu_out = mu.max(mu2);
97 let rows_out = ml_out + mu_out + 1;
98 let mut out = Matrix {
99 nrows: n,
100 ncols: n,
101 data: vec![T::zero(); rows_out * n],
102 storage: MatrixStorage::Banded {
103 ml: ml_out,
104 mu: mu_out,
105 zero: T::zero(),
106 },
107 };
108 for j in 0..n {
110 for r in 0..(ml + mu + 1) {
111 let k = r as isize - mu as isize;
112 let i_signed = j as isize + k;
113 if i_signed >= 0 && (i_signed as usize) < n {
114 let row_out = (k + mu_out as isize) as usize;
115 out.data[row_out * n + j] = out.data[row_out * n + j] + a[r * n + j];
116 }
117 }
118 }
119 for j in 0..n {
121 for r in 0..(ml2 + mu2 + 1) {
122 let k = r as isize - mu2 as isize;
123 let i_signed = j as isize + k;
124 if i_signed >= 0 && (i_signed as usize) < n {
125 let row_out = (k + mu_out as isize) as usize;
126 out.data[row_out * n + j] = out.data[row_out * n + j] + b[r * n + j];
127 }
128 }
129 }
130 out
131 }
132 (
134 Matrix {
135 data: a,
136 storage: sa,
137 ..
138 },
139 Matrix {
140 data: b,
141 storage: sb,
142 ..
143 },
144 ) => {
145 let to_full = |n: usize, data: Vec<T>, storage: MatrixStorage<T>| -> Vec<T> {
146 match storage {
147 MatrixStorage::Full => data,
148 MatrixStorage::Identity => {
149 let mut d = vec![T::zero(); n * n];
150 for i in 0..n {
151 d[i * n + i] = T::one();
152 }
153 d
154 }
155 MatrixStorage::Banded { ml, mu, .. } => {
156 let mut d = vec![T::zero(); n * n];
157 for j in 0..n {
158 for r in 0..(ml + mu + 1) {
159 let k = r as isize - mu as isize;
160 let i_signed = j as isize + k;
161 if i_signed >= 0 && (i_signed as usize) < n {
162 let i = i_signed as usize;
163 d[i * n + j] = d[i * n + j] + data[r * n + j];
164 }
165 }
166 }
167 d
168 }
169 }
170 };
171 let aa = to_full(n, a, sa);
172 let bb = to_full(n, b, sb);
173 let data = aa
174 .into_iter()
175 .zip(bb.into_iter())
176 .map(|(x, y)| x + y)
177 .collect();
178 Matrix {
179 nrows: n,
180 ncols: n,
181 data,
182 storage: MatrixStorage::Full,
183 }
184 }
185 }
186 }
187}
188
189impl<T: Real> Matrix<T> {
190 pub fn component_add(mut self, rhs: T) -> Self {
192 match &mut self.storage {
193 MatrixStorage::Identity => {
194 let n = self.nrows;
196 let mut data = vec![rhs; n * n];
197 for i in 0..n {
198 data[i * n + i] = rhs + T::one();
199 }
200 Matrix {
201 nrows: n,
202 ncols: n,
203 data,
204 storage: MatrixStorage::Full,
205 }
206 }
207 MatrixStorage::Full => {
208 for v in &mut self.data {
209 *v = *v + rhs;
210 }
211 self
212 }
213 MatrixStorage::Banded { ml, mu, .. } => {
214 let n = self.nrows;
215 if rhs == T::zero() {
216 self
217 } else {
218 let rows = *ml + *mu + 1;
219 let mut dense = vec![rhs; n * n];
220 for j in 0..n {
221 for r in 0..rows {
222 let k = r as isize - *mu as isize;
223 let i_signed = j as isize + k;
224 if i_signed >= 0 && (i_signed as usize) < n {
225 let i = i_signed as usize;
226 let val = self.data[r * n + j];
227 dense[i * n + j] = val + rhs;
228 }
229 }
230 }
231 Matrix {
232 nrows: n,
233 ncols: n,
234 data: dense,
235 storage: MatrixStorage::Full,
236 }
237 }
238 }
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use crate::linalg::matrix::Matrix;
246
247 #[test]
248 fn add_scalar_full() {
249 let m: Matrix<f64> = Matrix::full(2, vec![1.0, 2.0, 3.0, 4.0]);
250 let r = m.component_add(1.0);
251 assert_eq!(r[(0, 0)], 2.0);
252 assert_eq!(r[(0, 1)], 3.0);
253 assert_eq!(r[(1, 0)], 4.0);
254 assert_eq!(r[(1, 1)], 5.0);
255 }
256
257 #[test]
258 fn add_scalar_banded_zero_keeps_banded() {
259 let m: Matrix<f64> = Matrix::banded(3, 1, 1);
260 let r = m.component_add(0.0);
261 for i in 0..3 {
262 for j in 0..3 {
263 assert_eq!(r[(i, j)], 0.0);
264 }
265 }
266 }
267
268 #[test]
269 fn add_matrix_full_full() {
270 let a: Matrix<f64> = Matrix::full(2, vec![1.0, 2.0, 3.0, 4.0]);
271 let b: Matrix<f64> = Matrix::full(2, vec![4.0, 3.0, 2.0, 1.0]);
272 let r = a + b;
273 assert_eq!(r[(0, 0)], 5.0);
274 assert_eq!(r[(0, 1)], 5.0);
275 assert_eq!(r[(1, 0)], 5.0);
276 assert_eq!(r[(1, 1)], 5.0);
277 }
278
279 #[test]
280 fn add_matrix_banded_banded() {
281 let mut a: Matrix<f64> = Matrix::banded(3, 1, 0);
283 let mut b: Matrix<f64> = Matrix::banded(3, 0, 1);
284 a[(0, 0)] = 1.0;
286 a[(1, 1)] = 1.0;
287 a[(2, 2)] = 1.0;
288 a[(1, 0)] = 1.0;
289 a[(2, 1)] = 1.0;
290 b[(0, 0)] = 2.0;
292 b[(1, 1)] = 2.0;
293 b[(2, 2)] = 2.0;
294 b[(0, 1)] = 2.0;
295 b[(1, 2)] = 2.0;
296 let r = a + b;
297 assert_eq!(r[(0, 0)], 3.0);
299 assert_eq!(r[(1, 1)], 3.0);
300 assert_eq!(r[(2, 2)], 3.0);
301 assert_eq!(r[(1, 0)], 1.0);
302 assert_eq!(r[(2, 1)], 1.0);
303 assert_eq!(r[(0, 1)], 2.0);
304 assert_eq!(r[(1, 2)], 2.0);
305 assert_eq!(r[(0, 2)], 0.0);
306 }
307}