differential_equations/linalg/matrix/
sub.rs

1//! Matrix subtraction.
2
3use core::ops::{Sub, SubAssign};
4
5use super::base::{Matrix, MatrixStorage};
6use crate::traits::Real;
7
8// Matrix - Matrix
9impl<T: Real> Sub for Matrix<T> {
10    type Output = Matrix<T>;
11
12    fn sub(self, rhs: Matrix<T>) -> Self::Output {
13        match (self, rhs) {
14            (
15                Matrix {
16                    nrows: n1,
17                    storage: MatrixStorage::Identity,
18                    ..
19                },
20                Matrix {
21                    nrows: n2,
22                    storage: MatrixStorage::Identity,
23                    ..
24                },
25            ) => {
26                assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
27                Matrix {
28                    nrows: n1,
29                    ncols: n1,
30                    data: vec![T::zero(); n1 * n1],
31                    storage: MatrixStorage::Full,
32                }
33            }
34            (
35                Matrix {
36                    nrows: n,
37                    data: mut a,
38                    storage: MatrixStorage::Full,
39                    ..
40                },
41                Matrix {
42                    nrows: n2,
43                    data: b,
44                    storage: MatrixStorage::Full,
45                    ..
46                },
47            ) => {
48                assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
49                for (x, y) in a.iter_mut().zip(b.iter()) {
50                    *x = *x - *y;
51                }
52                Matrix {
53                    nrows: n,
54                    ncols: n,
55                    data: a,
56                    storage: MatrixStorage::Full,
57                }
58            }
59            (
60                Matrix {
61                    nrows: n,
62                    data: a,
63                    storage: MatrixStorage::Banded { ml, mu, .. },
64                    ..
65                },
66                Matrix {
67                    nrows: n2,
68                    data: b,
69                    storage:
70                        MatrixStorage::Banded {
71                            ml: ml2, mu: mu2, ..
72                        },
73                    ..
74                },
75            ) => {
76                assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
77                let ml_out = ml.max(ml2);
78                let mu_out = mu.max(mu2);
79                let rows_out = ml_out + mu_out + 1;
80                let mut out = Matrix {
81                    nrows: n,
82                    ncols: n,
83                    data: vec![T::zero(); rows_out * n],
84                    storage: MatrixStorage::Banded {
85                        ml: ml_out,
86                        mu: mu_out,
87                        zero: T::zero(),
88                    },
89                };
90                // Add first banded
91                for j in 0..n {
92                    for r in 0..(ml + mu + 1) {
93                        let k = r as isize - mu as isize; // i - j for first
94                        let i_signed = j as isize + k;
95                        if i_signed >= 0 && (i_signed as usize) < n {
96                            let row_out = (k + mu_out as isize) as usize;
97                            out.data[row_out * n + j] = out.data[row_out * n + j] + a[r * n + j];
98                        }
99                    }
100                }
101                // Subtract second banded
102                for j in 0..n {
103                    for r in 0..(ml2 + mu2 + 1) {
104                        let k = r as isize - mu2 as isize; // i - j for second
105                        let i_signed = j as isize + k;
106                        if i_signed >= 0 && (i_signed as usize) < n {
107                            let row_out = (k + mu_out as isize) as usize;
108                            out.data[row_out * n + j] = out.data[row_out * n + j] - b[r * n + j];
109                        }
110                    }
111                }
112                out
113            }
114            // Mixed storage: densify
115            (
116                Matrix {
117                    nrows: n1,
118                    data: a,
119                    storage: sa,
120                    ..
121                },
122                Matrix {
123                    nrows: n2,
124                    data: b,
125                    storage: sb,
126                    ..
127                },
128            ) => {
129                assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
130                let to_full = |n: usize, data: Vec<T>, storage: MatrixStorage<T>| -> Vec<T> {
131                    match storage {
132                        MatrixStorage::Full => data,
133                        MatrixStorage::Identity => {
134                            let mut d = vec![T::zero(); n * n];
135                            for i in 0..n {
136                                d[i * n + i] = T::one();
137                            }
138                            d
139                        }
140                        MatrixStorage::Banded { ml, mu, .. } => {
141                            let mut d = vec![T::zero(); n * n];
142                            for j in 0..n {
143                                for r in 0..(ml + mu + 1) {
144                                    let k = r as isize - mu as isize; // i - j
145                                    let i_signed = j as isize + k;
146                                    if i_signed >= 0 && (i_signed as usize) < n {
147                                        let i = i_signed as usize;
148                                        d[i * n + j] = d[i * n + j] + data[r * n + j];
149                                    }
150                                }
151                            }
152                            d
153                        }
154                    }
155                };
156                let aa = to_full(n1, a, sa);
157                let bb = to_full(n2, b, sb);
158                let data = aa
159                    .into_iter()
160                    .zip(bb.into_iter())
161                    .map(|(x, y)| x - y)
162                    .collect();
163                Matrix {
164                    nrows: n1,
165                    ncols: n1,
166                    data,
167                    storage: MatrixStorage::Full,
168                }
169            }
170        }
171    }
172}
173
174// For scalars, use `component_sub`.
175
176// Sub-assign by value
177impl<T: Real> SubAssign<Matrix<T>> for Matrix<T> {
178    fn sub_assign(&mut self, rhs: Matrix<T>) {
179        let n = self.n();
180        let lhs = core::mem::replace(self, Matrix::zeros(n));
181        *self = lhs - rhs;
182    }
183}
184
185// Sub-assign by reference (clones rhs)
186impl<T: Real> SubAssign<&Matrix<T>> for Matrix<T> {
187    fn sub_assign(&mut self, rhs: &Matrix<T>) {
188        let n = self.n();
189        let lhs = core::mem::replace(self, Matrix::zeros(n));
190        *self = lhs - rhs.clone();
191    }
192}
193
194impl<T: Real> Matrix<T> {
195    /// Return a new matrix where each stored entry has `rhs` subtracted. Off-band handling similar to add.
196    pub fn component_sub(mut self, rhs: T) -> Self {
197        match &mut self.storage {
198            MatrixStorage::Identity => {
199                let n = self.nrows;
200                let mut data = vec![T::zero() - rhs; n * n];
201                for i in 0..n {
202                    data[i * n + i] = T::one() - rhs;
203                }
204                Matrix {
205                    nrows: n,
206                    ncols: n,
207                    data,
208                    storage: MatrixStorage::Full,
209                }
210            }
211            MatrixStorage::Full => {
212                for v in &mut self.data {
213                    *v = *v - rhs;
214                }
215                self
216            }
217            MatrixStorage::Banded { ml, mu, .. } => {
218                let n = self.nrows;
219                if rhs == T::zero() {
220                    self
221                } else {
222                    let rows = *ml + *mu + 1;
223                    let mut dense = vec![T::zero() - rhs; n * n];
224                    for j in 0..n {
225                        for r in 0..rows {
226                            let k = r as isize - *mu as isize;
227                            let i_signed = j as isize + k;
228                            if i_signed >= 0 && (i_signed as usize) < n {
229                                let i = i_signed as usize;
230                                let val = self.data[r * n + j];
231                                dense[i * n + j] = val - rhs;
232                            }
233                        }
234                    }
235                    Matrix {
236                        nrows: n,
237                        ncols: n,
238                        data: dense,
239                        storage: MatrixStorage::Full,
240                    }
241                }
242            }
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use crate::linalg::matrix::Matrix;
250
251    #[test]
252    fn sub_scalar_full() {
253        let m: Matrix<f64> = Matrix::full(2, vec![1.0, 2.0, 3.0, 4.0]);
254        let r = m.component_sub(1.0);
255        assert_eq!(r[(0, 0)], 0.0);
256        assert_eq!(r[(0, 1)], 1.0);
257        assert_eq!(r[(1, 0)], 2.0);
258        assert_eq!(r[(1, 1)], 3.0);
259    }
260
261    #[test]
262    fn sub_scalar_banded_zero_keeps_banded() {
263        let m: Matrix<f64> = Matrix::banded(3, 1, 1);
264        let r = m.component_sub(0.0);
265        for i in 0..3 {
266            for j in 0..3 {
267                assert_eq!(r[(i, j)], 0.0);
268            }
269        }
270    }
271
272    #[test]
273    fn sub_matrix_full_full() {
274        let a: Matrix<f64> = Matrix::full(2, vec![1.0, 2.0, 3.0, 4.0]);
275        let b: Matrix<f64> = Matrix::full(2, vec![4.0, 3.0, 2.0, 1.0]);
276        let r = a - b;
277        assert_eq!(r[(0, 0)], -3.0);
278        assert_eq!(r[(0, 1)], -1.0);
279        assert_eq!(r[(1, 0)], 1.0);
280        assert_eq!(r[(1, 1)], 3.0);
281    }
282
283    #[test]
284    fn sub_matrix_banded_banded() {
285        // 3x3, ml=1, mu=0 and 0,1
286        let mut a: Matrix<f64> = Matrix::banded(3, 1, 0);
287        let mut b: Matrix<f64> = Matrix::banded(3, 0, 1);
288        // set a main and lower
289        a[(0, 0)] = 1.0;
290        a[(1, 1)] = 1.0;
291        a[(2, 2)] = 1.0;
292        a[(1, 0)] = 1.0;
293        a[(2, 1)] = 1.0;
294        // set b main and upper
295        b[(0, 0)] = 2.0;
296        b[(1, 1)] = 2.0;
297        b[(2, 2)] = 2.0;
298        b[(0, 1)] = 2.0;
299        b[(1, 2)] = 2.0;
300        let r = a - b;
301        // Check entries of the resulting tri-diagonal
302        assert_eq!(r[(0, 0)], -1.0);
303        assert_eq!(r[(1, 1)], -1.0);
304        assert_eq!(r[(2, 2)], -1.0);
305        assert_eq!(r[(1, 0)], 1.0);
306        assert_eq!(r[(2, 1)], 1.0);
307        assert_eq!(r[(0, 1)], -2.0);
308        assert_eq!(r[(1, 2)], -2.0);
309        assert_eq!(r[(0, 2)], 0.0);
310    }
311}