1use crate::internal_prelude::*;
2use crate::{assert, perm};
3use linalg::matmul::triangular::BlockStructure;
4
5const TOP_BIT: usize = 1 << (usize::BITS - 1);
6
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
9#[non_exhaustive]
10pub enum PivotingStrategy {
11 #[deprecated]
13 Diagonal,
14
15 Partial,
17 PartialDiag,
20 Rook,
22 RookDiag,
25
26 Full,
28}
29
30#[derive(Copy, Clone, Debug)]
32pub struct LbltParams {
33 pub pivoting: PivotingStrategy,
35 pub block_size: usize,
37
38 pub par_threshold: usize,
40
41 #[doc(hidden)]
42 pub non_exhaustive: NonExhaustive,
43}
44
45#[math]
46fn swap_self_adjoint<T: ComplexField>(A: MatMut<'_, T>, i: usize, j: usize) {
47 assert_ne!(i, j);
48
49 let mut A = A;
50 let (i, j) = (Ord::min(i, j), Ord::max(i, j));
51
52 perm::swap_cols_idx(A.rb_mut().get_mut(j + 1.., ..), i, j);
53 perm::swap_rows_idx(A.rb_mut().get_mut(.., ..i), i, j);
54
55 let tmp = real(A[(i, i)]);
56 A[(i, i)] = from_real(real(A[(j, j)]));
57 A[(j, j)] = from_real(tmp);
58
59 A[(j, i)] = conj(A[(j, i)]);
60
61 let (Ai, Aj) = A.split_at_row_mut(j);
62 let Ai = Ai.get_mut(i + 1..j, i);
63 let Aj = Aj.get_mut(0, i + 1..j).transpose_mut();
64 zip!(Ai, Aj).for_each(|unzip!(x, y): Zip!(&mut _, &mut _)| {
65 let tmp = conj(*x);
66 *x = conj(*y);
67 *y = tmp;
68 });
69}
70
71#[math]
72fn rank_1_update_and_argmax_fallback<'M, 'N, T: ComplexField>(
73 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
74 L: ColRef<'_, T, Dim<'N>>,
75 d: T::Real,
76 start: IdxInc<'N>,
77 end: IdxInc<'N>,
78) -> (usize, usize, T::Real) {
79 let mut A = A;
80 let n = A.nrows();
81
82 let mut max_j = n.idx(0);
83 let mut max_i = n.idx(0);
84 let mut max_offdiag = zero();
85
86 for j in start.to(end) {
87 for i in j.next().to(n.end()) {
88 A[(i, j)] = A[(i, j)] - mul_real(L[i] * conj(L[j]), d);
89 let val = abs2(A[(i, j)]);
90 if val > max_offdiag {
91 max_offdiag = val;
92 max_i = i;
93 max_j = j;
94 }
95 }
96 }
97
98 (*max_i, *max_j, max_offdiag)
99}
100
101#[math]
102fn rank_2_update_and_argmax_fallback<'N, T: ComplexField>(
103 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
104 L0: ColRef<'_, T, Dim<'N>>,
105 L1: ColRef<'_, T, Dim<'N>>,
106 d: T::Real,
107 d00: T::Real,
108 d11: T::Real,
109 d10: T,
110 start: IdxInc<'N>,
111 end: IdxInc<'N>,
112) -> (usize, usize, T::Real) {
113 let mut A = A;
114 let n = A.nrows();
115
116 let mut max_j = n.idx(0);
117 let mut max_i = n.idx(0);
118 let mut max_offdiag = zero();
119
120 for j in start.to(end) {
121 let x0 = copy(L0[j]);
122 let x1 = copy(L1[j]);
123
124 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
125 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
126
127 for i in j.next().to(n.end()) {
128 A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
129
130 let val = abs2(A[(i, j)]);
131 if val > max_offdiag {
132 max_offdiag = val;
133 max_i = i;
134 max_j = j;
135 }
136 }
137 }
138 (*max_i, *max_j, max_offdiag)
139}
140
141#[math]
142fn rank_1_update_and_argmax_seq<'M, 'N, T: ComplexField>(
143 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
144 L: ColRef<'_, T, Dim<'N>>,
145 d: T::Real,
146 start: IdxInc<'N>,
147 end: IdxInc<'N>,
148) -> (usize, usize, T::Real) {
149 rank_1_update_and_argmax_fallback(A, L, d, start, end)
150}
151
152#[math]
153fn rank_2_update_and_argmax_seq<'N, T: ComplexField>(
154 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
155 L0: ColRef<'_, T, Dim<'N>>,
156 L1: ColRef<'_, T, Dim<'N>>,
157 d: T::Real,
158 d00: T::Real,
159 d11: T::Real,
160 d10: T,
161 start: IdxInc<'N>,
162 end: IdxInc<'N>,
163) -> (usize, usize, T::Real) {
164 rank_2_update_and_argmax_fallback(A, L0, L1, d, d00, d11, d10, start, end)
165}
166
167#[math]
168fn rank_1_update_and_argmax<T: ComplexField>(A: MatMut<'_, T>, L: ColRef<'_, T>, d: T::Real, par: Par) -> (usize, usize, T::Real) {
169 with_dim!(N, A.nrows());
170
171 match par {
172 Par::Seq => rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), d, IdxInc::ZERO, N.end()),
173 #[cfg(feature = "rayon")]
174 Par::Rayon(nthreads) => {
175 use rayon::prelude::*;
176 let nthreads = nthreads.get();
177 let n = *N;
178
179 assert!((n as u64) < (1u64 << 50));
181
182 let idx_to_col_start = |idx: usize| {
183 let idx_as_percent = idx as f64 / nthreads as f64;
184 let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
185 (col_start_percent * n as f64) as usize
186 };
187
188 let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
189
190 spindle::for_each(nthreads, r.par_iter_mut().enumerate(), |(idx, out)| {
191 let A = unsafe { A.rb().const_cast() };
192 let start = N.idx_inc(idx_to_col_start(idx));
193 let end = N.idx_inc(idx_to_col_start(idx + 1));
194
195 *out = rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), copy(d), start, end);
196 });
197
198 r.into_iter()
199 .max_by(|(_, _, a), (_, _, b)| {
200 if a == b {
201 core::cmp::Ordering::Equal
202 } else if a > b {
203 core::cmp::Ordering::Greater
204 } else {
205 core::cmp::Ordering::Less
206 }
207 })
208 .unwrap()
209 },
210 }
211}
212
213#[math]
214fn rank_2_update_and_argmax<'N, T: ComplexField>(
215 A: MatMut<'_, T>,
216 L0: ColRef<'_, T>,
217 L1: ColRef<'_, T>,
218 d: T::Real,
219 d00: T::Real,
220 d11: T::Real,
221 d10: T,
222 par: Par,
223) -> (usize, usize, T::Real) {
224 with_dim!(N, A.nrows());
225
226 match par {
227 Par::Seq => rank_2_update_and_argmax_seq(
228 A.as_shape_mut(N, N),
229 L0.as_row_shape(N),
230 L1.as_row_shape(N),
231 d,
232 d00,
233 d11,
234 d10,
235 IdxInc::ZERO,
236 N.end(),
237 ),
238 #[cfg(feature = "rayon")]
239 Par::Rayon(nthreads) => {
240 use rayon::prelude::*;
241 let nthreads = nthreads.get();
242 let n = *N;
243
244 assert!((n as u64) < (1u64 << 50));
246
247 let idx_to_col_start = |idx: usize| {
248 let idx_as_percent = idx as f64 / nthreads as f64;
249 let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
250 (col_start_percent * n as f64) as usize
251 };
252
253 let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
254
255 spindle::for_each(nthreads, r.par_iter_mut().enumerate(), |(idx, out)| {
256 let A = unsafe { A.rb().const_cast() };
257 let start = N.idx_inc(idx_to_col_start(idx));
258 let end = N.idx_inc(idx_to_col_start(idx + 1));
259
260 *out = rank_2_update_and_argmax_seq(
261 A.as_shape_mut(N, N),
262 L0.as_row_shape(N),
263 L1.as_row_shape(N),
264 copy(d),
265 copy(d00),
266 copy(d11),
267 copy(d10),
268 start,
269 end,
270 );
271 });
272
273 r.into_iter()
274 .max_by(|(_, _, a), (_, _, b)| {
275 if a == b {
276 core::cmp::Ordering::Equal
277 } else if a < b {
278 core::cmp::Ordering::Less
279 } else {
280 core::cmp::Ordering::Greater
281 }
282 })
283 .unwrap()
284 },
285 }
286}
287
288#[math]
289fn lblt_full_piv<T: ComplexField>(A: MatMut<'_, T>, subdiag: DiagMut<'_, T>, pivots: &mut [usize], par: Par, params: LbltParams) {
290 let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
291 let alpha = alpha * alpha;
292
293 let mut A = A;
294 let mut subdiag = subdiag.column_vector_mut();
295 let mut par = par;
296 let n = A.nrows();
297
298 let scale_fwd = A.norm_max();
299 let scale_bwd = recip(scale_fwd);
300 zip!(A.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_bwd));
301
302 let mut max_i = 0;
303 let mut max_j = 0;
304 let mut max_offdiag = zero();
305
306 for j in 0..n {
307 for i in j + 1..n {
308 let val = abs2(A[(i, j)]);
309 if val > max_offdiag {
310 max_offdiag = val;
311 max_i = i;
312 max_j = j;
313 }
314 }
315 }
316
317 let mut k = 0;
318 while k < n {
319 if max_offdiag == zero() {
320 break;
321 }
322
323 let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
324 let mut subdiag = subdiag.rb_mut().get_mut(k..);
325 let pivots = &mut pivots[k..];
326
327 let n = A.nrows();
328 let mut max_s = 0;
329 let mut max_diag = zero();
330
331 for s in 0..n {
332 let val = abs2(A[(s, s)]);
333 if val > max_diag {
334 max_diag = val;
335 max_s = s;
336 }
337 }
338
339 let npiv;
340 let i0;
341 let i1;
342
343 if max_diag >= alpha * max_offdiag {
344 npiv = 1;
345 i0 = max_s;
346 i1 = usize::MAX;
347 } else {
348 npiv = 2;
349 i0 = max_j;
350 i1 = max_i;
351 }
352
353 let rem = n - npiv;
354 if rem * rem < params.par_threshold {
355 par = Par::Seq;
356 }
357
358 if i0 != 0 {
360 swap_self_adjoint(A.rb_mut(), 0, i0);
361 perm::swap_rows_idx(Aprev.rb_mut(), 0, i0);
362 }
363 if npiv == 2 && i1 != 1 {
364 swap_self_adjoint(A.rb_mut(), 1, i1);
365 perm::swap_rows_idx(Aprev.rb_mut(), 1, i1);
366 }
367
368 if npiv == 1 {
369 let diag = real(A[(0, 0)]);
370 let diag_inv = recip(diag);
371 subdiag[0] = zero();
372
373 let (_, _, L, mut A) = A.rb_mut().split_at_mut(1, 1);
374 let n = A.nrows();
375 let mut L = L.col_mut(0);
376
377 zip!(L.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, diag_inv));
378
379 for i in 0..n {
380 A[(i, i)] = from_real(real(A[(i, i)]) - diag * abs2(L[i]));
381 }
382
383 if n < params.par_threshold {}
384 if n != 0 {
385 (max_i, max_j, max_offdiag) = rank_1_update_and_argmax(A.rb_mut(), L.rb(), diag, par);
386 }
387 } else {
388 let a00 = real(A[(0, 0)]);
389 let a11 = real(A[(1, 1)]);
390 let a10 = copy(A[(1, 0)]);
391
392 subdiag[0] = copy(a10);
393 subdiag[1] = zero();
394 A[(1, 0)] = zero();
395
396 let d10 = abs(a10);
397 let d10_inv = recip(d10);
398 let d00 = a00 * d10_inv;
399 let d11 = a11 * d10_inv;
400
401 let t = recip(d00 * d11 - one());
403 let d10 = mul_real(a10, d10_inv);
404 let d = t * d10_inv;
405
406 let (_, _, L, mut A) = A.rb_mut().split_at_mut(2, 2);
409 let (mut L0, mut L1) = L.two_cols_mut(0, 1);
410 let n = A.nrows();
411
412 if n != 0 {
413 (max_i, max_j, max_offdiag) = rank_2_update_and_argmax(A.rb_mut(), L0.rb(), L1.rb(), copy(d), copy(d00), copy(d11), copy(d10), par);
414 }
415
416 for j in 0..n {
417 let x0 = copy(L0[j]);
418 let x1 = copy(L1[j]);
419
420 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
421 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
422
423 A[(j, j)] = from_real(real(A[(j, j)] - L0[j] * conj(w0) - L1[j] * conj(w1)));
424
425 L0[j] = w0;
426 L1[j] = w1;
427 }
428 }
429
430 if npiv == 2 {
431 pivots[0] = (i0 + k) | TOP_BIT;
432 pivots[1] = (i1 + k) | TOP_BIT;
433 } else {
434 pivots[0] = i0 + k;
435 }
436 k += npiv;
437 }
438
439 while k < n {
440 let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
441 let mut subdiag = subdiag.rb_mut().get_mut(k..);
442 let pivots = &mut pivots[k..];
443
444 let n = A.nrows();
445 let mut max_s = 0;
446 let mut max_diag = zero();
447
448 for s in 0..n {
449 let val = abs2(A[(s, s)]);
450 if val > max_diag {
451 max_diag = val;
452 max_s = s;
453 }
454 }
455
456 if max_s != 0 {
457 let (mut A0, mut As) = A.rb_mut().two_cols_mut(0, max_s);
458 core::mem::swap(&mut A0[0], &mut As[max_s]);
459
460 perm::swap_rows_idx(Aprev.rb_mut(), 0, max_s);
461 }
462
463 subdiag[0] = zero();
464 pivots[0] = max_s + k;
465
466 k += 1;
467 }
468
469 zip!(A.rb_mut().diagonal_mut().column_vector_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
470 zip!(subdiag.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
471}
472
473#[math]
474#[track_caller]
475fn l1_argmax<T: ComplexField>(col: ColRef<'_, T>) -> (Option<usize>, T::Real) {
476 let n = col.nrows();
477 if n == 0 {
478 return (None, zero());
479 }
480
481 let mut i = 0;
482 let mut best = zero();
483
484 for j in 0..n {
485 let val = abs1(col[j]);
486 if val > best {
487 best = val;
488 i = j;
489 }
490 }
491
492 (Some(i), best)
493}
494
495#[math]
496#[track_caller]
497fn offdiag_argmax<T: ComplexField>(A: MatRef<'_, T>, idx: usize) -> (Option<usize>, T::Real) {
498 let (mut col_argmax, col_max) = l1_argmax(A.rb().get(idx + 1.., idx));
499 col_argmax.as_mut().map(|col_argmax| *col_argmax += idx + 1);
500 let (row_argmax, row_max) = l1_argmax(A.rb().get(idx, ..idx).transpose());
501
502 if col_max > row_max {
503 (col_argmax, col_max)
504 } else {
505 (row_argmax, row_max)
506 }
507}
508
509#[math]
510fn update_and_offdiag_argmax<T: ComplexField>(
511 mut dst: ColMut<'_, T>,
512 Wl: MatRef<'_, T>,
513 Al: MatRef<'_, T>,
514 Ar: MatRef<'_, T>,
515 i0: usize,
516 par: Par,
517) -> (Option<usize>, T::Real) {
518 let n = Al.nrows();
519 for j in 0..i0 {
520 dst[j] = conj(Ar[(i0, j)]);
521 }
522 dst[i0] = zero();
523 for j in i0 + 1..n {
524 dst[j] = copy(Ar[(j, i0)]);
525 }
526
527 linalg::matmul::matmul(dst.rb_mut(), Accum::Add, Al.rb(), Wl.row(i0).adjoint(), -one::<T>(), par);
528 dst[i0] = zero();
529
530 let ret = l1_argmax(dst.rb());
531 dst[i0] = from_real(real(Ar[(i0, i0)]));
532 if n == 1 { (None, zero()) } else { ret }
533}
534
535#[math]
536#[inline(never)]
537fn lblt_blocked_step<T: ComplexField>(
538 alpha: T::Real,
539 W: MatMut<'_, T>,
540 A_left: MatMut<'_, T>,
541 A: MatMut<'_, T>,
542 subdiag: DiagMut<'_, T>,
543 pivots: &mut [usize],
544 rook: bool,
545 diagonal: bool,
546 par: Par,
547) -> usize {
548 let mut A = A;
549 let mut A_left = A_left;
550 let mut subdiag = subdiag;
551 let mut W = W;
552
553 let n = A.nrows();
554 let block_size = W.ncols();
555
556 assert!(all(A.nrows() == n, A.ncols() == n, W.nrows() == n, subdiag.dim() == n, block_size >= 2,));
557
558 let kmax = Ord::min(block_size - 1, n);
559 let mut k = 0usize;
560 while k < kmax {
561 let mut A = A.rb_mut();
562 let mut W = W.rb_mut();
563 let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
564 let A_left = A_left.rb_mut().get_mut(k.., ..);
565
566 let (mut Wl, mut Wr) = W.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
567 let (mut Al, mut Ar) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
568 let mut Al = Al.rb_mut();
569 let mut Wr = Wr.rb_mut().get_mut(.., ..2);
570
571 let npiv;
572 let mut i0 = if diagonal {
573 l1_argmax(Ar.rb().diagonal().column_vector()).0.unwrap()
574 } else {
575 0
576 };
577 let mut i1 = usize::MAX;
578
579 let mut nothing_to_do = false;
580
581 let (mut Wr0, mut Wr1) = Wr.rb_mut().two_cols_mut(0, 1);
582
583 let (r, mut gamma_i) = update_and_offdiag_argmax(Wr0.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i0, par);
584
585 if k + 1 == n || gamma_i == zero() {
586 nothing_to_do = true;
587 npiv = 1;
588 } else if abs(real(Ar[(i0, i0)])) >= alpha * gamma_i {
589 npiv = 1;
590 } else {
591 i1 = r.unwrap();
592 if rook {
593 loop {
594 let (s, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
595
596 if abs1(Ar[(i1, i1)]) >= alpha * gamma_r {
597 npiv = 1;
598 i0 = i1;
599 i1 = usize::MAX;
600 Wr0.copy_from(&Wr1);
601 break;
602 } else if s == Some(i0) || gamma_i == gamma_r {
603 npiv = 2;
604 break;
605 } else {
606 i0 = i1;
607 i1 = s.unwrap();
608 gamma_i = gamma_r;
609 Wr0.copy_from(&Wr1);
610 }
611 }
612 } else {
613 let (_, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
614
615 if abs(real(Ar[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
616 npiv = 1;
617 } else if abs(real(Ar[(i1, i1)])) >= alpha * gamma_r {
618 npiv = 1;
619 i0 = i1;
620 i1 = usize::MAX;
621 Wr0.copy_from(&Wr1);
622 } else {
623 npiv = 2;
624 }
625 }
626 }
627
628 if npiv == 2 && i0 > i1 {
629 perm::swap_cols_idx(Wr.rb_mut(), 0, 1);
630 (i0, i1) = (i1, i0);
631 }
632
633 let mut Wr = Wr.rb_mut().get_mut(.., ..npiv);
634
635 'next_iter: {
636 if i0 != 0 {
638 swap_self_adjoint(Ar.rb_mut(), 0, i0);
639 perm::swap_rows_idx(Al.rb_mut(), 0, i0);
640 perm::swap_rows_idx(Wl.rb_mut(), 0, i0);
641 perm::swap_rows_idx(Wr.rb_mut(), 0, i0);
642 }
643 if npiv == 2 && i1 != 1 {
644 swap_self_adjoint(Ar.rb_mut(), 1, i1);
645 perm::swap_rows_idx(Al.rb_mut(), 1, i1);
646 perm::swap_rows_idx(Wl.rb_mut(), 1, i1);
647 perm::swap_rows_idx(Wr.rb_mut(), 1, i1);
648 }
649
650 if nothing_to_do {
651 break 'next_iter;
652 }
653
654 if npiv == 1 {
655 let W0 = Wr.rb_mut().col_mut(0);
656
657 let diag = real(W0[0]);
658 let diag_inv = recip(diag);
659 subdiag[0] = zero();
660
661 let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(1, 1);
662 let W0 = W0.rb().get(1..);
663 let n = A.nrows();
664
665 let mut L = L.col_mut(0);
666 zip!(W0, L.rb_mut()).for_each(|unzip!(w, a): Zip!(&T, &mut T)| *a = mul_real(*w, diag_inv));
667
668 for j in 0..n {
669 A[(j, j)] = from_real(real(A[(j, j)]) - diag * abs2(L[j]));
670 }
671 } else {
672 let a00 = real(Wr[(0, 0)]);
673 let a11 = real(Wr[(1, 1)]);
674 let a10 = copy(Wr[(1, 0)]);
675
676 subdiag[0] = copy(a10);
677 subdiag[1] = zero();
678 Wr[(1, 0)] = zero();
679 Ar[(1, 0)] = zero();
680
681 let d10 = abs(a10);
682 let d10_inv = recip(d10);
683 let d00 = a00 * d10_inv;
684 let d11 = a11 * d10_inv;
685
686 let t = recip(d00 * d11 - one());
688 let d10 = mul_real(a10, d10_inv);
689 let d = t * d10_inv;
690
691 let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(2, 2);
694 let (mut L0, mut L1) = L.two_cols_mut(0, 1);
695 let Wr = Wr.rb().get(2.., ..);
696 let W0 = Wr.col(0);
697 let W1 = Wr.col(1);
698
699 let n = A.nrows();
700 for j in 0..n {
701 let x0 = copy(W0[j]);
702 let x1 = copy(W1[j]);
703
704 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
705 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
706
707 A[(j, j)] = from_real(real(A[(j, j)] - W0[j] * conj(w0) - W1[j] * conj(w1)));
708
709 L0[j] = w0;
710 L1[j] = w1;
711 }
712 }
713 }
714
715 let offset = A_left.ncols();
716
717 if npiv == 2 {
718 pivots[k] = (offset + i0 + k) | TOP_BIT;
719 pivots[k + 1] = (offset + i1 + k) | TOP_BIT;
720 } else {
721 pivots[k] = offset + i0 + k;
722 }
723 k += npiv;
724 }
725
726 let W = W.rb().get(k.., ..k);
727 let (_, _, Al, mut Ar) = A.rb_mut().split_at_mut(k, k);
728 let Al = Al.rb();
729
730 linalg::matmul::triangular::matmul(
731 Ar.rb_mut(),
732 BlockStructure::StrictTriangularLower,
733 Accum::Add,
734 W,
735 BlockStructure::Rectangular,
736 Al.adjoint(),
737 BlockStructure::Rectangular,
738 -one::<T>(),
739 par,
740 );
741
742 for j in 0..n - k {
743 Ar[(j, j)] = from_real(real(Ar[(j, j)]));
744 }
745
746 k
747}
748
749#[math]
750fn lblt_blocked<T: ComplexField>(
751 A: MatMut<'_, T>,
752 subdiag: DiagMut<'_, T>,
753 pivots: &mut [usize],
754 block_size: usize,
755 rook: bool,
756 diagonal: bool,
757 par: Par,
758 stack: &mut MemStack,
759) {
760 let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
761
762 let mut A = A;
763 let mut subdiag = subdiag.column_vector_mut();
764 let n = A.nrows();
765
766 let mut k = 0;
767 while k < n {
768 let (_, _, mut A_left, A_right) = A.rb_mut().split_at_mut(k, k);
769 let (mut W, _) = unsafe { temp_mat_uninit::<T, _, _>(n - k, block_size, stack) };
770 let W = W.as_mat_mut();
771
772 let next;
773
774 if block_size < 2 || n - k <= block_size {
775 lblt_unblocked(
776 copy(alpha),
777 A_left.rb_mut(),
778 A_right,
779 subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
780 &mut pivots[k..],
781 rook,
782 diagonal,
783 par,
784 );
785
786 next = n;
787 } else {
788 let block_size = lblt_blocked_step(
789 copy(alpha),
790 W,
791 A_left.rb_mut(),
792 A_right,
793 subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
794 &mut pivots[k..],
795 rook,
796 diagonal,
797 par,
798 );
799
800 next = k + block_size;
801 }
802
803 let pivots = &pivots[k..next];
804
805 let A_left = A.rb_mut().get_mut(.., ..k);
806
807 if A_left.ncols() > 0 {
808 match par {
809 Par::Seq => {
810 for mut col in A_left.col_iter_mut() {
811 for (i, &j) in core::iter::zip(k..next, pivots) {
812 let j = j & !TOP_BIT;
813 linalg::lu::partial_pivoting::factor::swap_elems(col.rb_mut(), i, j);
814 }
815 }
816 },
817 #[cfg(feature = "rayon")]
818 Par::Rayon(nthreads) => {
819 let nthreads = nthreads.get();
820 spindle::for_each(nthreads, A_left.par_col_iter_mut(), |mut col| {
821 for (i, &j) in core::iter::zip(k..next, pivots) {
822 let j = j & !TOP_BIT;
823 linalg::lu::partial_pivoting::factor::swap_elems(col.rb_mut(), i, j);
824 }
825 });
826 },
827 }
828 }
829
830 k = next;
831 }
832}
833
834#[math]
835#[inline(never)]
836fn lblt_unblocked<T: ComplexField>(
837 alpha: T::Real,
838 A_left: MatMut<'_, T>,
839 A: MatMut<'_, T>,
840 subdiag: DiagMut<'_, T>,
841 pivots: &mut [usize],
842 rook: bool,
843 diagonal: bool,
844 par: Par,
845) {
846 let _ = par;
847 let mut A = A;
848 let mut A_left = A_left;
849 let mut subdiag = subdiag;
850
851 let n = A.nrows();
852 assert!(all(A.nrows() == n, A.ncols() == n, subdiag.dim() == n));
853
854 let mut k = 0usize;
855 while k < n {
856 let (_, _, mut L_prev, mut A) = A.rb_mut().split_at_mut(k, k);
857 let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
858 let A_left = A_left.rb_mut().get_mut(k.., ..);
859
860 let npiv;
861
862 let mut i0 = if diagonal {
864 l1_argmax(A.rb().diagonal().column_vector()).0.unwrap()
865 } else {
866 0
867 };
868 let mut i1 = usize::MAX;
869
870 let (r, mut gamma_i) = offdiag_argmax(A.rb(), i0);
872
873 let mut nothing_to_do = false;
874
875 if k + 1 == n || gamma_i == zero() {
876 nothing_to_do = true;
877 npiv = 1;
878 } else if abs(real(A[(i0, i0)])) >= alpha * gamma_i {
879 npiv = 1;
880 } else {
881 i1 = r.unwrap();
882
883 if rook {
885 loop {
886 let (s, gamma_r) = offdiag_argmax(A.rb(), i1);
887
888 if abs1(A[(i1, i1)]) >= alpha * gamma_r {
889 npiv = 1;
890 i0 = i1;
891 i1 = usize::MAX;
892 break;
893 } else if gamma_i == gamma_r {
894 npiv = 2;
895 break;
896 } else {
897 i0 = i1;
898 i1 = s.unwrap();
899 gamma_i = gamma_r;
900 }
901 }
902 } else {
903 let (_, gamma_r) = offdiag_argmax(A.rb(), i1);
904 if abs(real(A[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
905 npiv = 1;
906 } else if abs(real(A[(i1, i1)])) >= alpha * gamma_r {
907 npiv = 1;
908 i0 = i1;
909 } else {
910 npiv = 2;
911 }
912 }
913 }
914
915 if npiv == 2 && i0 > i1 {
916 (i0, i1) = (i1, i0);
917 }
918
919 'next_iter: {
920 if i0 != 0 {
922 swap_self_adjoint(A.rb_mut(), 0, i0);
923 perm::swap_rows_idx(L_prev.rb_mut(), 0, i0);
924 }
925 if npiv == 2 && i1 != 1 {
926 swap_self_adjoint(A.rb_mut(), 1, i1);
927 perm::swap_rows_idx(L_prev.rb_mut(), 1, i1);
928 }
929
930 if nothing_to_do {
931 break 'next_iter;
932 }
933
934 if npiv == 1 {
936 let diag = real(A[(0, 0)]);
937 let diag_inv = recip(diag);
938 subdiag[0] = zero();
939
940 let (_, _, L, A) = A.rb_mut().split_at_mut(1, 1);
941 let L = L.col_mut(0);
942 rank1_update(A, L, diag_inv);
943 } else {
944 let a00 = real(A[(0, 0)]);
945 let a11 = real(A[(1, 1)]);
946 let a10 = copy(A[(1, 0)]);
947
948 subdiag[0] = copy(a10);
949 subdiag[1] = zero();
950 A[(1, 0)] = zero();
951
952 let d10 = abs(a10);
953 let d10_inv = recip(d10);
954 let d00 = a00 * d10_inv;
955 let d11 = a11 * d10_inv;
956
957 let t = recip(d00 * d11 - one());
959 let d10 = mul_real(a10, d10_inv);
960 let d = t * d10_inv;
961
962 let (_, _, L, A) = A.rb_mut().split_at_mut(2, 2);
965 let (L0, L1) = L.two_cols_mut(0, 1);
966 rank2_update(A, L0, L1, d, d00, d10, d11);
967 }
968 }
969
970 let offset = A_left.ncols();
971 if npiv == 2 {
972 pivots[k] = (offset + i0 + k) | TOP_BIT;
973 pivots[k + 1] = (offset + i1 + k) | TOP_BIT;
974 } else {
975 pivots[k] = offset + i0 + k;
976 }
977 k += npiv;
978 }
979}
980
981impl<T: ComplexField> Auto<T> for LbltParams {
982 fn auto() -> Self {
983 Self {
984 pivoting: PivotingStrategy::PartialDiag,
985 block_size: 64,
986 par_threshold: 128 * 128,
987 non_exhaustive: NonExhaustive(()),
988 }
989 }
990}
991
992pub fn rank2_update<'a, T: ComplexField>(
993 mut A: MatMut<'a, T>,
994 mut L0: ColMut<'a, T>,
995 mut L1: ColMut<'a, T>,
996 d: T::Real,
997 d00: T::Real,
998 d10: T,
999 d11: T::Real,
1000) {
1001 if const { T::SIMD_CAPABILITIES.is_simd() } {
1002 if let (Some(A), Some(L0), Some(L1)) = (
1003 A.rb_mut().try_as_col_major_mut(),
1004 L0.rb_mut().try_as_col_major_mut(),
1005 L1.rb_mut().try_as_col_major_mut(),
1006 ) {
1007 rank2_update_simd(A, L0, L1, d, d00, d10, d11);
1008 } else {
1009 rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
1010 }
1011 } else {
1012 rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
1013 }
1014}
1015
1016#[math]
1017pub fn rank2_update_simd<'a, T: ComplexField>(
1018 A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1019 L0: ColMut<'a, T, usize, ContiguousFwd>,
1020 L1: ColMut<'a, T, usize, ContiguousFwd>,
1021 d: T::Real,
1022 d00: T::Real,
1023 d10: T,
1024 d11: T::Real,
1025) {
1026 struct Impl<'a, T: ComplexField> {
1027 A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1028 L0: ColMut<'a, T, usize, ContiguousFwd>,
1029 L1: ColMut<'a, T, usize, ContiguousFwd>,
1030 d: T::Real,
1031 d00: T::Real,
1032 d10: T,
1033 d11: T::Real,
1034 }
1035
1036 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1037 type Output = ();
1038
1039 #[inline(always)]
1040 fn with_simd<S: pulp::Simd>(self, simd: S) {
1041 let Self {
1042 mut A,
1043 mut L0,
1044 mut L1,
1045 d,
1046 d00,
1047 d10,
1048 d11,
1049 } = self;
1050 let n = A.nrows();
1051 for j in 0..n {
1052 let x0 = copy(L0[j]);
1053 let x1 = copy(L1[j]);
1054 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1055 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1056
1057 with_dim!({
1058 let subrange_len = n - j;
1059 });
1060 {
1061 let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1062 let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1063 let L1 = L1.rb().get(j..).as_row_shape(subrange_len);
1064 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1065 let (head, body, tail) = simd.indices();
1066
1067 let w0_conj = conj(w0);
1068 let w1_conj = conj(w1);
1069 let w0_conj_neg = -w0_conj;
1070 let w1_conj_neg = -w1_conj;
1071 let w0_splat = simd.splat(&w0_conj_neg);
1072 let w1_splat = simd.splat(&w1_conj_neg);
1073
1074 if let Some(i) = head {
1075 let mut acc = simd.read(A.rb(), i);
1076 let l0_val = simd.read(L0, i);
1077 let l1_val = simd.read(L1, i);
1078 acc = simd.mul_add(l0_val, w0_splat, acc);
1079 acc = simd.mul_add(l1_val, w1_splat, acc);
1080 simd.write(A.rb_mut(), i, acc);
1081 }
1082
1083 for i in body.clone() {
1084 let mut acc = simd.read(A.rb(), i);
1085 let l0_val = simd.read(L0, i);
1086 let l1_val = simd.read(L1, i);
1087 acc = simd.mul_add(l0_val, w0_splat, acc);
1088 acc = simd.mul_add(l1_val, w1_splat, acc);
1089 simd.write(A.rb_mut(), i, acc);
1090 }
1091
1092 if let Some(i) = tail {
1093 let mut acc = simd.read(A.rb(), i);
1094 let l0_val = simd.read(L0, i);
1095 let l1_val = simd.read(L1, i);
1096 acc = simd.mul_add(l0_val, w0_splat, acc);
1097 acc = simd.mul_add(l1_val, w1_splat, acc);
1098 simd.write(A.rb_mut(), i, acc);
1099 }
1100 }
1101 A[(j, j)] = from_real(real(A[(j, j)]));
1102
1103 L0[j] = w0;
1104 L1[j] = w1;
1105 }
1106 }
1107 }
1108 dispatch!(Impl { A, L0, L1, d, d00, d10, d11 }, Impl, T)
1109}
1110
1111#[math]
1112pub fn rank2_update_fallback<'a, T: ComplexField>(
1113 mut A: MatMut<'a, T>,
1114 mut L0: ColMut<'a, T>,
1115 mut L1: ColMut<'a, T>,
1116 d: T::Real,
1117 d00: T::Real,
1118 d10: T,
1119 d11: T::Real,
1120) {
1121 let n = A.nrows();
1122 for j in 0..n {
1123 let x0 = copy(L0[j]);
1124 let x1 = copy(L1[j]);
1125
1126 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1127 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1128
1129 for i in j..n {
1130 A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
1131 }
1132 A[(j, j)] = from_real(real(A[(j, j)]));
1133
1134 L0[j] = w0;
1135 L1[j] = w1;
1136 }
1137}
1138
1139pub fn rank1_update<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1140 if const { T::SIMD_CAPABILITIES.is_simd() } {
1141 if let (Some(A), Some(L0)) = (A.rb_mut().try_as_col_major_mut(), L0.rb_mut().try_as_col_major_mut()) {
1142 rank1_update_simd(A, L0, d);
1143 } else {
1144 rank1_update_fallback(A, L0, d);
1145 }
1146 } else {
1147 rank1_update_fallback(A, L0, d);
1148 }
1149}
1150
1151#[math]
1152pub fn rank1_update_simd<'a, T: ComplexField>(A: MatMut<'a, T, usize, usize, ContiguousFwd>, L0: ColMut<'a, T, usize, ContiguousFwd>, d: T::Real) {
1153 struct Impl<'a, T: ComplexField> {
1154 A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1155 L0: ColMut<'a, T, usize, ContiguousFwd>,
1156 d: T::Real,
1157 }
1158
1159 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1160 type Output = ();
1161
1162 #[inline(always)]
1163 fn with_simd<S: pulp::Simd>(self, simd: S) {
1164 let Self { mut A, mut L0, d } = self;
1165
1166 let n = A.nrows();
1167 for j in 0..n {
1168 let x0 = copy(L0[j]);
1169 let w0 = mul_real(x0, d);
1170
1171 with_dim!({
1172 let subrange_len = n - j;
1173 });
1174 {
1175 let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1176 let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1177 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1178 let (head, body, tail) = simd.indices();
1179
1180 let w0_conj = conj(w0);
1181 let w0_conj_neg = -w0_conj;
1182 let w0_splat = simd.splat(&w0_conj_neg);
1183
1184 if let Some(i) = head {
1185 let mut acc = simd.read(A.rb(), i);
1186 let l0_val = simd.read(L0, i);
1187 acc = simd.mul_add(l0_val, w0_splat, acc);
1188 simd.write(A.rb_mut(), i, acc);
1189 }
1190
1191 for i in body.clone() {
1192 let mut acc = simd.read(A.rb(), i);
1193 let l0_val = simd.read(L0, i);
1194 acc = simd.mul_add(l0_val, w0_splat, acc);
1195 simd.write(A.rb_mut(), i, acc);
1196 }
1197
1198 if let Some(i) = tail {
1199 let mut acc = simd.read(A.rb(), i);
1200 let l0_val = simd.read(L0, i);
1201 acc = simd.mul_add(l0_val, w0_splat, acc);
1202 simd.write(A.rb_mut(), i, acc);
1203 }
1204 }
1205 A[(j, j)] = from_real(real(A[(j, j)]));
1206
1207 L0[j] = w0;
1208 }
1209 }
1210 }
1211 dispatch!(Impl { A, L0, d }, Impl, T)
1212}
1213
1214#[math]
1215pub fn rank1_update_fallback<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1216 let n = A.nrows();
1217 for j in 0..n {
1218 let x0 = copy(L0[j]);
1219 let w0 = mul_real(x0, d);
1220
1221 for i in j..n {
1222 A[(i, j)] = A[(i, j)] - L0[i] * conj(w0);
1223 }
1224 A[(j, j)] = from_real(real(A[(j, j)]));
1225 L0[j] = w0;
1226 }
1227}
1228pub fn cholesky_in_place_scratch<I: Index, T: ComplexField>(dim: usize, par: Par, params: Spec<LbltParams, T>) -> StackReq {
1231 let params = params.config;
1232 let _ = par;
1233 let mut bs = params.block_size;
1234 if bs < 2 || dim <= bs {
1235 bs = 0;
1236 }
1237 StackReq::new::<usize>(dim).and(temp_mat_scratch::<T>(dim, bs))
1238}
1239
1240#[derive(Copy, Clone, Debug)]
1242pub struct LbltInfo {
1243 pub transposition_count: usize,
1245}
1246
1247#[track_caller]
1261#[math]
1262pub fn cholesky_in_place<'out, I: Index, T: ComplexField>(
1263 A: MatMut<'_, T>,
1264 subdiag: DiagMut<'_, T>,
1265 perm: &'out mut [I],
1266 perm_inv: &'out mut [I],
1267 par: Par,
1268 stack: &mut MemStack,
1269 params: Spec<LbltParams, T>,
1270) -> (LbltInfo, PermRef<'out, I>) {
1271 let params = params.config;
1272
1273 let truncate = <I::Signed as SignedIndex>::truncate;
1274
1275 let n = A.nrows();
1276 assert!(all(A.nrows() == A.ncols(), subdiag.dim() == n, perm.len() == n, perm_inv.len() == n));
1277
1278 #[cfg(feature = "perf-warn")]
1279 if A.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) {
1280 if A.col_stride().unsigned_abs() == 1 {
1281 log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1282 matrix. Found row-major matrix.");
1283 } else {
1284 log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1285 matrix. Found matrix with generic strides.");
1286 }
1287 }
1288
1289 let (mut pivots, stack) = stack.make_with::<usize>(n, |_| 0);
1290 let pivots = &mut *pivots;
1291
1292 let mut bs = params.block_size;
1293 if bs < 2 || n <= bs {
1294 bs = 0;
1295 }
1296
1297 let (rook, diagonal) = match params.pivoting {
1298 PivotingStrategy::Partial => (false, false),
1299 PivotingStrategy::PartialDiag => (false, true),
1300 PivotingStrategy::Rook => (true, false),
1301 PivotingStrategy::RookDiag => (true, true),
1302 _ => (false, false),
1303 };
1304
1305 if params.pivoting == PivotingStrategy::Full {
1306 lblt_full_piv(A, subdiag, pivots, par, params);
1307 } else {
1308 lblt_blocked(A, subdiag, pivots, bs, rook, diagonal, par, stack);
1309 }
1310
1311 for (i, p) in perm.iter_mut().enumerate() {
1312 *p = I::from_signed(truncate(i));
1313 }
1314
1315 let mut transposition_count = 0usize;
1316 for i in 0..n {
1317 let p = pivots[i] & !TOP_BIT;
1318 if i != p {
1319 transposition_count += 1;
1320 }
1321 perm.swap(i, p);
1322 }
1323 for (i, &p) in perm.iter().enumerate() {
1324 perm_inv[p.to_signed().zx()] = I::from_signed(truncate(i));
1325 }
1326
1327 (LbltInfo { transposition_count }, unsafe { PermRef::new_unchecked(perm, perm_inv, n) })
1328}