1use crate::{
4 assert, debug_assert, join_raw, unzipped, zipped, ComplexField, Conj, Conjugate, MatMut,
5 MatRef, Parallelism,
6};
7use faer_entity::SimdCtx;
8use reborrow::*;
9
10#[inline(always)]
11fn identity<E: Clone>(x: E) -> E {
12 x.clone()
13}
14
15#[inline(always)]
16fn conj<E: ComplexField>(x: E) -> E {
17 x.faer_conj()
18}
19
20#[inline(always)]
21unsafe fn solve_unit_lower_triangular_in_place_base_case_generic_unchecked<E: ComplexField>(
22 tril: MatRef<'_, E>,
23 rhs: MatMut<'_, E>,
24 maybe_conj_lhs: impl Fn(E) -> E,
25) {
26 let n = tril.nrows();
27 match n {
28 0 | 1 => (),
29 2 => {
30 let nl10_div_l11 = maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_neg();
31
32 let (_, x0, _, x1) = rhs.split_at_mut(1, 0);
33 let x0 = x0.subrows_mut(0, 1);
34 let x1 = x1.subrows_mut(0, 1);
35
36 zipped!(x0, x1).for_each(|unzipped!(x0, mut x1)| {
37 x1.write(x1.read().faer_add(nl10_div_l11.faer_mul(x0.read())));
38 });
39 }
40 3 => {
41 let nl10_div_l11 = maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_neg();
42 let nl20_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_neg();
43 let nl21_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_neg();
44
45 let (_, x0, _, x1_2) = rhs.split_at_mut(1, 0);
46 let (_, x1, _, x2) = x1_2.split_at_mut(1, 0);
47 let x0 = x0.subrows_mut(0, 1);
48 let x1 = x1.subrows_mut(0, 1);
49 let x2 = x2.subrows_mut(0, 1);
50
51 zipped!(x0, x1, x2).for_each(|unzipped!(mut x0, mut x1, mut x2)| {
52 let y0 = x0.read();
53 let mut y1 = x1.read();
54 let mut y2 = x2.read();
55 y1 = y1.faer_add(nl10_div_l11.faer_mul(y0));
56 y2 = y2
57 .faer_add(nl20_div_l22.faer_mul(y0))
58 .faer_add(nl21_div_l22.faer_mul(y1));
59 x0.write(y0);
60 x1.write(y1);
61 x2.write(y2);
62 });
63 }
64 4 => {
65 let nl10_div_l11 = maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_neg();
66 let nl20_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_neg();
67 let nl21_div_l22 = maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_neg();
68 let nl30_div_l33 = maybe_conj_lhs(tril.read_unchecked(3, 0)).faer_neg();
69 let nl31_div_l33 = maybe_conj_lhs(tril.read_unchecked(3, 1)).faer_neg();
70 let nl32_div_l33 = maybe_conj_lhs(tril.read_unchecked(3, 2)).faer_neg();
71
72 let (_, x0, _, x1_2_3) = rhs.split_at_mut(1, 0);
73 let (_, x1, _, x2_3) = x1_2_3.split_at_mut(1, 0);
74 let (_, x2, _, x3) = x2_3.split_at_mut(1, 0);
75 let x0 = x0.subrows_mut(0, 1);
76 let x1 = x1.subrows_mut(0, 1);
77 let x2 = x2.subrows_mut(0, 1);
78 let x3 = x3.subrows_mut(0, 1);
79
80 zipped!(x0, x1, x2, x3).for_each(|unzipped!(mut x0, mut x1, mut x2, mut x3)| {
81 let y0 = x0.read();
82 let mut y1 = x1.read();
83 let mut y2 = x2.read();
84 let mut y3 = x3.read();
85 y1 = y1.faer_add(nl10_div_l11.faer_mul(y0));
86 y2 = y2.faer_add(
87 nl20_div_l22
88 .faer_mul(y0)
89 .faer_add(nl21_div_l22.faer_mul(y1)),
90 );
91 y3 = (y3.faer_add(nl30_div_l33.faer_mul(y0))).faer_add(
92 nl31_div_l33
93 .faer_mul(y1)
94 .faer_add(nl32_div_l33.faer_mul(y2)),
95 );
96 x0.write(y0);
97 x1.write(y1);
98 x2.write(y2);
99 x3.write(y3);
100 });
101 }
102 _ => unreachable!(),
103 }
104}
105
106#[inline(always)]
107unsafe fn solve_lower_triangular_in_place_base_case_generic_unchecked<E: ComplexField>(
108 tril: MatRef<'_, E>,
109 rhs: MatMut<'_, E>,
110 maybe_conj_lhs: impl Fn(E) -> E,
111) {
112 let n = tril.nrows();
113 match n {
114 0 => (),
115 1 => {
116 let inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
117 let x0 = rhs.subrows_mut(0, 1);
118 zipped!(x0).for_each(|unzipped!(mut x0)| x0.write(x0.read().faer_mul(inv)));
119 }
120 2 => {
121 let l00_inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
122 let l11_inv = maybe_conj_lhs(tril.read_unchecked(1, 1)).faer_inv();
123 let nl10_div_l11 =
124 (maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_mul(l11_inv)).faer_neg();
125
126 let (_, x0, _, x1) = rhs.split_at_mut(1, 0);
127 let x0 = x0.subrows_mut(0, 1);
128 let x1 = x1.subrows_mut(0, 1);
129
130 zipped!(x0, x1).for_each(|unzipped!(mut x0, mut x1)| {
131 x0.write(x0.read().faer_mul(l00_inv));
132 x1.write(
133 x1.read()
134 .faer_mul(l11_inv)
135 .faer_add(nl10_div_l11.faer_mul(x0.read())),
136 );
137 });
138 }
139 3 => {
140 let l00_inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
141 let l11_inv = maybe_conj_lhs(tril.read_unchecked(1, 1)).faer_inv();
142 let l22_inv = maybe_conj_lhs(tril.read_unchecked(2, 2)).faer_inv();
143 let nl10_div_l11 =
144 (maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_mul(l11_inv)).faer_neg();
145 let nl20_div_l22 =
146 (maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_mul(l22_inv)).faer_neg();
147 let nl21_div_l22 =
148 (maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_mul(l22_inv)).faer_neg();
149
150 let (_, x0, _, x1_2) = rhs.split_at_mut(1, 0);
151 let (_, x1, _, x2) = x1_2.split_at_mut(1, 0);
152 let x0 = x0.subrows_mut(0, 1);
153 let x1 = x1.subrows_mut(0, 1);
154 let x2 = x2.subrows_mut(0, 1);
155
156 zipped!(x0, x1, x2).for_each(|unzipped!(mut x0, mut x1, mut x2)| {
157 let mut y0 = x0.read();
158 let mut y1 = x1.read();
159 let mut y2 = x2.read();
160 y0 = y0.faer_mul(l00_inv);
161 y1 = y1.faer_mul(l11_inv).faer_add(nl10_div_l11.faer_mul(y0));
162 y2 = y2
163 .faer_mul(l22_inv)
164 .faer_add(nl20_div_l22.faer_mul(y0))
165 .faer_add(nl21_div_l22.faer_mul(y1));
166 x0.write(y0);
167 x1.write(y1);
168 x2.write(y2);
169 });
170 }
171 4 => {
172 let l00_inv = maybe_conj_lhs(tril.read_unchecked(0, 0)).faer_inv();
173 let l11_inv = maybe_conj_lhs(tril.read_unchecked(1, 1)).faer_inv();
174 let l22_inv = maybe_conj_lhs(tril.read_unchecked(2, 2)).faer_inv();
175 let l33_inv = maybe_conj_lhs(tril.read_unchecked(3, 3)).faer_inv();
176 let nl10_div_l11 =
177 (maybe_conj_lhs(tril.read_unchecked(1, 0)).faer_mul(l11_inv)).faer_neg();
178 let nl20_div_l22 =
179 (maybe_conj_lhs(tril.read_unchecked(2, 0)).faer_mul(l22_inv)).faer_neg();
180 let nl21_div_l22 =
181 (maybe_conj_lhs(tril.read_unchecked(2, 1)).faer_mul(l22_inv)).faer_neg();
182 let nl30_div_l33 =
183 (maybe_conj_lhs(tril.read_unchecked(3, 0)).faer_mul(l33_inv)).faer_neg();
184 let nl31_div_l33 =
185 (maybe_conj_lhs(tril.read_unchecked(3, 1)).faer_mul(l33_inv)).faer_neg();
186 let nl32_div_l33 =
187 (maybe_conj_lhs(tril.read_unchecked(3, 2)).faer_mul(l33_inv)).faer_neg();
188
189 let (_, x0, _, x1_2_3) = rhs.split_at_mut(1, 0);
190 let (_, x1, _, x2_3) = x1_2_3.split_at_mut(1, 0);
191 let (_, x2, _, x3) = x2_3.split_at_mut(1, 0);
192 let x0 = x0.subrows_mut(0, 1);
193 let x1 = x1.subrows_mut(0, 1);
194 let x2 = x2.subrows_mut(0, 1);
195 let x3 = x3.subrows_mut(0, 1);
196
197 zipped!(x0, x1, x2, x3).for_each(|unzipped!(mut x0, mut x1, mut x2, mut x3)| {
198 let mut y0 = x0.read();
199 let mut y1 = x1.read();
200 let mut y2 = x2.read();
201 let mut y3 = x3.read();
202 y0 = y0.faer_mul(l00_inv);
203 y1 = y1.faer_mul(l11_inv).faer_add(nl10_div_l11.faer_mul(y0));
204 y2 = y2.faer_mul(l22_inv).faer_add(
205 nl20_div_l22
206 .faer_mul(y0)
207 .faer_add(nl21_div_l22.faer_mul(y1)),
208 );
209 y3 = (y3.faer_mul(l33_inv).faer_add(nl30_div_l33.faer_mul(y0))).faer_add(
210 nl31_div_l33
211 .faer_mul(y1)
212 .faer_add(nl32_div_l33.faer_mul(y2)),
213 );
214 x0.write(y0);
215 x1.write(y1);
216 x2.write(y2);
217 x3.write(y3);
218 });
219 }
220 _ => unreachable!(),
221 }
222}
223
224#[inline]
225fn blocksize(n: usize) -> usize {
226 let base_rem = n / 2;
229 n - if n >= 32 {
230 (base_rem + 15) / 16 * 16
231 } else if n >= 16 {
232 (base_rem + 7) / 8 * 8
233 } else if n >= 8 {
234 (base_rem + 3) / 4 * 4
235 } else {
236 base_rem
237 }
238}
239
240#[inline]
241fn recursion_threshold() -> usize {
242 4
243}
244
245#[track_caller]
297#[inline]
298pub fn solve_lower_triangular_in_place_with_conj<E: ComplexField>(
299 triangular_lower: MatRef<'_, E>,
300 conj_lhs: Conj,
301 rhs: MatMut<'_, E>,
302 parallelism: Parallelism,
303) {
304 assert!(all(
305 triangular_lower.nrows() == triangular_lower.ncols(),
306 rhs.nrows() == triangular_lower.ncols(),
307 ));
308
309 unsafe {
310 solve_lower_triangular_in_place_unchecked(triangular_lower, conj_lhs, rhs, parallelism);
311 }
312}
313
314#[track_caller]
320#[inline]
321pub fn solve_lower_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
322 triangular_lower: MatRef<'_, TriE>,
323 rhs: MatMut<'_, E>,
324 parallelism: Parallelism,
325) {
326 let (tri, conj) = triangular_lower.canonicalize();
327 solve_lower_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
328}
329
330#[track_caller]
382#[inline]
383pub fn solve_upper_triangular_in_place_with_conj<E: ComplexField>(
384 triangular_upper: MatRef<'_, E>,
385 conj_lhs: Conj,
386 rhs: MatMut<'_, E>,
387 parallelism: Parallelism,
388) {
389 assert!(all(
390 triangular_upper.nrows() == triangular_upper.ncols(),
391 rhs.nrows() == triangular_upper.ncols(),
392 ));
393
394 unsafe {
395 solve_upper_triangular_in_place_unchecked(triangular_upper, conj_lhs, rhs, parallelism);
396 }
397}
398
399#[track_caller]
405#[inline]
406pub fn solve_upper_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
407 triangular_upper: MatRef<'_, TriE>,
408 rhs: MatMut<'_, E>,
409 parallelism: Parallelism,
410) {
411 let (tri, conj) = triangular_upper.canonicalize();
412 solve_upper_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
413}
414
415#[track_caller]
467#[inline]
468pub fn solve_unit_lower_triangular_in_place_with_conj<E: ComplexField>(
469 triangular_lower: MatRef<'_, E>,
470 conj_lhs: Conj,
471 rhs: MatMut<'_, E>,
472 parallelism: Parallelism,
473) {
474 assert!(all(
475 triangular_lower.nrows() == triangular_lower.ncols(),
476 rhs.nrows() == triangular_lower.ncols(),
477 ));
478
479 unsafe {
480 solve_unit_lower_triangular_in_place_unchecked(
481 triangular_lower,
482 conj_lhs,
483 rhs,
484 parallelism,
485 );
486 }
487}
488
489#[track_caller]
495#[inline]
496pub fn solve_unit_lower_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
497 triangular_lower: MatRef<'_, TriE>,
498 rhs: MatMut<'_, E>,
499 parallelism: Parallelism,
500) {
501 let (tri, conj) = triangular_lower.canonicalize();
502 solve_unit_lower_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
503}
504
505#[track_caller]
555#[inline]
556pub fn solve_unit_upper_triangular_in_place_with_conj<E: ComplexField>(
557 triangular_upper: MatRef<'_, E>,
558 conj_lhs: Conj,
559 rhs: MatMut<'_, E>,
560 parallelism: Parallelism,
561) {
562 assert!(all(
563 triangular_upper.nrows() == triangular_upper.ncols(),
564 rhs.nrows() == triangular_upper.ncols(),
565 ));
566
567 unsafe {
568 solve_unit_upper_triangular_in_place_unchecked(
569 triangular_upper,
570 conj_lhs,
571 rhs,
572 parallelism,
573 );
574 }
575}
576
577#[track_caller]
583#[inline]
584pub fn solve_unit_upper_triangular_in_place<E: ComplexField, TriE: Conjugate<Canonical = E>>(
585 triangular_upper: MatRef<'_, TriE>,
586 rhs: MatMut<'_, E>,
587 parallelism: Parallelism,
588) {
589 let (tri, conj) = triangular_upper.canonicalize();
590 solve_unit_upper_triangular_in_place_with_conj(tri, conj, rhs, parallelism)
591}
592
593unsafe fn solve_unit_lower_triangular_in_place_unchecked<E: ComplexField>(
601 tril: MatRef<'_, E>,
602 conj_lhs: Conj,
603 rhs: MatMut<'_, E>,
604 parallelism: Parallelism,
605) {
606 let n = tril.nrows();
607 let k = rhs.ncols();
608
609 if k > 64 && n <= 128 {
610 let (_, _, rhs_left, rhs_right) = rhs.split_at_mut(0, k / 2);
611 join_raw(
612 |_| {
613 solve_unit_lower_triangular_in_place_unchecked(
614 tril,
615 conj_lhs,
616 rhs_left,
617 parallelism,
618 )
619 },
620 |_| {
621 solve_unit_lower_triangular_in_place_unchecked(
622 tril,
623 conj_lhs,
624 rhs_right,
625 parallelism,
626 )
627 },
628 parallelism,
629 );
630 return;
631 }
632
633 debug_assert!(all(
634 tril.nrows() == tril.ncols(),
635 rhs.nrows() == tril.ncols(),
636 ));
637
638 if n <= recursion_threshold() {
639 E::Simd::default().dispatch(
640 #[inline(always)]
641 || match conj_lhs {
642 Conj::Yes => solve_unit_lower_triangular_in_place_base_case_generic_unchecked(
643 tril, rhs, conj,
644 ),
645 Conj::No => solve_unit_lower_triangular_in_place_base_case_generic_unchecked(
646 tril, rhs, identity,
647 ),
648 },
649 );
650 return;
651 }
652
653 let bs = blocksize(n);
654
655 let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_at(bs, bs);
656 let (_, mut rhs_top, _, mut rhs_bot) = rhs.split_at_mut(bs, 0);
657
658 solve_unit_lower_triangular_in_place_unchecked(
668 tril_top_left,
669 conj_lhs,
670 rhs_top.rb_mut(),
671 parallelism,
672 );
673
674 crate::mul::matmul_with_conj(
675 rhs_bot.rb_mut(),
676 tril_bot_left,
677 conj_lhs,
678 rhs_top.into_const(),
679 Conj::No,
680 Some(E::faer_one()),
681 E::faer_one().faer_neg(),
682 parallelism,
683 );
684
685 solve_unit_lower_triangular_in_place_unchecked(tril_bot_right, conj_lhs, rhs_bot, parallelism);
686}
687
688#[inline]
696unsafe fn solve_unit_upper_triangular_in_place_unchecked<E: ComplexField>(
697 triu: MatRef<'_, E>,
698 conj_lhs: Conj,
699 rhs: MatMut<'_, E>,
700 parallelism: Parallelism,
701) {
702 solve_unit_lower_triangular_in_place_unchecked(
703 triu.reverse_rows_and_cols(),
704 conj_lhs,
705 rhs.reverse_rows_mut(),
706 parallelism,
707 );
708}
709
710unsafe fn solve_lower_triangular_in_place_unchecked<E: ComplexField>(
718 tril: MatRef<'_, E>,
719 conj_lhs: Conj,
720 rhs: MatMut<'_, E>,
721 parallelism: Parallelism,
722) {
723 let n = tril.nrows();
724 let k = rhs.ncols();
725
726 if k > 64 && n <= 128 {
727 let (_, _, rhs_left, rhs_right) = rhs.split_at_mut(0, k / 2);
728 join_raw(
729 |_| solve_lower_triangular_in_place_unchecked(tril, conj_lhs, rhs_left, parallelism),
730 |_| solve_lower_triangular_in_place_unchecked(tril, conj_lhs, rhs_right, parallelism),
731 parallelism,
732 );
733 return;
734 }
735
736 debug_assert!(all(
737 tril.nrows() == tril.ncols(),
738 rhs.nrows() == tril.ncols(),
739 ));
740
741 let n = tril.nrows();
742
743 if n <= recursion_threshold() {
744 E::Simd::default().dispatch(
745 #[inline(always)]
746 || match conj_lhs {
747 Conj::Yes => {
748 solve_lower_triangular_in_place_base_case_generic_unchecked(tril, rhs, conj)
749 }
750 Conj::No => {
751 solve_lower_triangular_in_place_base_case_generic_unchecked(tril, rhs, identity)
752 }
753 },
754 );
755 return;
756 }
757
758 let bs = blocksize(n);
759
760 let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_at(bs, bs);
761 let (_, mut rhs_top, _, mut rhs_bot) = rhs.split_at_mut(bs, 0);
762
763 solve_lower_triangular_in_place_unchecked(
764 tril_top_left,
765 conj_lhs,
766 rhs_top.rb_mut(),
767 parallelism,
768 );
769
770 crate::mul::matmul_with_conj(
771 rhs_bot.rb_mut(),
772 tril_bot_left,
773 conj_lhs,
774 rhs_top.into_const(),
775 Conj::No,
776 Some(E::faer_one()),
777 E::faer_one().faer_neg(),
778 parallelism,
779 );
780
781 solve_lower_triangular_in_place_unchecked(tril_bot_right, conj_lhs, rhs_bot, parallelism);
782}
783
784#[inline]
792unsafe fn solve_upper_triangular_in_place_unchecked<E: ComplexField>(
793 triu: MatRef<'_, E>,
794 conj_lhs: Conj,
795 rhs: MatMut<'_, E>,
796 parallelism: Parallelism,
797) {
798 solve_lower_triangular_in_place_unchecked(
799 triu.reverse_rows_and_cols(),
800 conj_lhs,
801 rhs.reverse_rows_mut(),
802 parallelism,
803 );
804}