faer_core/
inverse.rs

1//! Triangular matrix inversion.
2
3use crate::{
4    assert, join_raw,
5    mul::triangular::{self, BlockStructure},
6    solve, ComplexField, MatMut, MatRef, Parallelism,
7};
8use reborrow::*;
9
10unsafe fn invert_lower_triangular_impl_small<E: ComplexField>(
11    mut dst: MatMut<'_, E>,
12    src: MatRef<'_, E>,
13) {
14    let m = dst.nrows();
15    let src = {
16        #[inline(always)]
17        |i: usize, j: usize| src.read_unchecked(i, j)
18    };
19    match m {
20        0 => {}
21        1 => dst.write_unchecked(0, 0, src(0, 0).faer_inv()),
22        2 => {
23            let dst00 = src(0, 0).faer_inv();
24            let dst11 = src(1, 1).faer_inv();
25            let dst10 = (dst11.faer_mul(src(1, 0)).faer_mul(dst00)).faer_neg();
26
27            dst.write_unchecked(0, 0, dst00);
28            dst.write_unchecked(1, 1, dst11);
29            dst.write_unchecked(1, 0, dst10);
30        }
31        _ => unreachable!(),
32    }
33}
34
35unsafe fn invert_unit_lower_triangular_impl_small<E: ComplexField>(
36    mut dst: MatMut<'_, E>,
37    src: MatRef<'_, E>,
38) {
39    let m = dst.nrows();
40    let src = |i: usize, j: usize| src.read_unchecked(i, j);
41    match m {
42        0 | 1 => {}
43        2 => {
44            dst.write_unchecked(1, 0, src(1, 0).faer_neg());
45        }
46        _ => unreachable!(),
47    }
48}
49
50unsafe fn invert_lower_triangular_impl<E: ComplexField>(
51    dst: MatMut<'_, E>,
52    src: MatRef<'_, E>,
53    parallelism: Parallelism,
54) {
55    // m must be equal to n
56    let m = dst.nrows();
57    let n = dst.ncols();
58
59    if m <= 2 {
60        invert_lower_triangular_impl_small(dst, src);
61        return;
62    }
63
64    let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_at_mut(m / 2, n / 2) };
65
66    let m = src.nrows();
67    let n = src.ncols();
68    let (src_tl, _, src_bl, src_br) = { src.split_at(m / 2, n / 2) };
69
70    join_raw(
71        |parallelism| invert_lower_triangular_impl(dst_tl.rb_mut(), src_tl, parallelism),
72        |parallelism| invert_lower_triangular_impl(dst_br.rb_mut(), src_br, parallelism),
73        parallelism,
74    );
75
76    triangular::matmul(
77        dst_bl.rb_mut(),
78        BlockStructure::Rectangular,
79        src_bl,
80        BlockStructure::Rectangular,
81        dst_tl.rb(),
82        BlockStructure::TriangularLower,
83        None,
84        E::faer_one().faer_neg(),
85        parallelism,
86    );
87    solve::solve_lower_triangular_in_place(src_br, dst_bl, parallelism);
88}
89
90unsafe fn invert_unit_lower_triangular_impl<E: ComplexField>(
91    dst: MatMut<'_, E>,
92    src: MatRef<'_, E>,
93    parallelism: Parallelism,
94) {
95    // m must be equal to n
96    let m = dst.nrows();
97    let n = dst.ncols();
98
99    if m <= 2 {
100        invert_unit_lower_triangular_impl_small(dst, src);
101        return;
102    }
103
104    let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_at_mut(m / 2, n / 2) };
105
106    let m = src.nrows();
107    let n = src.ncols();
108    let (src_tl, _, src_bl, src_br) = { src.split_at(m / 2, n / 2) };
109
110    join_raw(
111        |parallelism| invert_unit_lower_triangular_impl(dst_tl.rb_mut(), src_tl, parallelism),
112        |parallelism| invert_unit_lower_triangular_impl(dst_br.rb_mut(), src_br, parallelism),
113        parallelism,
114    );
115
116    triangular::matmul(
117        dst_bl.rb_mut(),
118        BlockStructure::Rectangular,
119        src_bl,
120        BlockStructure::Rectangular,
121        dst_tl.rb(),
122        BlockStructure::UnitTriangularLower,
123        None,
124        E::faer_one().faer_neg(),
125        parallelism,
126    );
127    solve::solve_unit_lower_triangular_in_place(src_br, dst_bl, parallelism);
128}
129
130/// Computes the inverse of the lower triangular matrix `src` (with implicit unit
131/// diagonal) and stores the strictly lower triangular part of the result to `dst`.
132///
133/// # Panics
134///
135/// Panics if `src` and `dst` have mismatching dimensions, or if they are not square.
136#[track_caller]
137pub fn invert_unit_lower_triangular<E: ComplexField>(
138    dst: MatMut<'_, E>,
139    src: MatRef<'_, E>,
140    parallelism: Parallelism,
141) {
142    assert!(all(
143        dst.nrows() == src.nrows(),
144        dst.ncols() == src.ncols(),
145        dst.nrows() == dst.ncols()
146    ));
147
148    unsafe { invert_unit_lower_triangular_impl(dst, src, parallelism) }
149}
150
151/// Computes the inverse of the lower triangular matrix `src` and stores the
152/// lower triangular part of the result to `dst`.
153///
154/// # Panics
155///
156/// Panics if `src` and `dst` have mismatching dimensions, or if they are not square.
157#[track_caller]
158pub fn invert_lower_triangular<E: ComplexField>(
159    dst: MatMut<'_, E>,
160    src: MatRef<'_, E>,
161    parallelism: Parallelism,
162) {
163    assert!(all(
164        dst.nrows() == src.nrows(),
165        dst.ncols() == src.ncols(),
166        dst.nrows() == dst.ncols()
167    ));
168
169    unsafe { invert_lower_triangular_impl(dst, src, parallelism) }
170}
171
172/// Computes the inverse of the upper triangular matrix `src` (with implicit unit
173/// diagonal) and stores the strictly upper triangular part of the result to `dst`.
174///
175/// # Panics
176///
177/// Panics if `src` and `dst` have mismatching dimensions, or if they are not square.
178#[track_caller]
179pub fn invert_unit_upper_triangular<E: ComplexField>(
180    dst: MatMut<'_, E>,
181    src: MatRef<'_, E>,
182    parallelism: Parallelism,
183) {
184    invert_unit_lower_triangular(
185        dst.reverse_rows_and_cols_mut(),
186        src.reverse_rows_and_cols(),
187        parallelism,
188    )
189}
190
191/// Computes the inverse of the upper triangular matrix `src` and stores the
192/// upper triangular part of the result to `dst`.
193///
194/// # Panics
195///
196/// Panics if `src` and `dst` have mismatching dimensions, or if they are not square.
197#[track_caller]
198pub fn invert_upper_triangular<E: ComplexField>(
199    dst: MatMut<'_, E>,
200    src: MatRef<'_, E>,
201    parallelism: Parallelism,
202) {
203    invert_lower_triangular(
204        dst.reverse_rows_and_cols_mut(),
205        src.reverse_rows_and_cols(),
206        parallelism,
207    )
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::{assert, Mat};
214    use assert_approx_eq::assert_approx_eq;
215    use rand::random;
216
217    #[test]
218    fn test_invert_lower() {
219        (0..32).for_each(|n| {
220            let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
221            let mut inv = Mat::zeros(n, n);
222            invert_lower_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
223
224            let mut prod = Mat::zeros(n, n);
225            triangular::matmul(
226                prod.as_mut(),
227                BlockStructure::Rectangular,
228                a.as_ref(),
229                BlockStructure::TriangularLower,
230                inv.as_ref(),
231                BlockStructure::TriangularLower,
232                None,
233                1.0,
234                Parallelism::Rayon(0),
235            );
236
237            for i in 0..n {
238                for j in 0..n {
239                    let target = if i == j { 1.0 } else { 0.0 };
240                    assert_approx_eq!(prod.read(i, j), target, 1e-4);
241                }
242            }
243        });
244    }
245
246    #[test]
247    fn test_invert_unit_lower() {
248        (0..32).for_each(|n| {
249            let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
250            let mut inv = Mat::zeros(n, n);
251            invert_unit_lower_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
252
253            let mut prod = Mat::zeros(n, n);
254            triangular::matmul(
255                prod.as_mut(),
256                BlockStructure::Rectangular,
257                a.as_ref(),
258                BlockStructure::UnitTriangularLower,
259                inv.as_ref(),
260                BlockStructure::UnitTriangularLower,
261                None,
262                1.0,
263                Parallelism::Rayon(0),
264            );
265            for i in 0..n {
266                for j in 0..n {
267                    let target = if i == j { 1.0 } else { 0.0 };
268                    assert_approx_eq!(prod.read(i, j), target, 1e-4);
269                }
270            }
271        });
272    }
273
274    #[test]
275    fn test_invert_upper() {
276        (0..32).for_each(|n| {
277            let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
278            let mut inv = Mat::zeros(n, n);
279            invert_upper_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
280
281            let mut prod = Mat::zeros(n, n);
282            triangular::matmul(
283                prod.as_mut(),
284                BlockStructure::Rectangular,
285                a.as_ref(),
286                BlockStructure::TriangularUpper,
287                inv.as_ref(),
288                BlockStructure::TriangularUpper,
289                None,
290                1.0,
291                Parallelism::Rayon(0),
292            );
293            for i in 0..n {
294                for j in 0..n {
295                    let target = if i == j { 1.0 } else { 0.0 };
296                    assert_approx_eq!(prod.read(i, j), target, 1e-4);
297                }
298            }
299        });
300    }
301
302    #[test]
303    fn test_invert_unit_upper() {
304        (0..32).for_each(|n| {
305            let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
306            let mut inv = Mat::zeros(n, n);
307            invert_unit_upper_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
308
309            let mut prod = Mat::zeros(n, n);
310            triangular::matmul(
311                prod.as_mut(),
312                BlockStructure::Rectangular,
313                a.as_ref(),
314                BlockStructure::UnitTriangularUpper,
315                inv.as_ref(),
316                BlockStructure::UnitTriangularUpper,
317                None,
318                1.0,
319                Parallelism::Rayon(0),
320            );
321            for i in 0..n {
322                for j in 0..n {
323                    let target = if i == j { 1.0 } else { 0.0 };
324                    assert_approx_eq!(prod.read(i, j), target, 1e-4);
325                }
326            }
327        });
328    }
329}