differential_equations/linalg/matrix/
sub.rs1use core::ops::{Sub, SubAssign};
4
5use super::base::{Matrix, MatrixStorage};
6use crate::traits::Real;
7
8impl<T: Real> Sub for Matrix<T> {
10 type Output = Matrix<T>;
11
12 fn sub(self, rhs: Matrix<T>) -> Self::Output {
13 match (self, rhs) {
14 (
15 Matrix {
16 n: n1,
17 storage: MatrixStorage::Identity,
18 ..
19 },
20 Matrix {
21 n: n2,
22 storage: MatrixStorage::Identity,
23 ..
24 },
25 ) => {
26 assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
27 Matrix {
28 n: n1,
29 m: n1,
30 data: vec![T::zero(); n1 * n1],
31 storage: MatrixStorage::Full,
32 }
33 }
34 (
35 Matrix {
36 n,
37 data: mut a,
38 storage: MatrixStorage::Full,
39 ..
40 },
41 Matrix {
42 n: n2,
43 data: b,
44 storage: MatrixStorage::Full,
45 ..
46 },
47 ) => {
48 assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
49 for (x, y) in a.iter_mut().zip(b.iter()) {
50 *x -= *y;
51 }
52 Matrix {
53 n,
54 m: n,
55 data: a,
56 storage: MatrixStorage::Full,
57 }
58 }
59 (
60 Matrix {
61 n,
62 data: a,
63 storage: MatrixStorage::Banded { ml, mu, .. },
64 ..
65 },
66 Matrix {
67 n: n2,
68 data: b,
69 storage:
70 MatrixStorage::Banded {
71 ml: ml2, mu: mu2, ..
72 },
73 ..
74 },
75 ) => {
76 assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
77 let ml_out = ml.max(ml2);
78 let mu_out = mu.max(mu2);
79 let rows_out = ml_out + mu_out + 1;
80 let mut out = Matrix {
81 n,
82 m: n,
83 data: vec![T::zero(); rows_out * n],
84 storage: MatrixStorage::Banded {
85 ml: ml_out,
86 mu: mu_out,
87 zero: T::zero(),
88 },
89 };
90 for j in 0..n {
92 for r in 0..(ml + mu + 1) {
93 let k = r as isize - mu as isize; let i_signed = j as isize + k;
95 if i_signed >= 0 && (i_signed as usize) < n {
96 let row_out = (k + mu_out as isize) as usize;
97 out.data[row_out * n + j] += a[r * n + j];
98 }
99 }
100 }
101 for j in 0..n {
103 for r in 0..(ml2 + mu2 + 1) {
104 let k = r as isize - mu2 as isize; let i_signed = j as isize + k;
106 if i_signed >= 0 && (i_signed as usize) < n {
107 let row_out = (k + mu_out as isize) as usize;
108 out.data[row_out * n + j] -= b[r * n + j];
109 }
110 }
111 }
112 out
113 }
114 (
116 Matrix {
117 n: n1,
118 data: a,
119 storage: sa,
120 ..
121 },
122 Matrix {
123 n: n2,
124 data: b,
125 storage: sb,
126 ..
127 },
128 ) => {
129 assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
130 let to_full = |n: usize, data: Vec<T>, storage: MatrixStorage<T>| -> Vec<T> {
131 match storage {
132 MatrixStorage::Full => data,
133 MatrixStorage::Identity => {
134 let mut d = vec![T::zero(); n * n];
135 for i in 0..n {
136 d[i * n + i] = T::one();
137 }
138 d
139 }
140 MatrixStorage::Banded { ml, mu, .. } => {
141 let mut d = vec![T::zero(); n * n];
142 for j in 0..n {
143 for r in 0..(ml + mu + 1) {
144 let k = r as isize - mu as isize; let i_signed = j as isize + k;
146 if i_signed >= 0 && (i_signed as usize) < n {
147 let i = i_signed as usize;
148 d[i * n + j] += data[r * n + j];
149 }
150 }
151 }
152 d
153 }
154 }
155 };
156 let aa = to_full(n1, a, sa);
157 let bb = to_full(n2, b, sb);
158 let data = aa.into_iter().zip(bb).map(|(x, y)| x - y).collect();
159 Matrix {
160 n: n1,
161 m: n1,
162 data,
163 storage: MatrixStorage::Full,
164 }
165 }
166 }
167 }
168}
169
170impl<T: Real> SubAssign<Matrix<T>> for Matrix<T> {
174 fn sub_assign(&mut self, rhs: Matrix<T>) {
175 let n = self.n;
176 let lhs = core::mem::replace(self, Matrix::zeros(n, n));
177 *self = lhs - rhs;
178 }
179}
180
181impl<T: Real> SubAssign<&Matrix<T>> for Matrix<T> {
183 fn sub_assign(&mut self, rhs: &Matrix<T>) {
184 let n = self.n;
185 let lhs = core::mem::replace(self, Matrix::zeros(n, n));
186 *self = lhs - rhs.clone();
187 }
188}
189
190impl<T: Real> Matrix<T> {
191 pub fn component_sub(mut self, rhs: T) -> Self {
193 match &mut self.storage {
194 MatrixStorage::Identity => {
195 let n = self.n;
196 let mut data = vec![T::zero() - rhs; n * n];
197 for i in 0..n {
198 data[i * n + i] = T::one() - rhs;
199 }
200 Matrix {
201 n,
202 m: n,
203 data,
204 storage: MatrixStorage::Full,
205 }
206 }
207 MatrixStorage::Full => {
208 for v in &mut self.data {
209 *v -= rhs;
210 }
211 self
212 }
213 MatrixStorage::Banded { ml, mu, .. } => {
214 let n = self.n;
215 if rhs == T::zero() {
216 self
217 } else {
218 let rows = *ml + *mu + 1;
219 let mut dense = vec![T::zero() - rhs; n * n];
220 for j in 0..n {
221 for r in 0..rows {
222 let k = r as isize - *mu as isize;
223 let i_signed = j as isize + k;
224 if i_signed >= 0 && (i_signed as usize) < n {
225 let i = i_signed as usize;
226 let val = self.data[r * n + j];
227 dense[i * n + j] = val - rhs;
228 }
229 }
230 }
231 Matrix {
232 n,
233 m: n,
234 data: dense,
235 storage: MatrixStorage::Full,
236 }
237 }
238 }
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use crate::linalg::matrix::Matrix;
246
247 #[test]
248 fn sub_scalar_full() {
249 let m: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
250 let r = m.component_sub(1.0);
251 assert_eq!(r[(0, 0)], 0.0);
252 assert_eq!(r[(0, 1)], 1.0);
253 assert_eq!(r[(1, 0)], 2.0);
254 assert_eq!(r[(1, 1)], 3.0);
255 }
256
257 #[test]
258 fn sub_scalar_banded_zero_keeps_banded() {
259 let m: Matrix<f64> = Matrix::banded(3, 1, 1);
260 let r = m.component_sub(0.0);
261 for i in 0..3 {
262 for j in 0..3 {
263 assert_eq!(r[(i, j)], 0.0);
264 }
265 }
266 }
267
268 #[test]
269 fn sub_matrix_full_full() {
270 let a: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
271 let b: Matrix<f64> = Matrix::from_vec(2, 2, vec![4.0, 3.0, 2.0, 1.0]);
272 let r = a - b;
273 assert_eq!(r[(0, 0)], -3.0);
274 assert_eq!(r[(0, 1)], -1.0);
275 assert_eq!(r[(1, 0)], 1.0);
276 assert_eq!(r[(1, 1)], 3.0);
277 }
278
279 #[test]
280 fn sub_matrix_banded_banded() {
281 let mut a: Matrix<f64> = Matrix::banded(3, 1, 0);
283 let mut b: Matrix<f64> = Matrix::banded(3, 0, 1);
284 a[(0, 0)] = 1.0;
286 a[(1, 1)] = 1.0;
287 a[(2, 2)] = 1.0;
288 a[(1, 0)] = 1.0;
289 a[(2, 1)] = 1.0;
290 b[(0, 0)] = 2.0;
292 b[(1, 1)] = 2.0;
293 b[(2, 2)] = 2.0;
294 b[(0, 1)] = 2.0;
295 b[(1, 2)] = 2.0;
296 let r = a - b;
297 assert_eq!(r[(0, 0)], -1.0);
299 assert_eq!(r[(1, 1)], -1.0);
300 assert_eq!(r[(2, 2)], -1.0);
301 assert_eq!(r[(1, 0)], 1.0);
302 assert_eq!(r[(2, 1)], 1.0);
303 assert_eq!(r[(0, 1)], -2.0);
304 assert_eq!(r[(1, 2)], -2.0);
305 assert_eq!(r[(0, 2)], 0.0);
306 }
307}