faer/linalg/cholesky/bunch_kaufman/
factor.rs

1use crate::internal_prelude::*;
2use crate::{assert, perm};
3use linalg::matmul::triangular::BlockStructure;
4
5const TOP_BIT: usize = 1 << (usize::BITS - 1);
6
7/// pivoting strategy for choosing the pivots
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
9#[non_exhaustive]
10pub enum PivotingStrategy {
11	/// deprecated, corresponds to partial pivoting
12	#[deprecated]
13	Diagonal,
14
15	/// searches for the k-th pivot in the k-th column
16	Partial,
17	/// searches for the k-th pivot in the k-th column, as well as the tail of the diagonal of the
18	/// matrix
19	PartialDiag,
20	/// searches for pivots that are locally optimal
21	Rook,
22	/// searches for pivots that are locally optimal, as well as the tail of the diagonal of the
23	/// matrix
24	RookDiag,
25
26	/// searches for pivots that are globally optimal
27	Full,
28}
29
30/// tuning parameters for the decomposition
31#[derive(Copy, Clone, Debug)]
32pub struct LbltParams {
33	/// pivoting strategy
34	pub pivoting: PivotingStrategy,
35	/// block size of the algorithm
36	pub block_size: usize,
37
38	/// threshold at which size parallelism should be disabled
39	pub par_threshold: usize,
40
41	#[doc(hidden)]
42	pub non_exhaustive: NonExhaustive,
43}
44
45#[math]
46fn swap_self_adjoint<T: ComplexField>(A: MatMut<'_, T>, i: usize, j: usize) {
47	assert_ne!(i, j);
48
49	let mut A = A;
50	let (i, j) = (Ord::min(i, j), Ord::max(i, j));
51
52	perm::swap_cols_idx(A.rb_mut().get_mut(j + 1.., ..), i, j);
53	perm::swap_rows_idx(A.rb_mut().get_mut(.., ..i), i, j);
54
55	let tmp = real(A[(i, i)]);
56	A[(i, i)] = from_real(real(A[(j, j)]));
57	A[(j, j)] = from_real(tmp);
58
59	A[(j, i)] = conj(A[(j, i)]);
60
61	let (Ai, Aj) = A.split_at_row_mut(j);
62	let Ai = Ai.get_mut(i + 1..j, i);
63	let Aj = Aj.get_mut(0, i + 1..j).transpose_mut();
64	zip!(Ai, Aj).for_each(|unzip!(x, y): Zip!(&mut _, &mut _)| {
65		let tmp = conj(*x);
66		*x = conj(*y);
67		*y = tmp;
68	});
69}
70
71#[math]
72fn rank_1_update_and_argmax_fallback<'M, 'N, T: ComplexField>(
73	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
74	L: ColRef<'_, T, Dim<'N>>,
75	d: T::Real,
76	start: IdxInc<'N>,
77	end: IdxInc<'N>,
78) -> (usize, usize, T::Real) {
79	let mut A = A;
80	let n = A.nrows();
81
82	let mut max_j = n.idx(0);
83	let mut max_i = n.idx(0);
84	let mut max_offdiag = zero();
85
86	for j in start.to(end) {
87		for i in j.next().to(n.end()) {
88			A[(i, j)] = A[(i, j)] - mul_real(L[i] * conj(L[j]), d);
89			let val = abs2(A[(i, j)]);
90			if val > max_offdiag {
91				max_offdiag = val;
92				max_i = i;
93				max_j = j;
94			}
95		}
96	}
97
98	(*max_i, *max_j, max_offdiag)
99}
100
101#[math]
102fn rank_2_update_and_argmax_fallback<'N, T: ComplexField>(
103	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
104	L0: ColRef<'_, T, Dim<'N>>,
105	L1: ColRef<'_, T, Dim<'N>>,
106	d: T::Real,
107	d00: T::Real,
108	d11: T::Real,
109	d10: T,
110	start: IdxInc<'N>,
111	end: IdxInc<'N>,
112) -> (usize, usize, T::Real) {
113	let mut A = A;
114	let n = A.nrows();
115
116	let mut max_j = n.idx(0);
117	let mut max_i = n.idx(0);
118	let mut max_offdiag = zero();
119
120	for j in start.to(end) {
121		let x0 = copy(L0[j]);
122		let x1 = copy(L1[j]);
123
124		let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
125		let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
126
127		for i in j.next().to(n.end()) {
128			A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
129
130			let val = abs2(A[(i, j)]);
131			if val > max_offdiag {
132				max_offdiag = val;
133				max_i = i;
134				max_j = j;
135			}
136		}
137	}
138	(*max_i, *max_j, max_offdiag)
139}
140
141#[math]
142fn rank_1_update_and_argmax_seq<'M, 'N, T: ComplexField>(
143	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
144	L: ColRef<'_, T, Dim<'N>>,
145	d: T::Real,
146	start: IdxInc<'N>,
147	end: IdxInc<'N>,
148) -> (usize, usize, T::Real) {
149	rank_1_update_and_argmax_fallback(A, L, d, start, end)
150}
151
152#[math]
153fn rank_2_update_and_argmax_seq<'N, T: ComplexField>(
154	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
155	L0: ColRef<'_, T, Dim<'N>>,
156	L1: ColRef<'_, T, Dim<'N>>,
157	d: T::Real,
158	d00: T::Real,
159	d11: T::Real,
160	d10: T,
161	start: IdxInc<'N>,
162	end: IdxInc<'N>,
163) -> (usize, usize, T::Real) {
164	rank_2_update_and_argmax_fallback(A, L0, L1, d, d00, d11, d10, start, end)
165}
166
167#[math]
168fn rank_1_update_and_argmax<T: ComplexField>(A: MatMut<'_, T>, L: ColRef<'_, T>, d: T::Real, par: Par) -> (usize, usize, T::Real) {
169	with_dim!(N, A.nrows());
170
171	match par {
172		Par::Seq => rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), d, IdxInc::ZERO, N.end()),
173		#[cfg(feature = "rayon")]
174		Par::Rayon(nthreads) => {
175			use rayon::prelude::*;
176			let nthreads = nthreads.get();
177			let n = *N;
178
179			// to check that integers can be represented exactly as floats
180			assert!((n as u64) < (1u64 << 50));
181
182			let idx_to_col_start = |idx: usize| {
183				let idx_as_percent = idx as f64 / nthreads as f64;
184				let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
185				(col_start_percent * n as f64) as usize
186			};
187
188			let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
189
190			spindle::for_each(nthreads, r.par_iter_mut().enumerate(), |(idx, out)| {
191				let A = unsafe { A.rb().const_cast() };
192				let start = N.idx_inc(idx_to_col_start(idx));
193				let end = N.idx_inc(idx_to_col_start(idx + 1));
194
195				*out = rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), copy(d), start, end);
196			});
197
198			r.into_iter()
199				.max_by(|(_, _, a), (_, _, b)| {
200					if a == b {
201						core::cmp::Ordering::Equal
202					} else if a > b {
203						core::cmp::Ordering::Greater
204					} else {
205						core::cmp::Ordering::Less
206					}
207				})
208				.unwrap()
209		},
210	}
211}
212
213#[math]
214fn rank_2_update_and_argmax<'N, T: ComplexField>(
215	A: MatMut<'_, T>,
216	L0: ColRef<'_, T>,
217	L1: ColRef<'_, T>,
218	d: T::Real,
219	d00: T::Real,
220	d11: T::Real,
221	d10: T,
222	par: Par,
223) -> (usize, usize, T::Real) {
224	with_dim!(N, A.nrows());
225
226	match par {
227		Par::Seq => rank_2_update_and_argmax_seq(
228			A.as_shape_mut(N, N),
229			L0.as_row_shape(N),
230			L1.as_row_shape(N),
231			d,
232			d00,
233			d11,
234			d10,
235			IdxInc::ZERO,
236			N.end(),
237		),
238		#[cfg(feature = "rayon")]
239		Par::Rayon(nthreads) => {
240			use rayon::prelude::*;
241			let nthreads = nthreads.get();
242			let n = *N;
243
244			// to check that integers can be represented exactly as floats
245			assert!((n as u64) < (1u64 << 50));
246
247			let idx_to_col_start = |idx: usize| {
248				let idx_as_percent = idx as f64 / nthreads as f64;
249				let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
250				(col_start_percent * n as f64) as usize
251			};
252
253			let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
254
255			spindle::for_each(nthreads, r.par_iter_mut().enumerate(), |(idx, out)| {
256				let A = unsafe { A.rb().const_cast() };
257				let start = N.idx_inc(idx_to_col_start(idx));
258				let end = N.idx_inc(idx_to_col_start(idx + 1));
259
260				*out = rank_2_update_and_argmax_seq(
261					A.as_shape_mut(N, N),
262					L0.as_row_shape(N),
263					L1.as_row_shape(N),
264					copy(d),
265					copy(d00),
266					copy(d11),
267					copy(d10),
268					start,
269					end,
270				);
271			});
272
273			r.into_iter()
274				.max_by(|(_, _, a), (_, _, b)| {
275					if a == b {
276						core::cmp::Ordering::Equal
277					} else if a < b {
278						core::cmp::Ordering::Less
279					} else {
280						core::cmp::Ordering::Greater
281					}
282				})
283				.unwrap()
284		},
285	}
286}
287
288#[math]
289fn lblt_full_piv<T: ComplexField>(A: MatMut<'_, T>, subdiag: DiagMut<'_, T>, pivots: &mut [usize], par: Par, params: LbltParams) {
290	let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
291	let alpha = alpha * alpha;
292
293	let mut A = A;
294	let mut subdiag = subdiag.column_vector_mut();
295	let mut par = par;
296	let n = A.nrows();
297
298	let scale_fwd = A.norm_max();
299	let scale_bwd = recip(scale_fwd);
300	zip!(A.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_bwd));
301
302	let mut max_i = 0;
303	let mut max_j = 0;
304	let mut max_offdiag = zero();
305
306	for j in 0..n {
307		for i in j + 1..n {
308			let val = abs2(A[(i, j)]);
309			if val > max_offdiag {
310				max_offdiag = val;
311				max_i = i;
312				max_j = j;
313			}
314		}
315	}
316
317	let mut k = 0;
318	while k < n {
319		if max_offdiag == zero() {
320			break;
321		}
322
323		let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
324		let mut subdiag = subdiag.rb_mut().get_mut(k..);
325		let pivots = &mut pivots[k..];
326
327		let n = A.nrows();
328		let mut max_s = 0;
329		let mut max_diag = zero();
330
331		for s in 0..n {
332			let val = abs2(A[(s, s)]);
333			if val > max_diag {
334				max_diag = val;
335				max_s = s;
336			}
337		}
338
339		let npiv;
340		let i0;
341		let i1;
342
343		if max_diag >= alpha * max_offdiag {
344			npiv = 1;
345			i0 = max_s;
346			i1 = usize::MAX;
347		} else {
348			npiv = 2;
349			i0 = max_j;
350			i1 = max_i;
351		}
352
353		let rem = n - npiv;
354		if rem * rem < params.par_threshold {
355			par = Par::Seq;
356		}
357
358		// swap pivots to first (and second) column
359		if i0 != 0 {
360			swap_self_adjoint(A.rb_mut(), 0, i0);
361			perm::swap_rows_idx(Aprev.rb_mut(), 0, i0);
362		}
363		if npiv == 2 && i1 != 1 {
364			swap_self_adjoint(A.rb_mut(), 1, i1);
365			perm::swap_rows_idx(Aprev.rb_mut(), 1, i1);
366		}
367
368		if npiv == 1 {
369			let diag = real(A[(0, 0)]);
370			let diag_inv = recip(diag);
371			subdiag[0] = zero();
372
373			let (_, _, L, mut A) = A.rb_mut().split_at_mut(1, 1);
374			let n = A.nrows();
375			let mut L = L.col_mut(0);
376
377			zip!(L.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, diag_inv));
378
379			for i in 0..n {
380				A[(i, i)] = from_real(real(A[(i, i)]) - diag * abs2(L[i]));
381			}
382
383			if n < params.par_threshold {}
384			if n != 0 {
385				(max_i, max_j, max_offdiag) = rank_1_update_and_argmax(A.rb_mut(), L.rb(), diag, par);
386			}
387		} else {
388			let a00 = real(A[(0, 0)]);
389			let a11 = real(A[(1, 1)]);
390			let a10 = copy(A[(1, 0)]);
391
392			subdiag[0] = copy(a10);
393			subdiag[1] = zero();
394			A[(1, 0)] = zero();
395
396			let d10 = abs(a10);
397			let d10_inv = recip(d10);
398			let d00 = a00 * d10_inv;
399			let d11 = a11 * d10_inv;
400
401			// t = (d00/|d10| * d11/|d10| - 1.0)
402			let t = recip(d00 * d11 - one());
403			let d10 = mul_real(a10, d10_inv);
404			let d = t * d10_inv;
405
406			//         [ a00  a01 ]
407			// L_new * [ a10  a11 ] = L
408			let (_, _, L, mut A) = A.rb_mut().split_at_mut(2, 2);
409			let (mut L0, mut L1) = L.two_cols_mut(0, 1);
410			let n = A.nrows();
411
412			if n != 0 {
413				(max_i, max_j, max_offdiag) = rank_2_update_and_argmax(A.rb_mut(), L0.rb(), L1.rb(), copy(d), copy(d00), copy(d11), copy(d10), par);
414			}
415
416			for j in 0..n {
417				let x0 = copy(L0[j]);
418				let x1 = copy(L1[j]);
419
420				let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
421				let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
422
423				A[(j, j)] = from_real(real(A[(j, j)] - L0[j] * conj(w0) - L1[j] * conj(w1)));
424
425				L0[j] = w0;
426				L1[j] = w1;
427			}
428		}
429
430		if npiv == 2 {
431			pivots[0] = (i0 + k) | TOP_BIT;
432			pivots[1] = (i1 + k) | TOP_BIT;
433		} else {
434			pivots[0] = i0 + k;
435		}
436		k += npiv;
437	}
438
439	while k < n {
440		let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
441		let mut subdiag = subdiag.rb_mut().get_mut(k..);
442		let pivots = &mut pivots[k..];
443
444		let n = A.nrows();
445		let mut max_s = 0;
446		let mut max_diag = zero();
447
448		for s in 0..n {
449			let val = abs2(A[(s, s)]);
450			if val > max_diag {
451				max_diag = val;
452				max_s = s;
453			}
454		}
455
456		if max_s != 0 {
457			let (mut A0, mut As) = A.rb_mut().two_cols_mut(0, max_s);
458			core::mem::swap(&mut A0[0], &mut As[max_s]);
459
460			perm::swap_rows_idx(Aprev.rb_mut(), 0, max_s);
461		}
462
463		subdiag[0] = zero();
464		pivots[0] = max_s + k;
465
466		k += 1;
467	}
468
469	zip!(A.rb_mut().diagonal_mut().column_vector_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
470	zip!(subdiag.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
471}
472
473#[math]
474#[track_caller]
475fn l1_argmax<T: ComplexField>(col: ColRef<'_, T>) -> (Option<usize>, T::Real) {
476	let n = col.nrows();
477	if n == 0 {
478		return (None, zero());
479	}
480
481	let mut i = 0;
482	let mut best = zero();
483
484	for j in 0..n {
485		let val = abs1(col[j]);
486		if val > best {
487			best = val;
488			i = j;
489		}
490	}
491
492	(Some(i), best)
493}
494
495#[math]
496#[track_caller]
497fn offdiag_argmax<T: ComplexField>(A: MatRef<'_, T>, idx: usize) -> (Option<usize>, T::Real) {
498	let (mut col_argmax, col_max) = l1_argmax(A.rb().get(idx + 1.., idx));
499	col_argmax.as_mut().map(|col_argmax| *col_argmax += idx + 1);
500	let (row_argmax, row_max) = l1_argmax(A.rb().get(idx, ..idx).transpose());
501
502	if col_max > row_max {
503		(col_argmax, col_max)
504	} else {
505		(row_argmax, row_max)
506	}
507}
508
509#[math]
510fn update_and_offdiag_argmax<T: ComplexField>(
511	mut dst: ColMut<'_, T>,
512	Wl: MatRef<'_, T>,
513	Al: MatRef<'_, T>,
514	Ar: MatRef<'_, T>,
515	i0: usize,
516	par: Par,
517) -> (Option<usize>, T::Real) {
518	let n = Al.nrows();
519	for j in 0..i0 {
520		dst[j] = conj(Ar[(i0, j)]);
521	}
522	dst[i0] = zero();
523	for j in i0 + 1..n {
524		dst[j] = copy(Ar[(j, i0)]);
525	}
526
527	linalg::matmul::matmul(dst.rb_mut(), Accum::Add, Al.rb(), Wl.row(i0).adjoint(), -one::<T>(), par);
528	dst[i0] = zero();
529
530	let ret = l1_argmax(dst.rb());
531	dst[i0] = from_real(real(Ar[(i0, i0)]));
532	if n == 1 { (None, zero()) } else { ret }
533}
534
535#[math]
536#[inline(never)]
537fn lblt_blocked_step<T: ComplexField>(
538	alpha: T::Real,
539	W: MatMut<'_, T>,
540	A_left: MatMut<'_, T>,
541	A: MatMut<'_, T>,
542	subdiag: DiagMut<'_, T>,
543	pivots: &mut [usize],
544	rook: bool,
545	diagonal: bool,
546	par: Par,
547) -> usize {
548	let mut A = A;
549	let mut A_left = A_left;
550	let mut subdiag = subdiag;
551	let mut W = W;
552
553	let n = A.nrows();
554	let block_size = W.ncols();
555
556	assert!(all(A.nrows() == n, A.ncols() == n, W.nrows() == n, subdiag.dim() == n, block_size >= 2,));
557
558	let kmax = Ord::min(block_size - 1, n);
559	let mut k = 0usize;
560	while k < kmax {
561		let mut A = A.rb_mut();
562		let mut W = W.rb_mut();
563		let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
564		let A_left = A_left.rb_mut().get_mut(k.., ..);
565
566		let (mut Wl, mut Wr) = W.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
567		let (mut Al, mut Ar) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
568		let mut Al = Al.rb_mut();
569		let mut Wr = Wr.rb_mut().get_mut(.., ..2);
570
571		let npiv;
572		let mut i0 = if diagonal {
573			l1_argmax(Ar.rb().diagonal().column_vector()).0.unwrap()
574		} else {
575			0
576		};
577		let mut i1 = usize::MAX;
578
579		let mut nothing_to_do = false;
580
581		let (mut Wr0, mut Wr1) = Wr.rb_mut().two_cols_mut(0, 1);
582
583		let (r, mut gamma_i) = update_and_offdiag_argmax(Wr0.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i0, par);
584
585		if k + 1 == n || gamma_i == zero() {
586			nothing_to_do = true;
587			npiv = 1;
588		} else if abs(real(Ar[(i0, i0)])) >= alpha * gamma_i {
589			npiv = 1;
590		} else {
591			i1 = r.unwrap();
592			if rook {
593				loop {
594					let (s, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
595
596					if abs1(Ar[(i1, i1)]) >= alpha * gamma_r {
597						npiv = 1;
598						i0 = i1;
599						i1 = usize::MAX;
600						Wr0.copy_from(&Wr1);
601						break;
602					} else if s == Some(i0) || gamma_i == gamma_r {
603						npiv = 2;
604						break;
605					} else {
606						i0 = i1;
607						i1 = s.unwrap();
608						gamma_i = gamma_r;
609						Wr0.copy_from(&Wr1);
610					}
611				}
612			} else {
613				let (_, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
614
615				if abs(real(Ar[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
616					npiv = 1;
617				} else if abs(real(Ar[(i1, i1)])) >= alpha * gamma_r {
618					npiv = 1;
619					i0 = i1;
620					i1 = usize::MAX;
621					Wr0.copy_from(&Wr1);
622				} else {
623					npiv = 2;
624				}
625			}
626		}
627
628		if npiv == 2 && i0 > i1 {
629			perm::swap_cols_idx(Wr.rb_mut(), 0, 1);
630			(i0, i1) = (i1, i0);
631		}
632
633		let mut Wr = Wr.rb_mut().get_mut(.., ..npiv);
634
635		'next_iter: {
636			// swap pivots to first (and second) column
637			if i0 != 0 {
638				swap_self_adjoint(Ar.rb_mut(), 0, i0);
639				perm::swap_rows_idx(Al.rb_mut(), 0, i0);
640				perm::swap_rows_idx(Wl.rb_mut(), 0, i0);
641				perm::swap_rows_idx(Wr.rb_mut(), 0, i0);
642			}
643			if npiv == 2 && i1 != 1 {
644				swap_self_adjoint(Ar.rb_mut(), 1, i1);
645				perm::swap_rows_idx(Al.rb_mut(), 1, i1);
646				perm::swap_rows_idx(Wl.rb_mut(), 1, i1);
647				perm::swap_rows_idx(Wr.rb_mut(), 1, i1);
648			}
649
650			if nothing_to_do {
651				break 'next_iter;
652			}
653
654			if npiv == 1 {
655				let W0 = Wr.rb_mut().col_mut(0);
656
657				let diag = real(W0[0]);
658				let diag_inv = recip(diag);
659				subdiag[0] = zero();
660
661				let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(1, 1);
662				let W0 = W0.rb().get(1..);
663				let n = A.nrows();
664
665				let mut L = L.col_mut(0);
666				zip!(W0, L.rb_mut()).for_each(|unzip!(w, a): Zip!(&T, &mut T)| *a = mul_real(*w, diag_inv));
667
668				for j in 0..n {
669					A[(j, j)] = from_real(real(A[(j, j)]) - diag * abs2(L[j]));
670				}
671			} else {
672				let a00 = real(Wr[(0, 0)]);
673				let a11 = real(Wr[(1, 1)]);
674				let a10 = copy(Wr[(1, 0)]);
675
676				subdiag[0] = copy(a10);
677				subdiag[1] = zero();
678				Wr[(1, 0)] = zero();
679				Ar[(1, 0)] = zero();
680
681				let d10 = abs(a10);
682				let d10_inv = recip(d10);
683				let d00 = a00 * d10_inv;
684				let d11 = a11 * d10_inv;
685
686				// t = (d00/|d10| * d11/|d10| - 1.0)
687				let t = recip(d00 * d11 - one());
688				let d10 = mul_real(a10, d10_inv);
689				let d = t * d10_inv;
690
691				//         [ a00  a01 ]
692				// L_new * [ a10  a11 ] = L
693				let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(2, 2);
694				let (mut L0, mut L1) = L.two_cols_mut(0, 1);
695				let Wr = Wr.rb().get(2.., ..);
696				let W0 = Wr.col(0);
697				let W1 = Wr.col(1);
698
699				let n = A.nrows();
700				for j in 0..n {
701					let x0 = copy(W0[j]);
702					let x1 = copy(W1[j]);
703
704					let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
705					let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
706
707					A[(j, j)] = from_real(real(A[(j, j)] - W0[j] * conj(w0) - W1[j] * conj(w1)));
708
709					L0[j] = w0;
710					L1[j] = w1;
711				}
712			}
713		}
714
715		let offset = A_left.ncols();
716
717		if npiv == 2 {
718			pivots[k] = (offset + i0 + k) | TOP_BIT;
719			pivots[k + 1] = (offset + i1 + k) | TOP_BIT;
720		} else {
721			pivots[k] = offset + i0 + k;
722		}
723		k += npiv;
724	}
725
726	let W = W.rb().get(k.., ..k);
727	let (_, _, Al, mut Ar) = A.rb_mut().split_at_mut(k, k);
728	let Al = Al.rb();
729
730	linalg::matmul::triangular::matmul(
731		Ar.rb_mut(),
732		BlockStructure::StrictTriangularLower,
733		Accum::Add,
734		W,
735		BlockStructure::Rectangular,
736		Al.adjoint(),
737		BlockStructure::Rectangular,
738		-one::<T>(),
739		par,
740	);
741
742	for j in 0..n - k {
743		Ar[(j, j)] = from_real(real(Ar[(j, j)]));
744	}
745
746	k
747}
748
749#[math]
750fn lblt_blocked<T: ComplexField>(
751	A: MatMut<'_, T>,
752	subdiag: DiagMut<'_, T>,
753	pivots: &mut [usize],
754	block_size: usize,
755	rook: bool,
756	diagonal: bool,
757	par: Par,
758	stack: &mut MemStack,
759) {
760	let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
761
762	let mut A = A;
763	let mut subdiag = subdiag.column_vector_mut();
764	let n = A.nrows();
765
766	let mut k = 0;
767	while k < n {
768		let (_, _, mut A_left, A_right) = A.rb_mut().split_at_mut(k, k);
769		let (mut W, _) = unsafe { temp_mat_uninit::<T, _, _>(n - k, block_size, stack) };
770		let W = W.as_mat_mut();
771
772		let next;
773
774		if block_size < 2 || n - k <= block_size {
775			lblt_unblocked(
776				copy(alpha),
777				A_left.rb_mut(),
778				A_right,
779				subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
780				&mut pivots[k..],
781				rook,
782				diagonal,
783				par,
784			);
785
786			next = n;
787		} else {
788			let block_size = lblt_blocked_step(
789				copy(alpha),
790				W,
791				A_left.rb_mut(),
792				A_right,
793				subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
794				&mut pivots[k..],
795				rook,
796				diagonal,
797				par,
798			);
799
800			next = k + block_size;
801		}
802
803		let pivots = &pivots[k..next];
804
805		let A_left = A.rb_mut().get_mut(.., ..k);
806
807		if A_left.ncols() > 0 {
808			match par {
809				Par::Seq => {
810					for mut col in A_left.col_iter_mut() {
811						for (i, &j) in core::iter::zip(k..next, pivots) {
812							let j = j & !TOP_BIT;
813							linalg::lu::partial_pivoting::factor::swap_elems(col.rb_mut(), i, j);
814						}
815					}
816				},
817				#[cfg(feature = "rayon")]
818				Par::Rayon(nthreads) => {
819					let nthreads = nthreads.get();
820					spindle::for_each(nthreads, A_left.par_col_iter_mut(), |mut col| {
821						for (i, &j) in core::iter::zip(k..next, pivots) {
822							let j = j & !TOP_BIT;
823							linalg::lu::partial_pivoting::factor::swap_elems(col.rb_mut(), i, j);
824						}
825					});
826				},
827			}
828		}
829
830		k = next;
831	}
832}
833
834#[math]
835#[inline(never)]
836fn lblt_unblocked<T: ComplexField>(
837	alpha: T::Real,
838	A_left: MatMut<'_, T>,
839	A: MatMut<'_, T>,
840	subdiag: DiagMut<'_, T>,
841	pivots: &mut [usize],
842	rook: bool,
843	diagonal: bool,
844	par: Par,
845) {
846	let _ = par;
847	let mut A = A;
848	let mut A_left = A_left;
849	let mut subdiag = subdiag;
850
851	let n = A.nrows();
852	assert!(all(A.nrows() == n, A.ncols() == n, subdiag.dim() == n));
853
854	let mut k = 0usize;
855	while k < n {
856		let (_, _, mut L_prev, mut A) = A.rb_mut().split_at_mut(k, k);
857		let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
858		let A_left = A_left.rb_mut().get_mut(k.., ..);
859
860		let npiv;
861
862		// find the diagonal pivot candidate, if requested
863		let mut i0 = if diagonal {
864			l1_argmax(A.rb().diagonal().column_vector()).0.unwrap()
865		} else {
866			0
867		};
868		let mut i1 = usize::MAX;
869
870		// find the largest off-diagonal in the pivot's column
871		let (r, mut gamma_i) = offdiag_argmax(A.rb(), i0);
872
873		let mut nothing_to_do = false;
874
875		if k + 1 == n || gamma_i == zero() {
876			nothing_to_do = true;
877			npiv = 1;
878		} else if abs(real(A[(i0, i0)])) >= alpha * gamma_i {
879			npiv = 1;
880		} else {
881			i1 = r.unwrap();
882
883			// pivot search
884			if rook {
885				loop {
886					let (s, gamma_r) = offdiag_argmax(A.rb(), i1);
887
888					if abs1(A[(i1, i1)]) >= alpha * gamma_r {
889						npiv = 1;
890						i0 = i1;
891						i1 = usize::MAX;
892						break;
893					} else if gamma_i == gamma_r {
894						npiv = 2;
895						break;
896					} else {
897						i0 = i1;
898						i1 = s.unwrap();
899						gamma_i = gamma_r;
900					}
901				}
902			} else {
903				let (_, gamma_r) = offdiag_argmax(A.rb(), i1);
904				if abs(real(A[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
905					npiv = 1;
906				} else if abs(real(A[(i1, i1)])) >= alpha * gamma_r {
907					npiv = 1;
908					i0 = i1;
909				} else {
910					npiv = 2;
911				}
912			}
913		}
914
915		if npiv == 2 && i0 > i1 {
916			(i0, i1) = (i1, i0);
917		}
918
919		'next_iter: {
920			// swap pivots to first (and second) column
921			if i0 != 0 {
922				swap_self_adjoint(A.rb_mut(), 0, i0);
923				perm::swap_rows_idx(L_prev.rb_mut(), 0, i0);
924			}
925			if npiv == 2 && i1 != 1 {
926				swap_self_adjoint(A.rb_mut(), 1, i1);
927				perm::swap_rows_idx(L_prev.rb_mut(), 1, i1);
928			}
929
930			if nothing_to_do {
931				break 'next_iter;
932			}
933
934			// rank downdate
935			if npiv == 1 {
936				let diag = real(A[(0, 0)]);
937				let diag_inv = recip(diag);
938				subdiag[0] = zero();
939
940				let (_, _, L, A) = A.rb_mut().split_at_mut(1, 1);
941				let L = L.col_mut(0);
942				rank1_update(A, L, diag_inv);
943			} else {
944				let a00 = real(A[(0, 0)]);
945				let a11 = real(A[(1, 1)]);
946				let a10 = copy(A[(1, 0)]);
947
948				subdiag[0] = copy(a10);
949				subdiag[1] = zero();
950				A[(1, 0)] = zero();
951
952				let d10 = abs(a10);
953				let d10_inv = recip(d10);
954				let d00 = a00 * d10_inv;
955				let d11 = a11 * d10_inv;
956
957				// t = (d00/|d10| * d11/|d10| - 1.0)
958				let t = recip(d00 * d11 - one());
959				let d10 = mul_real(a10, d10_inv);
960				let d = t * d10_inv;
961
962				//         [ a00  a01 ]
963				// L_new * [ a10  a11 ] = L
964				let (_, _, L, A) = A.rb_mut().split_at_mut(2, 2);
965				let (L0, L1) = L.two_cols_mut(0, 1);
966				rank2_update(A, L0, L1, d, d00, d10, d11);
967			}
968		}
969
970		let offset = A_left.ncols();
971		if npiv == 2 {
972			pivots[k] = (offset + i0 + k) | TOP_BIT;
973			pivots[k + 1] = (offset + i1 + k) | TOP_BIT;
974		} else {
975			pivots[k] = offset + i0 + k;
976		}
977		k += npiv;
978	}
979}
980
981impl<T: ComplexField> Auto<T> for LbltParams {
982	fn auto() -> Self {
983		Self {
984			pivoting: PivotingStrategy::PartialDiag,
985			block_size: 64,
986			par_threshold: 128 * 128,
987			non_exhaustive: NonExhaustive(()),
988		}
989	}
990}
991
992pub fn rank2_update<'a, T: ComplexField>(
993	mut A: MatMut<'a, T>,
994	mut L0: ColMut<'a, T>,
995	mut L1: ColMut<'a, T>,
996	d: T::Real,
997	d00: T::Real,
998	d10: T,
999	d11: T::Real,
1000) {
1001	if const { T::SIMD_CAPABILITIES.is_simd() } {
1002		if let (Some(A), Some(L0), Some(L1)) = (
1003			A.rb_mut().try_as_col_major_mut(),
1004			L0.rb_mut().try_as_col_major_mut(),
1005			L1.rb_mut().try_as_col_major_mut(),
1006		) {
1007			rank2_update_simd(A, L0, L1, d, d00, d10, d11);
1008		} else {
1009			rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
1010		}
1011	} else {
1012		rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
1013	}
1014}
1015
1016#[math]
1017pub fn rank2_update_simd<'a, T: ComplexField>(
1018	A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1019	L0: ColMut<'a, T, usize, ContiguousFwd>,
1020	L1: ColMut<'a, T, usize, ContiguousFwd>,
1021	d: T::Real,
1022	d00: T::Real,
1023	d10: T,
1024	d11: T::Real,
1025) {
1026	struct Impl<'a, T: ComplexField> {
1027		A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1028		L0: ColMut<'a, T, usize, ContiguousFwd>,
1029		L1: ColMut<'a, T, usize, ContiguousFwd>,
1030		d: T::Real,
1031		d00: T::Real,
1032		d10: T,
1033		d11: T::Real,
1034	}
1035
1036	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1037		type Output = ();
1038
1039		#[inline(always)]
1040		fn with_simd<S: pulp::Simd>(self, simd: S) {
1041			let Self {
1042				mut A,
1043				mut L0,
1044				mut L1,
1045				d,
1046				d00,
1047				d10,
1048				d11,
1049			} = self;
1050			let n = A.nrows();
1051			for j in 0..n {
1052				let x0 = copy(L0[j]);
1053				let x1 = copy(L1[j]);
1054				let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1055				let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1056
1057				with_dim!({
1058					let subrange_len = n - j;
1059				});
1060				{
1061					let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1062					let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1063					let L1 = L1.rb().get(j..).as_row_shape(subrange_len);
1064					let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1065					let (head, body, tail) = simd.indices();
1066
1067					let w0_conj = conj(w0);
1068					let w1_conj = conj(w1);
1069					let w0_conj_neg = -w0_conj;
1070					let w1_conj_neg = -w1_conj;
1071					let w0_splat = simd.splat(&w0_conj_neg);
1072					let w1_splat = simd.splat(&w1_conj_neg);
1073
1074					if let Some(i) = head {
1075						let mut acc = simd.read(A.rb(), i);
1076						let l0_val = simd.read(L0, i);
1077						let l1_val = simd.read(L1, i);
1078						acc = simd.mul_add(l0_val, w0_splat, acc);
1079						acc = simd.mul_add(l1_val, w1_splat, acc);
1080						simd.write(A.rb_mut(), i, acc);
1081					}
1082
1083					for i in body.clone() {
1084						let mut acc = simd.read(A.rb(), i);
1085						let l0_val = simd.read(L0, i);
1086						let l1_val = simd.read(L1, i);
1087						acc = simd.mul_add(l0_val, w0_splat, acc);
1088						acc = simd.mul_add(l1_val, w1_splat, acc);
1089						simd.write(A.rb_mut(), i, acc);
1090					}
1091
1092					if let Some(i) = tail {
1093						let mut acc = simd.read(A.rb(), i);
1094						let l0_val = simd.read(L0, i);
1095						let l1_val = simd.read(L1, i);
1096						acc = simd.mul_add(l0_val, w0_splat, acc);
1097						acc = simd.mul_add(l1_val, w1_splat, acc);
1098						simd.write(A.rb_mut(), i, acc);
1099					}
1100				}
1101				A[(j, j)] = from_real(real(A[(j, j)]));
1102
1103				L0[j] = w0;
1104				L1[j] = w1;
1105			}
1106		}
1107	}
1108	dispatch!(Impl { A, L0, L1, d, d00, d10, d11 }, Impl, T)
1109}
1110
1111#[math]
1112pub fn rank2_update_fallback<'a, T: ComplexField>(
1113	mut A: MatMut<'a, T>,
1114	mut L0: ColMut<'a, T>,
1115	mut L1: ColMut<'a, T>,
1116	d: T::Real,
1117	d00: T::Real,
1118	d10: T,
1119	d11: T::Real,
1120) {
1121	let n = A.nrows();
1122	for j in 0..n {
1123		let x0 = copy(L0[j]);
1124		let x1 = copy(L1[j]);
1125
1126		let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1127		let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1128
1129		for i in j..n {
1130			A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
1131		}
1132		A[(j, j)] = from_real(real(A[(j, j)]));
1133
1134		L0[j] = w0;
1135		L1[j] = w1;
1136	}
1137}
1138
1139pub fn rank1_update<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1140	if const { T::SIMD_CAPABILITIES.is_simd() } {
1141		if let (Some(A), Some(L0)) = (A.rb_mut().try_as_col_major_mut(), L0.rb_mut().try_as_col_major_mut()) {
1142			rank1_update_simd(A, L0, d);
1143		} else {
1144			rank1_update_fallback(A, L0, d);
1145		}
1146	} else {
1147		rank1_update_fallback(A, L0, d);
1148	}
1149}
1150
1151#[math]
1152pub fn rank1_update_simd<'a, T: ComplexField>(A: MatMut<'a, T, usize, usize, ContiguousFwd>, L0: ColMut<'a, T, usize, ContiguousFwd>, d: T::Real) {
1153	struct Impl<'a, T: ComplexField> {
1154		A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1155		L0: ColMut<'a, T, usize, ContiguousFwd>,
1156		d: T::Real,
1157	}
1158
1159	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1160		type Output = ();
1161
1162		#[inline(always)]
1163		fn with_simd<S: pulp::Simd>(self, simd: S) {
1164			let Self { mut A, mut L0, d } = self;
1165
1166			let n = A.nrows();
1167			for j in 0..n {
1168				let x0 = copy(L0[j]);
1169				let w0 = mul_real(x0, d);
1170
1171				with_dim!({
1172					let subrange_len = n - j;
1173				});
1174				{
1175					let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1176					let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1177					let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1178					let (head, body, tail) = simd.indices();
1179
1180					let w0_conj = conj(w0);
1181					let w0_conj_neg = -w0_conj;
1182					let w0_splat = simd.splat(&w0_conj_neg);
1183
1184					if let Some(i) = head {
1185						let mut acc = simd.read(A.rb(), i);
1186						let l0_val = simd.read(L0, i);
1187						acc = simd.mul_add(l0_val, w0_splat, acc);
1188						simd.write(A.rb_mut(), i, acc);
1189					}
1190
1191					for i in body.clone() {
1192						let mut acc = simd.read(A.rb(), i);
1193						let l0_val = simd.read(L0, i);
1194						acc = simd.mul_add(l0_val, w0_splat, acc);
1195						simd.write(A.rb_mut(), i, acc);
1196					}
1197
1198					if let Some(i) = tail {
1199						let mut acc = simd.read(A.rb(), i);
1200						let l0_val = simd.read(L0, i);
1201						acc = simd.mul_add(l0_val, w0_splat, acc);
1202						simd.write(A.rb_mut(), i, acc);
1203					}
1204				}
1205				A[(j, j)] = from_real(real(A[(j, j)]));
1206
1207				L0[j] = w0;
1208			}
1209		}
1210	}
1211	dispatch!(Impl { A, L0, d }, Impl, T)
1212}
1213
1214#[math]
1215pub fn rank1_update_fallback<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1216	let n = A.nrows();
1217	for j in 0..n {
1218		let x0 = copy(L0[j]);
1219		let w0 = mul_real(x0, d);
1220
1221		for i in j..n {
1222			A[(i, j)] = A[(i, j)] - L0[i] * conj(w0);
1223		}
1224		A[(j, j)] = from_real(real(A[(j, j)]));
1225		L0[j] = w0;
1226	}
1227}
1228/// computes the layout of required workspace for performing an $LBL^\top$
1229/// decomposition
1230pub fn cholesky_in_place_scratch<I: Index, T: ComplexField>(dim: usize, par: Par, params: Spec<LbltParams, T>) -> StackReq {
1231	let params = params.config;
1232	let _ = par;
1233	let mut bs = params.block_size;
1234	if bs < 2 || dim <= bs {
1235		bs = 0;
1236	}
1237	StackReq::new::<usize>(dim).and(temp_mat_scratch::<T>(dim, bs))
1238}
1239
1240/// info about the result of the $LBL^\top$ factorization
1241#[derive(Copy, Clone, Debug)]
1242pub struct LbltInfo {
1243	/// number of pivoting transpositions
1244	pub transposition_count: usize,
1245}
1246
1247/// computes the $LBL^\top$ factorization of $A$ and stores the factorization in `matrix` and
1248/// `subdiag`
1249///
1250/// the diagonal of the block diagonal matrix is stored on the diagonal
1251/// of `matrix`, while the subdiagonal elements of the blocks are stored in `subdiag`
1252///
1253/// # panics
1254///
1255/// panics if the input matrix is not square
1256///
1257/// this can also panic if the provided memory in `stack` is insufficient (see
1258/// [`cholesky_in_place_scratch`]).
1259
1260#[track_caller]
1261#[math]
1262pub fn cholesky_in_place<'out, I: Index, T: ComplexField>(
1263	A: MatMut<'_, T>,
1264	subdiag: DiagMut<'_, T>,
1265	perm: &'out mut [I],
1266	perm_inv: &'out mut [I],
1267	par: Par,
1268	stack: &mut MemStack,
1269	params: Spec<LbltParams, T>,
1270) -> (LbltInfo, PermRef<'out, I>) {
1271	let params = params.config;
1272
1273	let truncate = <I::Signed as SignedIndex>::truncate;
1274
1275	let n = A.nrows();
1276	assert!(all(A.nrows() == A.ncols(), subdiag.dim() == n, perm.len() == n, perm_inv.len() == n));
1277
1278	#[cfg(feature = "perf-warn")]
1279	if A.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) {
1280		if A.col_stride().unsigned_abs() == 1 {
1281			log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1282    matrix. Found row-major matrix.");
1283		} else {
1284			log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1285    matrix. Found matrix with generic strides.");
1286		}
1287	}
1288
1289	let (mut pivots, stack) = stack.make_with::<usize>(n, |_| 0);
1290	let pivots = &mut *pivots;
1291
1292	let mut bs = params.block_size;
1293	if bs < 2 || n <= bs {
1294		bs = 0;
1295	}
1296
1297	let (rook, diagonal) = match params.pivoting {
1298		PivotingStrategy::Partial => (false, false),
1299		PivotingStrategy::PartialDiag => (false, true),
1300		PivotingStrategy::Rook => (true, false),
1301		PivotingStrategy::RookDiag => (true, true),
1302		_ => (false, false),
1303	};
1304
1305	if params.pivoting == PivotingStrategy::Full {
1306		lblt_full_piv(A, subdiag, pivots, par, params);
1307	} else {
1308		lblt_blocked(A, subdiag, pivots, bs, rook, diagonal, par, stack);
1309	}
1310
1311	for (i, p) in perm.iter_mut().enumerate() {
1312		*p = I::from_signed(truncate(i));
1313	}
1314
1315	let mut transposition_count = 0usize;
1316	for i in 0..n {
1317		let p = pivots[i] & !TOP_BIT;
1318		if i != p {
1319			transposition_count += 1;
1320		}
1321		perm.swap(i, p);
1322	}
1323	for (i, &p) in perm.iter().enumerate() {
1324		perm_inv[p.to_signed().zx()] = I::from_signed(truncate(i));
1325	}
1326
1327	(LbltInfo { transposition_count }, unsafe { PermRef::new_unchecked(perm, perm_inv, n) })
1328}