differential_equations/linalg/matrix/
sub.rs

1//! Matrix subtraction.
2
3use core::ops::{Sub, SubAssign};
4
5use super::base::{Matrix, MatrixStorage};
6use crate::traits::Real;
7
8// Matrix - Matrix
9impl<T: Real> Sub for Matrix<T> {
10    type Output = Matrix<T>;
11
12    fn sub(self, rhs: Matrix<T>) -> Self::Output {
13        match (self, rhs) {
14            (
15                Matrix {
16                    n: n1,
17                    storage: MatrixStorage::Identity,
18                    ..
19                },
20                Matrix {
21                    n: n2,
22                    storage: MatrixStorage::Identity,
23                    ..
24                },
25            ) => {
26                assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
27                Matrix {
28                    n: n1,
29                    m: n1,
30                    data: vec![T::zero(); n1 * n1],
31                    storage: MatrixStorage::Full,
32                }
33            }
34            (
35                Matrix {
36                    n,
37                    data: mut a,
38                    storage: MatrixStorage::Full,
39                    ..
40                },
41                Matrix {
42                    n: n2,
43                    data: b,
44                    storage: MatrixStorage::Full,
45                    ..
46                },
47            ) => {
48                assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
49                for (x, y) in a.iter_mut().zip(b.iter()) {
50                    *x -= *y;
51                }
52                Matrix {
53                    n,
54                    m: n,
55                    data: a,
56                    storage: MatrixStorage::Full,
57                }
58            }
59            (
60                Matrix {
61                    n,
62                    data: a,
63                    storage: MatrixStorage::Banded { ml, mu, .. },
64                    ..
65                },
66                Matrix {
67                    n: n2,
68                    data: b,
69                    storage:
70                        MatrixStorage::Banded {
71                            ml: ml2, mu: mu2, ..
72                        },
73                    ..
74                },
75            ) => {
76                assert_eq!(n, n2, "dimension mismatch in Matrix - Matrix");
77                let ml_out = ml.max(ml2);
78                let mu_out = mu.max(mu2);
79                let rows_out = ml_out + mu_out + 1;
80                let mut out = Matrix {
81                    n,
82                    m: n,
83                    data: vec![T::zero(); rows_out * n],
84                    storage: MatrixStorage::Banded {
85                        ml: ml_out,
86                        mu: mu_out,
87                        zero: T::zero(),
88                    },
89                };
90                // Add first banded
91                for j in 0..n {
92                    for r in 0..(ml + mu + 1) {
93                        let k = r as isize - mu as isize; // i - j for first
94                        let i_signed = j as isize + k;
95                        if i_signed >= 0 && (i_signed as usize) < n {
96                            let row_out = (k + mu_out as isize) as usize;
97                            out.data[row_out * n + j] += a[r * n + j];
98                        }
99                    }
100                }
101                // Subtract second banded
102                for j in 0..n {
103                    for r in 0..(ml2 + mu2 + 1) {
104                        let k = r as isize - mu2 as isize; // i - j for second
105                        let i_signed = j as isize + k;
106                        if i_signed >= 0 && (i_signed as usize) < n {
107                            let row_out = (k + mu_out as isize) as usize;
108                            out.data[row_out * n + j] -= b[r * n + j];
109                        }
110                    }
111                }
112                out
113            }
114            // Mixed storage: densify
115            (
116                Matrix {
117                    n: n1,
118                    data: a,
119                    storage: sa,
120                    ..
121                },
122                Matrix {
123                    n: n2,
124                    data: b,
125                    storage: sb,
126                    ..
127                },
128            ) => {
129                assert_eq!(n1, n2, "dimension mismatch in Matrix - Matrix");
130                let to_full = |n: usize, data: Vec<T>, storage: MatrixStorage<T>| -> Vec<T> {
131                    match storage {
132                        MatrixStorage::Full => data,
133                        MatrixStorage::Identity => {
134                            let mut d = vec![T::zero(); n * n];
135                            for i in 0..n {
136                                d[i * n + i] = T::one();
137                            }
138                            d
139                        }
140                        MatrixStorage::Banded { ml, mu, .. } => {
141                            let mut d = vec![T::zero(); n * n];
142                            for j in 0..n {
143                                for r in 0..(ml + mu + 1) {
144                                    let k = r as isize - mu as isize; // i - j
145                                    let i_signed = j as isize + k;
146                                    if i_signed >= 0 && (i_signed as usize) < n {
147                                        let i = i_signed as usize;
148                                        d[i * n + j] += data[r * n + j];
149                                    }
150                                }
151                            }
152                            d
153                        }
154                    }
155                };
156                let aa = to_full(n1, a, sa);
157                let bb = to_full(n2, b, sb);
158                let data = aa.into_iter().zip(bb).map(|(x, y)| x - y).collect();
159                Matrix {
160                    n: n1,
161                    m: n1,
162                    data,
163                    storage: MatrixStorage::Full,
164                }
165            }
166        }
167    }
168}
169
170// For scalars, use `component_sub`.
171
172// Sub-assign by value
173impl<T: Real> SubAssign<Matrix<T>> for Matrix<T> {
174    fn sub_assign(&mut self, rhs: Matrix<T>) {
175        let n = self.n;
176        let lhs = core::mem::replace(self, Matrix::zeros(n, n));
177        *self = lhs - rhs;
178    }
179}
180
181// Sub-assign by reference (clones rhs)
182impl<T: Real> SubAssign<&Matrix<T>> for Matrix<T> {
183    fn sub_assign(&mut self, rhs: &Matrix<T>) {
184        let n = self.n;
185        let lhs = core::mem::replace(self, Matrix::zeros(n, n));
186        *self = lhs - rhs.clone();
187    }
188}
189
190impl<T: Real> Matrix<T> {
191    /// Return a new matrix where each stored entry has `rhs` subtracted. Off-band handling similar to add.
192    pub fn component_sub(mut self, rhs: T) -> Self {
193        match &mut self.storage {
194            MatrixStorage::Identity => {
195                let n = self.n;
196                let mut data = vec![T::zero() - rhs; n * n];
197                for i in 0..n {
198                    data[i * n + i] = T::one() - rhs;
199                }
200                Matrix {
201                    n,
202                    m: n,
203                    data,
204                    storage: MatrixStorage::Full,
205                }
206            }
207            MatrixStorage::Full => {
208                for v in &mut self.data {
209                    *v -= rhs;
210                }
211                self
212            }
213            MatrixStorage::Banded { ml, mu, .. } => {
214                let n = self.n;
215                if rhs == T::zero() {
216                    self
217                } else {
218                    let rows = *ml + *mu + 1;
219                    let mut dense = vec![T::zero() - rhs; n * n];
220                    for j in 0..n {
221                        for r in 0..rows {
222                            let k = r as isize - *mu as isize;
223                            let i_signed = j as isize + k;
224                            if i_signed >= 0 && (i_signed as usize) < n {
225                                let i = i_signed as usize;
226                                let val = self.data[r * n + j];
227                                dense[i * n + j] = val - rhs;
228                            }
229                        }
230                    }
231                    Matrix {
232                        n,
233                        m: n,
234                        data: dense,
235                        storage: MatrixStorage::Full,
236                    }
237                }
238            }
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use crate::linalg::matrix::Matrix;
246
247    #[test]
248    fn sub_scalar_full() {
249        let m: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
250        let r = m.component_sub(1.0);
251        assert_eq!(r[(0, 0)], 0.0);
252        assert_eq!(r[(0, 1)], 1.0);
253        assert_eq!(r[(1, 0)], 2.0);
254        assert_eq!(r[(1, 1)], 3.0);
255    }
256
257    #[test]
258    fn sub_scalar_banded_zero_keeps_banded() {
259        let m: Matrix<f64> = Matrix::banded(3, 1, 1);
260        let r = m.component_sub(0.0);
261        for i in 0..3 {
262            for j in 0..3 {
263                assert_eq!(r[(i, j)], 0.0);
264            }
265        }
266    }
267
268    #[test]
269    fn sub_matrix_full_full() {
270        let a: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
271        let b: Matrix<f64> = Matrix::from_vec(2, 2, vec![4.0, 3.0, 2.0, 1.0]);
272        let r = a - b;
273        assert_eq!(r[(0, 0)], -3.0);
274        assert_eq!(r[(0, 1)], -1.0);
275        assert_eq!(r[(1, 0)], 1.0);
276        assert_eq!(r[(1, 1)], 3.0);
277    }
278
279    #[test]
280    fn sub_matrix_banded_banded() {
281        // 3x3, ml=1, mu=0 and 0,1
282        let mut a: Matrix<f64> = Matrix::banded(3, 1, 0);
283        let mut b: Matrix<f64> = Matrix::banded(3, 0, 1);
284        // set a main and lower
285        a[(0, 0)] = 1.0;
286        a[(1, 1)] = 1.0;
287        a[(2, 2)] = 1.0;
288        a[(1, 0)] = 1.0;
289        a[(2, 1)] = 1.0;
290        // set b main and upper
291        b[(0, 0)] = 2.0;
292        b[(1, 1)] = 2.0;
293        b[(2, 2)] = 2.0;
294        b[(0, 1)] = 2.0;
295        b[(1, 2)] = 2.0;
296        let r = a - b;
297        // Check entries of the resulting tri-diagonal
298        assert_eq!(r[(0, 0)], -1.0);
299        assert_eq!(r[(1, 1)], -1.0);
300        assert_eq!(r[(2, 2)], -1.0);
301        assert_eq!(r[(1, 0)], 1.0);
302        assert_eq!(r[(2, 1)], 1.0);
303        assert_eq!(r[(0, 1)], -2.0);
304        assert_eq!(r[(1, 2)], -2.0);
305        assert_eq!(r[(0, 2)], 0.0);
306    }
307}