differential_equations/linalg/matrix/
sub.rs1use core::ops::{Sub, SubAssign};
4
5use super::base::{Matrix, MatrixStorage};
6use crate::traits::Real;
7
8impl<T: Real> Sub for Matrix<T> {
10 type Output = Matrix<T>;
11
12 fn sub(self, rhs: Matrix<T>) -> Self::Output {
13 match (self, rhs) {
14 (
15 Matrix {
16 nrows: n1,
17 storage: MatrixStorage::Identity,
18 ..
19 },
20 Matrix {
21 nrows: n2,
22 storage: MatrixStorage::Identity,
23 ..
24 },
25 ) => {
26 assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
27 Matrix {
28 nrows: n1,
29 ncols: n1,
30 data: vec![T::zero(); n1 * n1],
31 storage: MatrixStorage::Full,
32 }
33 }
34 (
35 Matrix {
36 nrows: n,
37 data: mut a,
38 storage: MatrixStorage::Full,
39 ..
40 },
41 Matrix {
42 nrows: n2,
43 data: b,
44 storage: MatrixStorage::Full,
45 ..
46 },
47 ) => {
48 assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
49 for (x, y) in a.iter_mut().zip(b.iter()) {
50 *x = *x - *y;
51 }
52 Matrix {
53 nrows: n,
54 ncols: n,
55 data: a,
56 storage: MatrixStorage::Full,
57 }
58 }
59 (
60 Matrix {
61 nrows: n,
62 data: a,
63 storage: MatrixStorage::Banded { ml, mu, .. },
64 ..
65 },
66 Matrix {
67 nrows: n2,
68 data: b,
69 storage:
70 MatrixStorage::Banded {
71 ml: ml2, mu: mu2, ..
72 },
73 ..
74 },
75 ) => {
76 assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
77 let ml_out = ml.max(ml2);
78 let mu_out = mu.max(mu2);
79 let rows_out = ml_out + mu_out + 1;
80 let mut out = Matrix {
81 nrows: n,
82 ncols: n,
83 data: vec![T::zero(); rows_out * n],
84 storage: MatrixStorage::Banded {
85 ml: ml_out,
86 mu: mu_out,
87 zero: T::zero(),
88 },
89 };
90 for j in 0..n {
92 for r in 0..(ml + mu + 1) {
93 let k = r as isize - mu as isize; let i_signed = j as isize + k;
95 if i_signed >= 0 && (i_signed as usize) < n {
96 let row_out = (k + mu_out as isize) as usize;
97 out.data[row_out * n + j] = out.data[row_out * n + j] + a[r * n + j];
98 }
99 }
100 }
101 for j in 0..n {
103 for r in 0..(ml2 + mu2 + 1) {
104 let k = r as isize - mu2 as isize; let i_signed = j as isize + k;
106 if i_signed >= 0 && (i_signed as usize) < n {
107 let row_out = (k + mu_out as isize) as usize;
108 out.data[row_out * n + j] = out.data[row_out * n + j] - b[r * n + j];
109 }
110 }
111 }
112 out
113 }
114 (
116 Matrix {
117 nrows: n1,
118 data: a,
119 storage: sa,
120 ..
121 },
122 Matrix {
123 nrows: n2,
124 data: b,
125 storage: sb,
126 ..
127 },
128 ) => {
129 assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
130 let to_full = |n: usize, data: Vec<T>, storage: MatrixStorage<T>| -> Vec<T> {
131 match storage {
132 MatrixStorage::Full => data,
133 MatrixStorage::Identity => {
134 let mut d = vec![T::zero(); n * n];
135 for i in 0..n {
136 d[i * n + i] = T::one();
137 }
138 d
139 }
140 MatrixStorage::Banded { ml, mu, .. } => {
141 let mut d = vec![T::zero(); n * n];
142 for j in 0..n {
143 for r in 0..(ml + mu + 1) {
144 let k = r as isize - mu as isize; let i_signed = j as isize + k;
146 if i_signed >= 0 && (i_signed as usize) < n {
147 let i = i_signed as usize;
148 d[i * n + j] = d[i * n + j] + data[r * n + j];
149 }
150 }
151 }
152 d
153 }
154 }
155 };
156 let aa = to_full(n1, a, sa);
157 let bb = to_full(n2, b, sb);
158 let data = aa
159 .into_iter()
160 .zip(bb.into_iter())
161 .map(|(x, y)| x - y)
162 .collect();
163 Matrix {
164 nrows: n1,
165 ncols: n1,
166 data,
167 storage: MatrixStorage::Full,
168 }
169 }
170 }
171 }
172}
173
174impl<T: Real> SubAssign<Matrix<T>> for Matrix<T> {
178 fn sub_assign(&mut self, rhs: Matrix<T>) {
179 let n = self.n();
180 let lhs = core::mem::replace(self, Matrix::zeros(n));
181 *self = lhs - rhs;
182 }
183}
184
185impl<T: Real> SubAssign<&Matrix<T>> for Matrix<T> {
187 fn sub_assign(&mut self, rhs: &Matrix<T>) {
188 let n = self.n();
189 let lhs = core::mem::replace(self, Matrix::zeros(n));
190 *self = lhs - rhs.clone();
191 }
192}
193
194impl<T: Real> Matrix<T> {
195 pub fn component_sub(mut self, rhs: T) -> Self {
197 match &mut self.storage {
198 MatrixStorage::Identity => {
199 let n = self.nrows;
200 let mut data = vec![T::zero() - rhs; n * n];
201 for i in 0..n {
202 data[i * n + i] = T::one() - rhs;
203 }
204 Matrix {
205 nrows: n,
206 ncols: n,
207 data,
208 storage: MatrixStorage::Full,
209 }
210 }
211 MatrixStorage::Full => {
212 for v in &mut self.data {
213 *v = *v - rhs;
214 }
215 self
216 }
217 MatrixStorage::Banded { ml, mu, .. } => {
218 let n = self.nrows;
219 if rhs == T::zero() {
220 self
221 } else {
222 let rows = *ml + *mu + 1;
223 let mut dense = vec![T::zero() - rhs; n * n];
224 for j in 0..n {
225 for r in 0..rows {
226 let k = r as isize - *mu as isize;
227 let i_signed = j as isize + k;
228 if i_signed >= 0 && (i_signed as usize) < n {
229 let i = i_signed as usize;
230 let val = self.data[r * n + j];
231 dense[i * n + j] = val - rhs;
232 }
233 }
234 }
235 Matrix {
236 nrows: n,
237 ncols: n,
238 data: dense,
239 storage: MatrixStorage::Full,
240 }
241 }
242 }
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use crate::linalg::matrix::Matrix;
250
251 #[test]
252 fn sub_scalar_full() {
253 let m: Matrix<f64> = Matrix::full(2, vec![1.0, 2.0, 3.0, 4.0]);
254 let r = m.component_sub(1.0);
255 assert_eq!(r[(0, 0)], 0.0);
256 assert_eq!(r[(0, 1)], 1.0);
257 assert_eq!(r[(1, 0)], 2.0);
258 assert_eq!(r[(1, 1)], 3.0);
259 }
260
261 #[test]
262 fn sub_scalar_banded_zero_keeps_banded() {
263 let m: Matrix<f64> = Matrix::banded(3, 1, 1);
264 let r = m.component_sub(0.0);
265 for i in 0..3 {
266 for j in 0..3 {
267 assert_eq!(r[(i, j)], 0.0);
268 }
269 }
270 }
271
272 #[test]
273 fn sub_matrix_full_full() {
274 let a: Matrix<f64> = Matrix::full(2, vec![1.0, 2.0, 3.0, 4.0]);
275 let b: Matrix<f64> = Matrix::full(2, vec![4.0, 3.0, 2.0, 1.0]);
276 let r = a - b;
277 assert_eq!(r[(0, 0)], -3.0);
278 assert_eq!(r[(0, 1)], -1.0);
279 assert_eq!(r[(1, 0)], 1.0);
280 assert_eq!(r[(1, 1)], 3.0);
281 }
282
283 #[test]
284 fn sub_matrix_banded_banded() {
285 let mut a: Matrix<f64> = Matrix::banded(3, 1, 0);
287 let mut b: Matrix<f64> = Matrix::banded(3, 0, 1);
288 a[(0, 0)] = 1.0;
290 a[(1, 1)] = 1.0;
291 a[(2, 2)] = 1.0;
292 a[(1, 0)] = 1.0;
293 a[(2, 1)] = 1.0;
294 b[(0, 0)] = 2.0;
296 b[(1, 1)] = 2.0;
297 b[(2, 2)] = 2.0;
298 b[(0, 1)] = 2.0;
299 b[(1, 2)] = 2.0;
300 let r = a - b;
301 assert_eq!(r[(0, 0)], -1.0);
303 assert_eq!(r[(1, 1)], -1.0);
304 assert_eq!(r[(2, 2)], -1.0);
305 assert_eq!(r[(1, 0)], 1.0);
306 assert_eq!(r[(2, 1)], 1.0);
307 assert_eq!(r[(0, 1)], -2.0);
308 assert_eq!(r[(1, 2)], -2.0);
309 assert_eq!(r[(0, 2)], 0.0);
310 }
311}