faer_core/
solve.rs

1//! Triangular solve module.
2
3use crate::{
4    assert, debug_assert, join_raw, unzipped, zipped, ComplexField, Conj, Conjugate, MatMut,
5    MatRef, Parallelism,
6};
7use faer_entity::SimdCtx;
8use reborrow::*;
9
10#[inline(always)]
11fn identity<E: Clone>(x: E) -> E {
12    x.clone()
13}
14
15#[inline(always)]
16fn conj<E: ComplexField>(x: E) -> E {
17    x.faer_conj()
18}
19
20#[inline(always)]
21unsafe fn solve_unit_lower_triangular_in_place_base_case_generic_unchecked<E: ComplexField>(
22    tril: MatRef<'_, E>,
23    rhs: MatMut<'_, E>,
24    maybe_conj_lhs: impl Fn(E) -> E,
25) {
26    let n = tril.nrows();
27    match n {
28        0 | 1 => (),
29        2 => {
30            let nl10_div_l11 = maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_neg();
31
32            let (_, x0, _, x1) = rhs.split_at_mut(1, 0);
33            let x0 = x0.subrows_mut(0, 1);
34            let x1 = x1.subrows_mut(0, 1);
35
36            zipped!(x0, x1).for_each(|unzipped!(x0, mut x1)| {
37                x1.write(x1.read().faer_add(nl10_div_l11.faer_mul(x0.read())));
38            });
39        }
40        3 => {
41            let nl10_div_l11 = maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_neg();
42            let nl20_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_neg();
43            let nl21_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_neg();
44
45            let (_, x0, _, x1_2) = rhs.split_at_mut(1, 0);
46            let (_, x1, _, x2) = x1_2.split_at_mut(1, 0);
47            let x0 = x0.subrows_mut(0, 1);
48            let x1 = x1.subrows_mut(0, 1);
49            let x2 = x2.subrows_mut(0, 1);
50
51            zipped!(x0, x1, x2).for_each(|unzipped!(mut x0, mut x1, mut x2)| {
52                let y0 = x0.read();
53                let mut y1 = x1.read();
54                let mut y2 = x2.read();
55                y1 = y1.faer_add(nl10_div_l11.faer_mul(y0));
56                y2 = y2
57                    .faer_add(nl20_div_l22.faer_mul(y0))
58                    .faer_add(nl21_div_l22.faer_mul(y1));
59                x0.write(y0);
60                x1.write(y1);
61                x2.write(y2);
62            });
63        }
64        4 => {
65            let nl10_div_l11 = maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_neg();
66            let nl20_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_neg();
67            let nl21_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_neg();
68            let nl30_div_l33 = maybe_conj_lhs(tril.read_unchecked(3, 0)).faer_neg();
69            let nl31_div_l33 = maybe_conj_lhs(tril.read_unchecked(3, 1)).faer_neg();
70            let nl32_div_l33 = maybe_conj_lhs(tril.read_unchecked(3, 2)).faer_neg();
71
72            let (_, x0, _, x1_2_3) = rhs.split_at_mut(1, 0);
73            let (_, x1, _, x2_3) = x1_2_3.split_at_mut(1, 0);
74            let (_, x2, _, x3) = x2_3.split_at_mut(1, 0);
75            let x0 = x0.subrows_mut(0, 1);
76            let x1 = x1.subrows_mut(0, 1);
77            let x2 = x2.subrows_mut(0, 1);
78            let x3 = x3.subrows_mut(0, 1);
79
80            zipped!(x0, x1, x2, x3).for_each(|unzipped!(mut x0, mut x1, mut x2, mut x3)| {
81                let y0 = x0.read();
82                let mut y1 = x1.read();
83                let mut y2 = x2.read();
84                let mut y3 = x3.read();
85                y1 = y1.faer_add(nl10_div_l11.faer_mul(y0));
86                y2 = y2.faer_add(
87                    nl20_div_l22
88                        .faer_mul(y0)
89                        .faer_add(nl21_div_l22.faer_mul(y1)),
90                );
91                y3 = (y3.faer_add(nl30_div_l33.faer_mul(y0))).faer_add(
92                    nl31_div_l33
93                        .faer_mul(y1)
94                        .faer_add(nl32_div_l33.faer_mul(y2)),
95                );
96                x0.write(y0);
97                x1.write(y1);
98                x2.write(y2);
99                x3.write(y3);
100            });
101        }
102        _ => unreachable!(),
103    }
104}
105
106#[inline(always)]
107unsafe fn solve_lower_triangular_in_place_base_case_generic_unchecked<E: ComplexField>(
108    tril: MatRef<'_, E>,
109    rhs: MatMut<'_, E>,
110    maybe_conj_lhs: impl Fn(E) -> E,
111) {
112    let n = tril.nrows();
113    match n {
114        0 => (),
115        1 => {
116            let inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
117            let x0 = rhs.subrows_mut(0, 1);
118            zipped!(x0).for_each(|unzipped!(mut x0)| x0.write(x0.read().faer_mul(inv)));
119        }
120        2 => {
121            let l00_inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
122            let l11_inv = maybe_conj_lhs(tril.read_unchecked(1, 1)).faer_inv();
123            let nl10_div_l11 =
124                (maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_mul(l11_inv)).faer_neg();
125
126            let (_, x0, _, x1) = rhs.split_at_mut(1, 0);
127            let x0 = x0.subrows_mut(0, 1);
128            let x1 = x1.subrows_mut(0, 1);
129
130            zipped!(x0, x1).for_each(|unzipped!(mut x0, mut x1)| {
131                x0.write(x0.read().faer_mul(l00_inv));
132                x1.write(
133                    x1.read()
134                        .faer_mul(l11_inv)
135                        .faer_add(nl10_div_l11.faer_mul(x0.read())),
136                );
137            });
138        }
139        3 => {
140            let l00_inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
141            let l11_inv = maybe_conj_lhs(tril.read_unchecked(1, 1)).faer_inv();
142            let l22_inv = maybe_conj_lhs(tril.read_unchecked(2, 2)).faer_inv();
143            let nl10_div_l11 =
144                (maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_mul(l11_inv)).faer_neg();
145            let nl20_div_l22 =
146                (maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_mul(l22_inv)).faer_neg();
147            let nl21_div_l22 =
148                (maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_mul(l22_inv)).faer_neg();
149
150            let (_, x0, _, x1_2) = rhs.split_at_mut(1, 0);
151            let (_, x1, _, x2) = x1_2.split_at_mut(1, 0);
152            let x0 = x0.subrows_mut(0, 1);
153            let x1 = x1.subrows_mut(0, 1);
154            let x2 = x2.subrows_mut(0, 1);
155
156            zipped!(x0, x1, x2).for_each(|unzipped!(mut x0, mut x1, mut x2)| {
157                let mut y0 = x0.read();
158                let mut y1 = x1.read();
159                let mut y2 = x2.read();
160                y0 = y0.faer_mul(l00_inv);
161                y1 = y1.faer_mul(l11_inv).faer_add(nl10_div_l11.faer_mul(y0));
162                y2 = y2
163                    .faer_mul(l22_inv)
164                    .faer_add(nl20_div_l22.faer_mul(y0))
165                    .faer_add(nl21_div_l22.faer_mul(y1));
166                x0.write(y0);
167                x1.write(y1);
168                x2.write(y2);
169            });
170        }
171        4 => {
172            let l00_inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
173            let l11_inv = maybe_conj_lhs(tril.read_unchecked(1, 1)).faer_inv();
174            let l22_inv = maybe_conj_lhs(tril.read_unchecked(2, 2)).faer_inv();
175            let l33_inv = maybe_conj_lhs(tril.read_unchecked(3, 3)).faer_inv();
176            let nl10_div_l11 =
177                (maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_mul(l11_inv)).faer_neg();
178            let nl20_div_l22 =
179                (maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_mul(l22_inv)).faer_neg();
180            let nl21_div_l22 =
181                (maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_mul(l22_inv)).faer_neg();
182            let nl30_div_l33 =
183                (maybe_conj_lhs(tril.read_unchecked(3, 0)).faer_mul(l33_inv)).faer_neg();
184            let nl31_div_l33 =
185                (maybe_conj_lhs(tril.read_unchecked(3, 1)).faer_mul(l33_inv)).faer_neg();
186            let nl32_div_l33 =
187                (maybe_conj_lhs(tril.read_unchecked(3, 2)).faer_mul(l33_inv)).faer_neg();
188
189            let (_, x0, _, x1_2_3) = rhs.split_at_mut(1, 0);
190            let (_, x1, _, x2_3) = x1_2_3.split_at_mut(1, 0);
191            let (_, x2, _, x3) = x2_3.split_at_mut(1, 0);
192            let x0 = x0.subrows_mut(0, 1);
193            let x1 = x1.subrows_mut(0, 1);
194            let x2 = x2.subrows_mut(0, 1);
195            let x3 = x3.subrows_mut(0, 1);
196
197            zipped!(x0, x1, x2, x3).for_each(|unzipped!(mut x0, mut x1, mut x2, mut x3)| {
198                let mut y0 = x0.read();
199                let mut y1 = x1.read();
200                let mut y2 = x2.read();
201                let mut y3 = x3.read();
202                y0 = y0.faer_mul(l00_inv);
203                y1 = y1.faer_mul(l11_inv).faer_add(nl10_div_l11.faer_mul(y0));
204                y2 = y2.faer_mul(l22_inv).faer_add(
205                    nl20_div_l22
206                        .faer_mul(y0)
207                        .faer_add(nl21_div_l22.faer_mul(y1)),
208                );
209                y3 = (y3.faer_mul(l33_inv).faer_add(nl30_div_l33.faer_mul(y0))).faer_add(
210                    nl31_div_l33
211                        .faer_mul(y1)
212                        .faer_add(nl32_div_l33.faer_mul(y2)),
213                );
214                x0.write(y0);
215                x1.write(y1);
216                x2.write(y2);
217                x3.write(y3);
218            });
219        }
220        _ => unreachable!(),
221    }
222}
223
224#[inline]
225fn blocksize(n: usize) -> usize {
226    // we want remainder to be a multiple of register size
227
228    let base_rem = n / 2;
229    n - if n >= 32 {
230        (base_rem + 15) / 16 * 16
231    } else if n >= 16 {
232        (base_rem + 7) / 8 * 8
233    } else if n >= 8 {
234        (base_rem + 3) / 4 * 4
235    } else {
236        base_rem
237    }
238}
239
240#[inline]
241fn recursion_threshold() -> usize {
242    4
243}
244
245/// Computes the solution of `Op_lhs(triangular_lower)×X = rhs`, and stores the result in
246/// `rhs`.
247///
248/// `triangular_lower` is interpreted as a lower triangular matrix (diagonal included).
249/// Its strictly upper triangular part is not accessed.
250///
251/// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is
252/// `Conj::Yes`.  
253///
254/// # Panics
255///
256///  - Panics if `triangular_lower` is not a square matrix.
257///  - Panics if `rhs.nrows() != triangular_lower.ncols()`
258///
259/// # Example
260///
261/// ```
262/// use faer_core::{
263///     mat,
264///     mul::triangular::{matmul, BlockStructure},
265///     solve::solve_lower_triangular_in_place_with_conj,
266///     unzipped, zipped, Conj, Mat, Parallelism,
267/// };
268///
269/// let m = mat![[1.0, 0.0], [2.0, 3.0]];
270/// let rhs = mat![[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
271///
272/// let mut sol = rhs.clone();
273/// solve_lower_triangular_in_place_with_conj(
274///     m.as_ref(),
275///     Conj::No,
276///     sol.as_mut(),
277///     Parallelism::None,
278/// );
279///
280/// let mut m_times_sol = Mat::<f64>::zeros(2, 3);
281/// matmul(
282///     m_times_sol.as_mut(),
283///     BlockStructure::Rectangular,
284///     m.as_ref(),
285///     BlockStructure::TriangularLower,
286///     sol.as_ref(),
287///     BlockStructure::Rectangular,
288///     None,
289///     1.0,
290///     Parallelism::None,
291/// );
292///
293/// zipped!(m_times_sol.as_ref(), rhs.as_ref())
294///     .for_each(|unzipped!(x, target)| assert!((x.read() - target.read()).abs() < 1e-10));
295/// ```
296#[track_caller]
297#[inline]
298pub fn solve_lower_triangular_in_place_with_conj<E: ComplexField>(
299    triangular_lower: MatRef<'_, E>,
300    conj_lhs: Conj,
301    rhs: MatMut<'_, E>,
302    parallelism: Parallelism,
303) {
304    assert!(all(
305        triangular_lower.nrows() == triangular_lower.ncols(),
306        rhs.nrows() == triangular_lower.ncols(),
307    ));
308
309    unsafe {
310        solve_lower_triangular_in_place_unchecked(triangular_lower, conj_lhs, rhs, parallelism);
311    }
312}
313
314/// Computes the solution of `triangular_lower×X = rhs`, and stores the result in
315/// `rhs`.
316///
317/// `triangular_lower` is interpreted as a lower triangular matrix (diagonal included).
318/// Its strictly upper triangular part is not accessed.
319#[track_caller]
320#[inline]
321pub fn solve_lower_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
322    triangular_lower: MatRef<'_, TriE>,
323    rhs: MatMut<'_, E>,
324    parallelism: Parallelism,
325) {
326    let (tri, conj) = triangular_lower.canonicalize();
327    solve_lower_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
328}
329
330/// Computes the solution of `Op_lhs(triangular_upper)×X = rhs`, and stores the result in
331/// `rhs`.
332///
333/// `triangular_upper` is interpreted as a upper triangular matrix (diagonal included).
334/// Its strictly lower triangular part is not accessed.
335///
336/// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is
337/// `Conj::Yes`.  
338///
339/// # Panics
340///
341///  - Panics if `triangular_upper` is not a square matrix.
342///  - Panics if `rhs.nrows() != triangular_lower.ncols()`
343///
344/// # Example
345///
346/// ```
347/// use faer_core::{
348///     mat,
349///     mul::triangular::{matmul, BlockStructure},
350///     solve::solve_upper_triangular_in_place_with_conj,
351///     unzipped, zipped, Conj, Mat, Parallelism,
352/// };
353///
354/// let m = mat![[1.0, 2.0], [0.0, 3.0]];
355/// let rhs = mat![[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
356///
357/// let mut sol = rhs.clone();
358/// solve_upper_triangular_in_place_with_conj(
359///     m.as_ref(),
360///     Conj::No,
361///     sol.as_mut(),
362///     Parallelism::None,
363/// );
364///
365/// let mut m_times_sol = Mat::<f64>::zeros(2, 3);
366/// matmul(
367///     m_times_sol.as_mut(),
368///     BlockStructure::Rectangular,
369///     m.as_ref(),
370///     BlockStructure::TriangularUpper,
371///     sol.as_ref(),
372///     BlockStructure::Rectangular,
373///     None,
374///     1.0,
375///     Parallelism::None,
376/// );
377///
378/// zipped!(m_times_sol.as_ref(), rhs.as_ref())
379///     .for_each(|unzipped!(x, target)| assert!((x.read() - target.read()).abs() < 1e-10));
380/// ```
381#[track_caller]
382#[inline]
383pub fn solve_upper_triangular_in_place_with_conj<E: ComplexField>(
384    triangular_upper: MatRef<'_, E>,
385    conj_lhs: Conj,
386    rhs: MatMut<'_, E>,
387    parallelism: Parallelism,
388) {
389    assert!(all(
390        triangular_upper.nrows() == triangular_upper.ncols(),
391        rhs.nrows() == triangular_upper.ncols(),
392    ));
393
394    unsafe {
395        solve_upper_triangular_in_place_unchecked(triangular_upper, conj_lhs, rhs, parallelism);
396    }
397}
398
399/// Computes the solution of `triangular_upper×X = rhs`, and stores the result in
400/// `rhs`.
401///
402/// `triangular_upper` is interpreted as a upper triangular matrix (diagonal included).
403/// Its strictly upper triangular part is not accessed.
404#[track_caller]
405#[inline]
406pub fn solve_upper_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
407    triangular_upper: MatRef<'_, TriE>,
408    rhs: MatMut<'_, E>,
409    parallelism: Parallelism,
410) {
411    let (tri, conj) = triangular_upper.canonicalize();
412    solve_upper_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
413}
414
415/// Computes the solution of `Op_lhs(triangular_lower)×X = rhs`, and stores the result in
416/// `rhs`.
417///
418/// `triangular_lower` is interpreted as a lower triangular matrix, and its diagonal elements are
419/// implicitly considered to be `1.0`. Its upper triangular part is not accessed.
420///
421/// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is
422/// `Conj::Yes`.  
423///
424/// # Panics
425///
426///  - Panics if `triangular_lower` is not a square matrix.
427///  - Panics if `rhs.nrows() != triangular_lower.ncols()`
428///
429/// # Example
430///
431/// ```
432/// use faer_core::{
433///     mat,
434///     mul::triangular::{matmul, BlockStructure},
435///     solve::solve_unit_lower_triangular_in_place_with_conj,
436///     unzipped, zipped, Conj, Mat, Parallelism,
437/// };
438///
439/// let m = mat![[0.0, 0.0], [2.0, 0.0]];
440/// let rhs = mat![[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
441///
442/// let mut sol = rhs.clone();
443/// solve_unit_lower_triangular_in_place_with_conj(
444///     m.as_ref(),
445///     Conj::No,
446///     sol.as_mut(),
447///     Parallelism::None,
448/// );
449///
450/// let mut m_times_sol = Mat::<f64>::zeros(2, 3);
451/// matmul(
452///     m_times_sol.as_mut(),
453///     BlockStructure::Rectangular,
454///     m.as_ref(),
455///     BlockStructure::UnitTriangularLower,
456///     sol.as_ref(),
457///     BlockStructure::Rectangular,
458///     None,
459///     1.0,
460///     Parallelism::None,
461/// );
462///
463/// zipped!(m_times_sol.as_ref(), rhs.as_ref())
464///     .for_each(|unzipped!(x, target)| assert!((x.read() - target.read()).abs() < 1e-10));
465/// ```
466#[track_caller]
467#[inline]
468pub fn solve_unit_lower_triangular_in_place_with_conj<E: ComplexField>(
469    triangular_lower: MatRef<'_, E>,
470    conj_lhs: Conj,
471    rhs: MatMut<'_, E>,
472    parallelism: Parallelism,
473) {
474    assert!(all(
475        triangular_lower.nrows() == triangular_lower.ncols(),
476        rhs.nrows() == triangular_lower.ncols(),
477    ));
478
479    unsafe {
480        solve_unit_lower_triangular_in_place_unchecked(
481            triangular_lower,
482            conj_lhs,
483            rhs,
484            parallelism,
485        );
486    }
487}
488
489/// Computes the solution of `triangular_lower×X = rhs`, and stores the result in
490/// `rhs`.
491///
492/// `triangular_lower` is interpreted as a lower triangular matrix with an implicit unit diagonal.
493/// Its lower triangular part is not accessed.
494#[track_caller]
495#[inline]
496pub fn solve_unit_lower_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
497    triangular_lower: MatRef<'_, TriE>,
498    rhs: MatMut<'_, E>,
499    parallelism: Parallelism,
500) {
501    let (tri, conj) = triangular_lower.canonicalize();
502    solve_unit_lower_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
503}
504
505/// Computes the solution of `Op_lhs(triangular_upper)×X = rhs`, and stores the result in
506/// `rhs`.
507///
508/// `triangular_upper` is interpreted as a upper triangular matrix, and its diagonal elements are
509/// implicitly considered to be `1.0`. Its lower triangular part is not accessed.
510///
511/// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is
512/// `Conj::Yes`.  
513///
514/// # Panics
515///
516///  - Panics if `triangular_upper` is not a square matrix.
517///  - Panics if `rhs.nrows() != triangular_lower.ncols()`
518///
519/// ```
520/// use faer_core::{
521///     mat,
522///     mul::triangular::{matmul, BlockStructure},
523///     solve::solve_unit_upper_triangular_in_place_with_conj,
524///     unzipped, zipped, Conj, Mat, Parallelism,
525/// };
526///
527/// let m = mat![[0.0, 2.0], [0.0, 0.0]];
528/// let rhs = mat![[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
529///
530/// let mut sol = rhs.clone();
531/// solve_unit_upper_triangular_in_place_with_conj(
532///     m.as_ref(),
533///     Conj::No,
534///     sol.as_mut(),
535///     Parallelism::None,
536/// );
537///
538/// let mut m_times_sol = Mat::<f64>::zeros(2, 3);
539/// matmul(
540///     m_times_sol.as_mut(),
541///     BlockStructure::Rectangular,
542///     m.as_ref(),
543///     BlockStructure::UnitTriangularUpper,
544///     sol.as_ref(),
545///     BlockStructure::Rectangular,
546///     None,
547///     1.0,
548///     Parallelism::None,
549/// );
550///
551/// zipped!(m_times_sol.as_ref(), rhs.as_ref())
552///     .for_each(|unzipped!(x, target)| assert!((x.read() - target.read()).abs() < 1e-10));
553/// ```
554#[track_caller]
555#[inline]
556pub fn solve_unit_upper_triangular_in_place_with_conj<E: ComplexField>(
557    triangular_upper: MatRef<'_, E>,
558    conj_lhs: Conj,
559    rhs: MatMut<'_, E>,
560    parallelism: Parallelism,
561) {
562    assert!(all(
563        triangular_upper.nrows() == triangular_upper.ncols(),
564        rhs.nrows() == triangular_upper.ncols(),
565    ));
566
567    unsafe {
568        solve_unit_upper_triangular_in_place_unchecked(
569            triangular_upper,
570            conj_lhs,
571            rhs,
572            parallelism,
573        );
574    }
575}
576
577/// Computes the solution of `triangular_upper×X = rhs`, and stores the result in
578/// `rhs`.
579///
580/// `triangular_upper` is interpreted as a upper triangular matrix with an implicit unit diagonal.
581/// Its upper triangular part is not accessed.
582#[track_caller]
583#[inline]
584pub fn solve_unit_upper_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
585    triangular_upper: MatRef<'_, TriE>,
586    rhs: MatMut<'_, E>,
587    parallelism: Parallelism,
588) {
589    let (tri, conj) = triangular_upper.canonicalize();
590    solve_unit_upper_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
591}
592
593/// # Safety
594///
595/// Same as [`solve_unit_lower_triangular_in_place`], except that panics become undefined behavior.
596///
597/// # Example
598///
599/// See [`solve_unit_lower_triangular_in_place`].
600unsafe fn solve_unit_lower_triangular_in_place_unchecked<E: ComplexField>(
601    tril: MatRef<'_, E>,
602    conj_lhs: Conj,
603    rhs: MatMut<'_, E>,
604    parallelism: Parallelism,
605) {
606    let n = tril.nrows();
607    let k = rhs.ncols();
608
609    if k > 64 && n <= 128 {
610        let (_, _, rhs_left, rhs_right) = rhs.split_at_mut(0, k / 2);
611        join_raw(
612            |_| {
613                solve_unit_lower_triangular_in_place_unchecked(
614                    tril,
615                    conj_lhs,
616                    rhs_left,
617                    parallelism,
618                )
619            },
620            |_| {
621                solve_unit_lower_triangular_in_place_unchecked(
622                    tril,
623                    conj_lhs,
624                    rhs_right,
625                    parallelism,
626                )
627            },
628            parallelism,
629        );
630        return;
631    }
632
633    debug_assert!(all(
634        tril.nrows() == tril.ncols(),
635        rhs.nrows() == tril.ncols(),
636    ));
637
638    if n <= recursion_threshold() {
639        E::Simd::default().dispatch(
640            #[inline(always)]
641            || match conj_lhs {
642                Conj::Yes => solve_unit_lower_triangular_in_place_base_case_generic_unchecked(
643                    tril, rhs, conj,
644                ),
645                Conj::No => solve_unit_lower_triangular_in_place_base_case_generic_unchecked(
646                    tril, rhs, identity,
647                ),
648            },
649        );
650        return;
651    }
652
653    let bs = blocksize(n);
654
655    let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_at(bs, bs);
656    let (_, mut rhs_top, _, mut rhs_bot) = rhs.split_at_mut(bs, 0);
657
658    //       (A00    )   X0         (B0)
659    // ConjA?(A10 A11)   X1 = ConjB?(B1)
660    //
661    //
662    // 1. ConjA?(A00) X0 = ConjB?(B0)
663    //
664    // 2. ConjA?(A10) X0 + ConjA?(A11) X1 = ConjB?(B1)
665    // => ConjA?(A11) X1 = ConjB?(B1) - ConjA?(A10) X0
666
667    solve_unit_lower_triangular_in_place_unchecked(
668        tril_top_left,
669        conj_lhs,
670        rhs_top.rb_mut(),
671        parallelism,
672    );
673
674    crate::mul::matmul_with_conj(
675        rhs_bot.rb_mut(),
676        tril_bot_left,
677        conj_lhs,
678        rhs_top.into_const(),
679        Conj::No,
680        Some(E::faer_one()),
681        E::faer_one().faer_neg(),
682        parallelism,
683    );
684
685    solve_unit_lower_triangular_in_place_unchecked(tril_bot_right, conj_lhs, rhs_bot, parallelism);
686}
687
688/// # Safety
689///
690/// Same as [`solve_unit_upper_triangular_in_place`], except that panics become undefined behavior.
691///
692/// # Example
693///
694/// See [`solve_unit_upper_triangular_in_place`].
695#[inline]
696unsafe fn solve_unit_upper_triangular_in_place_unchecked<E: ComplexField>(
697    triu: MatRef<'_, E>,
698    conj_lhs: Conj,
699    rhs: MatMut<'_, E>,
700    parallelism: Parallelism,
701) {
702    solve_unit_lower_triangular_in_place_unchecked(
703        triu.reverse_rows_and_cols(),
704        conj_lhs,
705        rhs.reverse_rows_mut(),
706        parallelism,
707    );
708}
709
710/// # Safety
711///
712/// Same as [`solve_lower_triangular_in_place`], except that panics become undefined behavior.
713///
714/// # Example
715///
716/// See [`solve_lower_triangular_in_place`].
717unsafe fn solve_lower_triangular_in_place_unchecked<E: ComplexField>(
718    tril: MatRef<'_, E>,
719    conj_lhs: Conj,
720    rhs: MatMut<'_, E>,
721    parallelism: Parallelism,
722) {
723    let n = tril.nrows();
724    let k = rhs.ncols();
725
726    if k > 64 && n <= 128 {
727        let (_, _, rhs_left, rhs_right) = rhs.split_at_mut(0, k / 2);
728        join_raw(
729            |_| solve_lower_triangular_in_place_unchecked(tril, conj_lhs, rhs_left, parallelism),
730            |_| solve_lower_triangular_in_place_unchecked(tril, conj_lhs, rhs_right, parallelism),
731            parallelism,
732        );
733        return;
734    }
735
736    debug_assert!(all(
737        tril.nrows() == tril.ncols(),
738        rhs.nrows() == tril.ncols(),
739    ));
740
741    let n = tril.nrows();
742
743    if n <= recursion_threshold() {
744        E::Simd::default().dispatch(
745            #[inline(always)]
746            || match conj_lhs {
747                Conj::Yes => {
748                    solve_lower_triangular_in_place_base_case_generic_unchecked(tril, rhs, conj)
749                }
750                Conj::No => {
751                    solve_lower_triangular_in_place_base_case_generic_unchecked(tril, rhs, identity)
752                }
753            },
754        );
755        return;
756    }
757
758    let bs = blocksize(n);
759
760    let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_at(bs, bs);
761    let (_, mut rhs_top, _, mut rhs_bot) = rhs.split_at_mut(bs, 0);
762
763    solve_lower_triangular_in_place_unchecked(
764        tril_top_left,
765        conj_lhs,
766        rhs_top.rb_mut(),
767        parallelism,
768    );
769
770    crate::mul::matmul_with_conj(
771        rhs_bot.rb_mut(),
772        tril_bot_left,
773        conj_lhs,
774        rhs_top.into_const(),
775        Conj::No,
776        Some(E::faer_one()),
777        E::faer_one().faer_neg(),
778        parallelism,
779    );
780
781    solve_lower_triangular_in_place_unchecked(tril_bot_right, conj_lhs, rhs_bot, parallelism);
782}
783
784/// # Safety
785///
786/// Same as [`solve_upper_triangular_in_place`], except that panics become undefined behavior.
787///
788/// # Example
789///
790/// See [`solve_upper_triangular_in_place`].
791#[inline]
792unsafe fn solve_upper_triangular_in_place_unchecked<E: ComplexField>(
793    triu: MatRef<'_, E>,
794    conj_lhs: Conj,
795    rhs: MatMut<'_, E>,
796    parallelism: Parallelism,
797) {
798    solve_lower_triangular_in_place_unchecked(
799        triu.reverse_rows_and_cols(),
800        conj_lhs,
801        rhs.reverse_rows_mut(),
802        parallelism,
803    );
804}