1use crate::{
4 assert, join_raw,
5 mul::triangular::{self, BlockStructure},
6 solve, ComplexField, MatMut, MatRef, Parallelism,
7};
8use reborrow::*;
9
10unsafe fn invert_lower_triangular_impl_small<E: ComplexField>(
11 mut dst: MatMut<'_, E>,
12 src: MatRef<'_, E>,
13) {
14 let m = dst.nrows();
15 let src = {
16 #[inline(always)]
17 |i: usize, j: usize| src.read_unchecked(i, j)
18 };
19 match m {
20 0 => {}
21 1 => dst.write_unchecked(0, 0, src(0, 0).faer_inv()),
22 2 => {
23 let dst00 = src(0, 0).faer_inv();
24 let dst11 = src(1, 1).faer_inv();
25 let dst10 = (dst11.faer_mul(src(1, 0)).faer_mul(dst00)).faer_neg();
26
27 dst.write_unchecked(0, 0, dst00);
28 dst.write_unchecked(1, 1, dst11);
29 dst.write_unchecked(1, 0, dst10);
30 }
31 _ => unreachable!(),
32 }
33}
34
35unsafe fn invert_unit_lower_triangular_impl_small<E: ComplexField>(
36 mut dst: MatMut<'_, E>,
37 src: MatRef<'_, E>,
38) {
39 let m = dst.nrows();
40 let src = |i: usize, j: usize| src.read_unchecked(i, j);
41 match m {
42 0 | 1 => {}
43 2 => {
44 dst.write_unchecked(1, 0, src(1, 0).faer_neg());
45 }
46 _ => unreachable!(),
47 }
48}
49
50unsafe fn invert_lower_triangular_impl<E: ComplexField>(
51 dst: MatMut<'_, E>,
52 src: MatRef<'_, E>,
53 parallelism: Parallelism,
54) {
55 let m = dst.nrows();
57 let n = dst.ncols();
58
59 if m <= 2 {
60 invert_lower_triangular_impl_small(dst, src);
61 return;
62 }
63
64 let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_at_mut(m / 2, n / 2) };
65
66 let m = src.nrows();
67 let n = src.ncols();
68 let (src_tl, _, src_bl, src_br) = { src.split_at(m / 2, n / 2) };
69
70 join_raw(
71 |parallelism| invert_lower_triangular_impl(dst_tl.rb_mut(), src_tl, parallelism),
72 |parallelism| invert_lower_triangular_impl(dst_br.rb_mut(), src_br, parallelism),
73 parallelism,
74 );
75
76 triangular::matmul(
77 dst_bl.rb_mut(),
78 BlockStructure::Rectangular,
79 src_bl,
80 BlockStructure::Rectangular,
81 dst_tl.rb(),
82 BlockStructure::TriangularLower,
83 None,
84 E::faer_one().faer_neg(),
85 parallelism,
86 );
87 solve::solve_lower_triangular_in_place(src_br, dst_bl, parallelism);
88}
89
90unsafe fn invert_unit_lower_triangular_impl<E: ComplexField>(
91 dst: MatMut<'_, E>,
92 src: MatRef<'_, E>,
93 parallelism: Parallelism,
94) {
95 let m = dst.nrows();
97 let n = dst.ncols();
98
99 if m <= 2 {
100 invert_unit_lower_triangular_impl_small(dst, src);
101 return;
102 }
103
104 let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_at_mut(m / 2, n / 2) };
105
106 let m = src.nrows();
107 let n = src.ncols();
108 let (src_tl, _, src_bl, src_br) = { src.split_at(m / 2, n / 2) };
109
110 join_raw(
111 |parallelism| invert_unit_lower_triangular_impl(dst_tl.rb_mut(), src_tl, parallelism),
112 |parallelism| invert_unit_lower_triangular_impl(dst_br.rb_mut(), src_br, parallelism),
113 parallelism,
114 );
115
116 triangular::matmul(
117 dst_bl.rb_mut(),
118 BlockStructure::Rectangular,
119 src_bl,
120 BlockStructure::Rectangular,
121 dst_tl.rb(),
122 BlockStructure::UnitTriangularLower,
123 None,
124 E::faer_one().faer_neg(),
125 parallelism,
126 );
127 solve::solve_unit_lower_triangular_in_place(src_br, dst_bl, parallelism);
128}
129
130#[track_caller]
137pub fn invert_unit_lower_triangular<E: ComplexField>(
138 dst: MatMut<'_, E>,
139 src: MatRef<'_, E>,
140 parallelism: Parallelism,
141) {
142 assert!(all(
143 dst.nrows() == src.nrows(),
144 dst.ncols() == src.ncols(),
145 dst.nrows() == dst.ncols()
146 ));
147
148 unsafe { invert_unit_lower_triangular_impl(dst, src, parallelism) }
149}
150
151#[track_caller]
158pub fn invert_lower_triangular<E: ComplexField>(
159 dst: MatMut<'_, E>,
160 src: MatRef<'_, E>,
161 parallelism: Parallelism,
162) {
163 assert!(all(
164 dst.nrows() == src.nrows(),
165 dst.ncols() == src.ncols(),
166 dst.nrows() == dst.ncols()
167 ));
168
169 unsafe { invert_lower_triangular_impl(dst, src, parallelism) }
170}
171
172#[track_caller]
179pub fn invert_unit_upper_triangular<E: ComplexField>(
180 dst: MatMut<'_, E>,
181 src: MatRef<'_, E>,
182 parallelism: Parallelism,
183) {
184 invert_unit_lower_triangular(
185 dst.reverse_rows_and_cols_mut(),
186 src.reverse_rows_and_cols(),
187 parallelism,
188 )
189}
190
191#[track_caller]
198pub fn invert_upper_triangular<E: ComplexField>(
199 dst: MatMut<'_, E>,
200 src: MatRef<'_, E>,
201 parallelism: Parallelism,
202) {
203 invert_lower_triangular(
204 dst.reverse_rows_and_cols_mut(),
205 src.reverse_rows_and_cols(),
206 parallelism,
207 )
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::{assert, Mat};
214 use assert_approx_eq::assert_approx_eq;
215 use rand::random;
216
217 #[test]
218 fn test_invert_lower() {
219 (0..32).for_each(|n| {
220 let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
221 let mut inv = Mat::zeros(n, n);
222 invert_lower_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
223
224 let mut prod = Mat::zeros(n, n);
225 triangular::matmul(
226 prod.as_mut(),
227 BlockStructure::Rectangular,
228 a.as_ref(),
229 BlockStructure::TriangularLower,
230 inv.as_ref(),
231 BlockStructure::TriangularLower,
232 None,
233 1.0,
234 Parallelism::Rayon(0),
235 );
236
237 for i in 0..n {
238 for j in 0..n {
239 let target = if i == j { 1.0 } else { 0.0 };
240 assert_approx_eq!(prod.read(i, j), target, 1e-4);
241 }
242 }
243 });
244 }
245
246 #[test]
247 fn test_invert_unit_lower() {
248 (0..32).for_each(|n| {
249 let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
250 let mut inv = Mat::zeros(n, n);
251 invert_unit_lower_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
252
253 let mut prod = Mat::zeros(n, n);
254 triangular::matmul(
255 prod.as_mut(),
256 BlockStructure::Rectangular,
257 a.as_ref(),
258 BlockStructure::UnitTriangularLower,
259 inv.as_ref(),
260 BlockStructure::UnitTriangularLower,
261 None,
262 1.0,
263 Parallelism::Rayon(0),
264 );
265 for i in 0..n {
266 for j in 0..n {
267 let target = if i == j { 1.0 } else { 0.0 };
268 assert_approx_eq!(prod.read(i, j), target, 1e-4);
269 }
270 }
271 });
272 }
273
274 #[test]
275 fn test_invert_upper() {
276 (0..32).for_each(|n| {
277 let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
278 let mut inv = Mat::zeros(n, n);
279 invert_upper_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
280
281 let mut prod = Mat::zeros(n, n);
282 triangular::matmul(
283 prod.as_mut(),
284 BlockStructure::Rectangular,
285 a.as_ref(),
286 BlockStructure::TriangularUpper,
287 inv.as_ref(),
288 BlockStructure::TriangularUpper,
289 None,
290 1.0,
291 Parallelism::Rayon(0),
292 );
293 for i in 0..n {
294 for j in 0..n {
295 let target = if i == j { 1.0 } else { 0.0 };
296 assert_approx_eq!(prod.read(i, j), target, 1e-4);
297 }
298 }
299 });
300 }
301
302 #[test]
303 fn test_invert_unit_upper() {
304 (0..32).for_each(|n| {
305 let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
306 let mut inv = Mat::zeros(n, n);
307 invert_unit_upper_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
308
309 let mut prod = Mat::zeros(n, n);
310 triangular::matmul(
311 prod.as_mut(),
312 BlockStructure::Rectangular,
313 a.as_ref(),
314 BlockStructure::UnitTriangularUpper,
315 inv.as_ref(),
316 BlockStructure::UnitTriangularUpper,
317 None,
318 1.0,
319 Parallelism::Rayon(0),
320 );
321 for i in 0..n {
322 for j in 0..n {
323 let target = if i == j { 1.0 } else { 0.0 };
324 assert_approx_eq!(prod.read(i, j), target, 1e-4);
325 }
326 }
327 });
328 }
329}