1use dyn_stack::{PodStack, SizeOverflow, StackReq};
7use faer_core::{
8 mul::triangular::{self, BlockStructure},
9 permutation::{
10 permute_rows, swap_cols, swap_rows, Index, PermutationMut, PermutationRef, SignedIndex,
11 },
12 solve::{
13 solve_unit_lower_triangular_in_place_with_conj,
14 solve_unit_upper_triangular_in_place_with_conj,
15 },
16 temp_mat_req, temp_mat_uninit, unzipped, zipped, Conj, MatMut, MatRef, Parallelism,
17};
18use faer_entity::{ComplexField, Entity, RealField};
19use reborrow::*;
20
21pub mod compute {
22 use super::*;
23 use faer_core::assert;
24
25 #[derive(Copy, Clone)]
26 #[non_exhaustive]
27 pub enum PivotingStrategy {
28 Diagonal,
29 }
30
31 #[derive(Copy, Clone)]
32 #[non_exhaustive]
33 pub struct BunchKaufmanParams {
34 pub pivoting: PivotingStrategy,
35 pub blocksize: usize,
36 }
37
38 #[derive(Debug)]
40 pub struct BunchKaufmanRegularization<'a, E: ComplexField> {
41 pub dynamic_regularization_signs: Option<&'a mut [i8]>,
42 pub dynamic_regularization_delta: E::Real,
43 pub dynamic_regularization_epsilon: E::Real,
44 }
45
46 impl<E: ComplexField> Default for BunchKaufmanRegularization<'_, E> {
47 fn default() -> Self {
48 Self {
49 dynamic_regularization_signs: None,
50 dynamic_regularization_delta: E::Real::faer_zero(),
51 dynamic_regularization_epsilon: E::Real::faer_zero(),
52 }
53 }
54 }
55
56 impl Default for BunchKaufmanParams {
57 fn default() -> Self {
58 Self {
59 pivoting: PivotingStrategy::Diagonal,
60 blocksize: 64,
61 }
62 }
63 }
64
65 fn best_score_idx<E: ComplexField>(a: MatRef<'_, E>) -> Option<(usize, usize, E::Real)> {
66 let m = a.nrows();
67 let n = a.ncols();
68
69 if m == 0 || n == 0 {
70 return None;
71 }
72
73 let mut best_row = 0usize;
74 let mut best_col = 0usize;
75 let mut best_score = E::Real::faer_zero();
76
77 for j in 0..n {
78 for i in 0..m {
79 let score = a.read(i, j).faer_abs();
80 if score > best_score {
81 best_row = i;
82 best_col = j;
83 best_score = score;
84 }
85 }
86 }
87
88 Some((best_row, best_col, best_score))
89 }
90
91 fn assign_col<E: ComplexField>(a: MatMut<'_, E>, i: usize, j: usize) {
92 if i < j {
93 let (ai, aj) = a.subcols_mut(i, j - i + 1).split_at_col_mut(1);
94 ai.col_mut(0).copy_from(aj.rb().col(j - i - 1));
95 } else if j < i {
96 let (aj, ai) = a.subcols_mut(j, i - j + 1).split_at_col_mut(1);
97 ai.col_mut(i - j - 1).copy_from(aj.rb().col(0));
98 }
99 }
100
101 fn best_score<E: ComplexField>(a: MatRef<'_, E>) -> E::Real {
102 let m = a.nrows();
103 let n = a.ncols();
104
105 let mut best_score = E::Real::faer_zero();
106
107 for j in 0..n {
108 for i in 0..m {
109 let score = a.read(i, j).faer_abs();
110 if score > best_score {
111 best_score = score;
112 }
113 }
114 }
115
116 best_score
117 }
118
119 #[inline(always)]
120 fn max<E: RealField>(a: E, b: E) -> E {
121 if a > b {
122 a
123 } else {
124 b
125 }
126 }
127
128 fn swap_elems_conj<E: ComplexField>(
129 a: MatMut<'_, E>,
130 i0: usize,
131 j0: usize,
132 i1: usize,
133 j1: usize,
134 ) {
135 let mut a = a;
136 let tmp = a.read(i0, j0).faer_conj();
137 a.write(i0, j0, a.read(i1, j1).faer_conj());
138 a.write(i1, j1, tmp);
139 }
140 fn swap_elems<E: ComplexField>(a: MatMut<'_, E>, i0: usize, j0: usize, i1: usize, j1: usize) {
141 let mut a = a;
142 let tmp = a.read(i0, j0);
143 a.write(i0, j0, a.read(i1, j1));
144 a.write(i1, j1, tmp);
145 }
146
147 fn cholesky_diagonal_pivoting_blocked_step<I: Index, E: ComplexField>(
148 mut a: MatMut<'_, E>,
149 regularization: BunchKaufmanRegularization<'_, E>,
150 mut w: MatMut<'_, E>,
151 pivots: &mut [I],
152 alpha: E::Real,
153 parallelism: Parallelism,
154 ) -> (usize, usize, usize) {
155 assert!(a.nrows() == a.ncols());
156 let n = a.nrows();
157 let nb = w.ncols();
158 assert!(nb < n);
159 if n == 0 {
160 return (0, 0, 0);
161 }
162
163 let eps = regularization.dynamic_regularization_epsilon.faer_abs();
164 let delta = regularization.dynamic_regularization_delta.faer_abs();
165 let mut signs = regularization.dynamic_regularization_signs;
166 let has_eps = delta > E::Real::faer_zero();
167 let mut dynamic_regularization_count = 0usize;
168 let mut pivot_count = 0usize;
169
170 let truncate = <I::Signed as SignedIndex>::truncate;
171
172 let mut k = 0;
173 while k < n && k + 1 < nb {
174 let make_real = |mut mat: MatMut<'_, E>, i, j| {
175 mat.write(i, j, E::faer_from_real(mat.read(i, j).faer_real()))
176 };
177
178 w.rb_mut()
179 .subrows_mut(k, n - k)
180 .col_mut(k)
181 .copy_from(a.rb().subrows(k, n - k).col(k));
182
183 let (w_left, w_right) = w
184 .rb_mut()
185 .submatrix_mut(k, 0, n - k, k + 1)
186 .split_at_col_mut(k);
187 let w_row = w_left.rb().row(0);
188 let w_col = w_right.col_mut(0);
189 faer_core::mul::matmul(
190 w_col.as_2d_mut(),
191 a.rb().submatrix(k, 0, n - k, k),
192 w_row.rb().transpose().as_2d(),
193 Some(E::faer_one()),
194 E::faer_one().faer_neg(),
195 parallelism,
196 );
197 make_real(w.rb_mut(), k, k);
198
199 let mut k_step = 1;
200
201 let abs_akk = w.read(k, k).faer_real().faer_abs();
202 let imax;
203 let colmax;
204
205 if k + 1 < n {
206 (imax, _, colmax) =
207 best_score_idx(w.rb().col(k).as_2d().subrows(k + 1, n - k - 1)).unwrap();
208 } else {
209 imax = 0;
210 colmax = E::Real::faer_zero();
211 }
212 let imax = imax + k + 1;
213
214 let kp;
215 if max(abs_akk, colmax) == E::Real::faer_zero() {
216 kp = k;
217
218 let mut d11 = w.read(k, k).faer_real();
219 if has_eps {
220 if let Some(signs) = signs.rb_mut() {
221 if signs[k] > 0 && d11 <= eps {
222 d11 = delta;
223 dynamic_regularization_count += 1;
224 } else if signs[k] < 0 && d11 >= eps.faer_neg() {
225 d11 = delta.faer_neg();
226 dynamic_regularization_count += 1;
227 }
228 }
229 }
230 let d11 = d11.faer_inv();
231 a.write(k, k, E::faer_from_real(d11));
232 } else {
233 if abs_akk >= colmax.faer_mul(alpha) {
234 kp = k;
235 } else {
236 zipped!(
237 w.rb_mut()
238 .subrows_mut(k, imax - k)
239 .col_mut(k + 1)
240 .as_2d_mut(),
241 a.rb().row(imax).subcols(k, imax - k).transpose().as_2d(),
242 )
243 .for_each(|unzipped!(mut dst, src)| dst.write(src.read().faer_conj()));
244
245 w.rb_mut()
246 .subrows_mut(imax, n - imax)
247 .col_mut(k + 1)
248 .copy_from(a.rb().subrows(imax, n - imax).col(imax));
249
250 let (w_left, w_right) = w
251 .rb_mut()
252 .submatrix_mut(k, 0, n - k, nb)
253 .split_at_col_mut(k + 1);
254 let w_row = w_left.rb().row(imax - k).subcols(0, k);
255 let w_col = w_right.col_mut(0);
256
257 faer_core::mul::matmul(
258 w_col.as_2d_mut(),
259 a.rb().submatrix(k, 0, n - k, k),
260 w_row.rb().transpose().as_2d(),
261 Some(E::faer_one()),
262 E::faer_one().faer_neg(),
263 parallelism,
264 );
265 make_real(w.rb_mut(), imax, k + 1);
266
267 let rowmax = max(
268 best_score(w.rb().subrows(k, imax - k).col(k + 1).as_2d()),
269 best_score(w.rb().subrows(imax + 1, n - imax - 1).col(k + 1).as_2d()),
270 );
271
272 if abs_akk >= alpha.faer_mul(colmax).faer_mul(colmax.faer_div(rowmax)) {
273 kp = k;
274 } else if w.read(imax, k + 1).faer_real().faer_abs() >= alpha.faer_mul(rowmax) {
275 kp = imax;
276 assign_col(w.rb_mut().subrows_mut(k, n - k), k, k + 1);
277 } else {
278 kp = imax;
279 k_step = 2;
280 }
281 }
282
283 let kk = k + k_step - 1;
284
285 if kp != kk {
286 pivot_count += 1;
287 if let Some(signs) = signs.rb_mut() {
288 signs.swap(kp, kk);
289 }
290 a.write(kp, kp, a.read(kk, kk));
291 for j in kk + 1..kp {
292 a.write(kp, j, a.read(j, kk).faer_conj());
293 }
294 assign_col(a.rb_mut().subrows_mut(kp + 1, n - kp - 1), kp, kk);
295
296 swap_rows(a.rb_mut().subcols_mut(0, k), kk, kp);
297 swap_rows(w.rb_mut().subcols_mut(0, kk + 1), kk, kp);
298 }
299
300 if k_step == 1 {
301 a.rb_mut()
302 .subrows_mut(k, n - k)
303 .col_mut(k)
304 .copy_from(w.rb().subrows(k, n - k).col(k));
305
306 let mut d11 = w.read(k, k).faer_real();
307 if has_eps {
308 if let Some(signs) = signs.rb_mut() {
309 if signs[k] > 0 && d11 <= eps {
310 d11 = delta;
311 dynamic_regularization_count += 1;
312 } else if signs[k] < 0 && d11 >= eps.faer_neg() {
313 d11 = delta.faer_neg();
314 dynamic_regularization_count += 1;
315 }
316 } else {
317 if d11.faer_abs() <= eps {
318 if d11 < E::Real::faer_zero() {
319 d11 = delta.faer_neg();
320 } else {
321 d11 = delta;
322 }
323 dynamic_regularization_count += 1;
324 }
325 }
326 }
327 let d11 = d11.faer_inv();
328 a.write(k, k, E::faer_from_real(d11));
329
330 let x = a.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k);
331 zipped!(x.as_2d_mut())
332 .for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
333 zipped!(w
334 .rb_mut()
335 .subrows_mut(k + 1, n - k - 1)
336 .col_mut(k)
337 .as_2d_mut())
338 .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
339 } else {
340 let d21 = w.read(k + 1, k).faer_abs();
341 let d21_inv = d21.faer_inv();
342 let mut d11 = d21_inv.faer_scale_real(w.read(k + 1, k + 1).faer_real());
343 let mut d22 = d21_inv.faer_scale_real(w.read(k, k).faer_real());
344
345 let eps = eps.faer_mul(d21_inv);
346 let delta = delta.faer_mul(d21_inv);
347 if has_eps {
348 if let Some(signs) = signs.rb_mut() {
349 if signs[k] > 0 && signs[k + 1] > 0 {
350 if d11 <= eps {
351 d11 = delta;
352 }
353 if d22 <= eps {
354 d22 = delta;
355 }
356 } else if signs[k] < 0 && signs[k + 1] < 0 {
357 if d11 >= eps.faer_neg() {
358 d11 = delta.faer_neg();
359 }
360 if d22 >= eps.faer_neg() {
361 d22 = delta.faer_neg();
362 }
363 }
364 }
365 }
366
367 let mut t = d11.faer_mul(d22).faer_sub(E::Real::faer_one());
369 if has_eps {
370 if let Some(signs) = signs.rb_mut() {
371 if ((signs[k] > 0 && signs[k + 1] > 0)
372 || (signs[k] < 0 && signs[k + 1] < 0))
373 && t <= eps
374 {
375 t = delta;
376 } else if ((signs[k] > 0 && signs[k + 1] < 0)
377 || (signs[k] < 0 && signs[k + 1] > 0))
378 && t >= eps.faer_neg()
379 {
380 t = delta.faer_neg();
381 }
382 }
383 }
384
385 let t = t.faer_inv();
386 let d21 = w.read(k + 1, k).faer_scale_real(d21_inv);
387 let d = t.faer_mul(d21_inv);
388
389 a.write(k, k, E::faer_from_real(d11.faer_mul(d)));
390 a.write(k + 1, k, d21.faer_scale_real(d.faer_neg()));
391 a.write(k + 1, k + 1, E::faer_from_real(d22.faer_mul(d)));
392
393 for j in k + 2..n {
394 let wk = (w
395 .read(j, k)
396 .faer_scale_real(d11)
397 .faer_sub(w.read(j, k + 1).faer_mul(d21)))
398 .faer_scale_real(d);
399 let wkp1 = (w
400 .read(j, k + 1)
401 .faer_scale_real(d22)
402 .faer_sub(w.read(j, k).faer_mul(d21.faer_conj())))
403 .faer_scale_real(d);
404
405 a.write(j, k, wk);
406 a.write(j, k + 1, wkp1);
407 }
408
409 zipped!(w
410 .rb_mut()
411 .subrows_mut(k + 1, n - k - 1)
412 .col_mut(k)
413 .as_2d_mut())
414 .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
415 zipped!(w
416 .rb_mut()
417 .subrows_mut(k + 2, n - k - 2)
418 .col_mut(k + 1)
419 .as_2d_mut())
420 .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
421 }
422 }
423
424 if k_step == 1 {
425 pivots[k] = I::from_signed(truncate(kp));
426 } else {
427 pivots[k] = I::from_signed(truncate(!kp));
428 pivots[k + 1] = I::from_signed(truncate(!kp));
429 }
430
431 k += k_step;
432 }
433
434 let (a_left, mut a_right) = a.rb_mut().subrows_mut(k, n - k).split_at_col_mut(k);
435 triangular::matmul(
436 a_right.rb_mut(),
437 BlockStructure::TriangularLower,
438 a_left.rb(),
439 BlockStructure::Rectangular,
440 w.rb().submatrix(k, 0, n - k, k).transpose(),
441 BlockStructure::Rectangular,
442 Some(E::faer_one()),
443 E::faer_one().faer_neg(),
444 parallelism,
445 );
446
447 zipped!(a_right.diagonal_mut().column_vector_mut().as_2d_mut())
448 .for_each(|unzipped!(mut x)| x.write(E::faer_from_real(x.read().faer_real())));
449
450 let mut j = k - 1;
451 loop {
452 let jj = j;
453 let mut jp = pivots[j].to_signed().sx();
454 if (jp as isize) < 0 {
455 jp = !jp;
456 j -= 1;
457 }
458
459 if j == 0 {
460 return (k, pivot_count, dynamic_regularization_count);
461 }
462 j -= 1;
463
464 if jp != jj {
465 swap_rows(a.rb_mut().subcols_mut(0, j + 1), jp, jj);
466 }
467 if j == 0 {
468 return (k, pivot_count, dynamic_regularization_count);
469 }
470 }
471 }
472
473 fn cholesky_diagonal_pivoting_unblocked<I: Index, E: ComplexField>(
474 mut a: MatMut<'_, E>,
475 regularization: BunchKaufmanRegularization<'_, E>,
476 pivots: &mut [I],
477 alpha: E::Real,
478 ) -> (usize, usize) {
479 let truncate = <I::Signed as SignedIndex>::truncate;
480
481 assert!(a.nrows() == a.ncols());
482 let n = a.nrows();
483 if n == 0 {
484 return (0, 0);
485 }
486
487 let eps = regularization.dynamic_regularization_epsilon.faer_abs();
488 let delta = regularization.dynamic_regularization_delta.faer_abs();
489 let mut signs = regularization.dynamic_regularization_signs;
490 let has_eps = delta > E::Real::faer_zero();
491 let mut dynamic_regularization_count = 0usize;
492 let mut pivot_count = 0usize;
493
494 let mut k = 0;
495 while k < n {
496 let make_real = |mut mat: MatMut<'_, E>, i, j| {
497 mat.write(i, j, E::faer_from_real(mat.read(i, j).faer_real()))
498 };
499
500 let mut k_step = 1;
501
502 let abs_akk = a.read(k, k).faer_abs();
503 let imax;
504 let colmax;
505
506 if k + 1 < n {
507 (imax, _, colmax) =
508 best_score_idx(a.rb().col(k).subrows(k + 1, n - k - 1).as_2d()).unwrap();
509 } else {
510 imax = 0;
511 colmax = E::Real::faer_zero();
512 }
513 let imax = imax + k + 1;
514
515 let kp;
516 if max(abs_akk, colmax) == E::Real::faer_zero() {
517 kp = k;
518
519 let mut d11 = a.read(k, k).faer_real();
520 if has_eps {
521 if let Some(signs) = signs.rb_mut() {
522 if signs[k] > 0 && d11 <= eps {
523 d11 = delta;
524 dynamic_regularization_count += 1;
525 } else if signs[k] < 0 && d11 >= eps.faer_neg() {
526 d11 = delta.faer_neg();
527 dynamic_regularization_count += 1;
528 }
529 }
530 }
531 let d11 = d11.faer_inv();
532 a.write(k, k, E::faer_from_real(d11));
533 } else {
534 if abs_akk >= colmax.faer_mul(alpha) {
535 kp = k;
536 } else {
537 let rowmax = max(
538 best_score(a.rb().row(imax).subcols(k, imax - k).as_2d()),
539 best_score(a.rb().subrows(imax + 1, n - imax - 1).col(imax).as_2d()),
540 );
541
542 if abs_akk >= alpha.faer_mul(colmax).faer_mul(colmax.faer_div(rowmax)) {
543 kp = k;
544 } else if a.read(imax, imax).faer_abs() >= alpha.faer_mul(rowmax) {
545 kp = imax
546 } else {
547 kp = imax;
548 k_step = 2;
549 }
550 }
551
552 let kk = k + k_step - 1;
553
554 if kp != kk {
555 pivot_count += 1;
556 swap_cols(a.rb_mut().subrows_mut(kp + 1, n - kp - 1), kk, kp);
557 for j in kk + 1..kp {
558 swap_elems_conj(a.rb_mut(), j, kk, kp, j);
559 }
560
561 a.write(kp, kk, a.read(kp, kk).faer_conj());
562 swap_elems(a.rb_mut(), kk, kk, kp, kp);
563
564 if k_step == 2 {
565 swap_elems(a.rb_mut(), k + 1, k, kp, k);
566 }
567 }
568
569 if k_step == 1 {
570 let mut d11 = a.read(k, k).faer_real();
571 if has_eps {
572 if let Some(signs) = signs.rb_mut() {
573 if signs[k] > 0 && d11 <= eps {
574 d11 = delta;
575 dynamic_regularization_count += 1;
576 } else if signs[k] < 0 && d11 >= eps.faer_neg() {
577 d11 = delta.faer_neg();
578 dynamic_regularization_count += 1;
579 }
580 } else {
581 if d11.faer_abs() <= eps {
582 if d11 < E::Real::faer_zero() {
583 d11 = delta.faer_neg();
584 } else {
585 d11 = delta;
586 }
587 dynamic_regularization_count += 1;
588 }
589 }
590 }
591 let d11 = d11.faer_inv();
592 a.write(k, k, E::faer_from_real(d11));
593
594 let (x, mut trailing) = a
595 .rb_mut()
596 .subrows_mut(k + 1, n - k - 1)
597 .subcols_mut(k, n - k)
598 .split_at_col_mut(1);
599
600 for j in 0..n - k - 1 {
601 let d11xj = x.read(j, 0).faer_conj().faer_scale_real(d11);
602 for i in j..n - k - 1 {
603 let xi = x.read(i, 0);
604 trailing.write(i, j, trailing.read(i, j).faer_sub(d11xj.faer_mul(xi)));
605 }
606 make_real(trailing.rb_mut(), j, j);
607 }
608 zipped!(x).for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
609 } else {
610 let d21 = a.read(k + 1, k).faer_abs();
611 let d21_inv = d21.faer_inv();
612 let mut d11 = d21_inv.faer_scale_real(a.read(k + 1, k + 1).faer_real());
613 let mut d22 = d21_inv.faer_scale_real(a.read(k, k).faer_real());
614
615 let eps = eps.faer_mul(d21_inv);
616 let delta = delta.faer_mul(d21_inv);
617 if has_eps {
618 if let Some(signs) = signs.rb_mut() {
619 if signs[k] > 0 && signs[k + 1] > 0 {
620 if d11 <= eps {
621 d11 = delta;
622 }
623 if d22 <= eps {
624 d22 = delta;
625 }
626 } else if signs[k] < 0 && signs[k + 1] < 0 {
627 if d11 >= eps.faer_neg() {
628 d11 = delta.faer_neg();
629 }
630 if d22 >= eps.faer_neg() {
631 d22 = delta.faer_neg();
632 }
633 }
634 }
635 }
636
637 let mut t = d11.faer_mul(d22).faer_sub(E::Real::faer_one());
639 if has_eps {
640 if let Some(signs) = signs.rb_mut() {
641 if ((signs[k] > 0 && signs[k + 1] > 0)
642 || (signs[k] < 0 && signs[k + 1] < 0))
643 && t <= eps
644 {
645 t = delta;
646 } else if ((signs[k] > 0 && signs[k + 1] < 0)
647 || (signs[k] < 0 && signs[k + 1] > 0))
648 && t >= eps.faer_neg()
649 {
650 t = delta.faer_neg();
651 }
652 }
653 }
654
655 let t = t.faer_inv();
656 let d21 = a.read(k + 1, k).faer_scale_real(d21_inv);
657 let d = t.faer_mul(d21_inv);
658
659 a.write(k, k, E::faer_from_real(d11.faer_mul(d)));
660 a.write(k + 1, k, d21.faer_scale_real(d.faer_neg()));
661 a.write(k + 1, k + 1, E::faer_from_real(d22.faer_mul(d)));
662
663 for j in k + 2..n {
664 let wk = (a
665 .read(j, k)
666 .faer_scale_real(d11)
667 .faer_sub(a.read(j, k + 1).faer_mul(d21)))
668 .faer_scale_real(d);
669 let wkp1 = (a
670 .read(j, k + 1)
671 .faer_scale_real(d22)
672 .faer_sub(a.read(j, k).faer_mul(d21.faer_conj())))
673 .faer_scale_real(d);
674
675 for i in j..n {
676 a.write(
677 i,
678 j,
679 a.read(i, j)
680 .faer_sub(a.read(i, k).faer_mul(wk.faer_conj()))
681 .faer_sub(a.read(i, k + 1).faer_mul(wkp1.faer_conj())),
682 );
683 }
684 make_real(a.rb_mut(), j, j);
685
686 a.write(j, k, wk);
687 a.write(j, k + 1, wkp1);
688 }
689 }
690 }
691
692 if k_step == 1 {
693 pivots[k] = I::from_signed(truncate(kp));
694 } else {
695 pivots[k] = I::from_signed(truncate(!kp));
696 pivots[k + 1] = I::from_signed(truncate(!kp));
697 }
698
699 k += k_step;
700 }
701
702 (pivot_count, dynamic_regularization_count)
703 }
704
705 fn convert<I: Index, E: ComplexField>(
706 mut a: MatMut<'_, E>,
707 pivots: &[I],
708 mut subdiag: MatMut<'_, E>,
709 ) {
710 assert!(a.nrows() == a.ncols());
711 let n = a.nrows();
712
713 let mut i = 0;
714 while i < n {
715 if (pivots[i].to_signed().sx() as isize) < 0 {
716 subdiag.write(i, 0, a.read(i + 1, i));
717 subdiag.write(i + 1, 0, E::faer_zero());
718 a.write(i + 1, i, E::faer_zero());
719 i += 2;
720 } else {
721 subdiag.write(i, 0, E::faer_zero());
722 i += 1;
723 }
724 }
725
726 let mut i = 0;
727 while i < n {
728 let p = pivots[i].to_signed().sx();
729 if (p as isize) < 0 {
730 let p = !p;
731 swap_rows(a.rb_mut().subcols_mut(0, i), i + 1, p);
732 i += 2;
733 } else {
734 swap_rows(a.rb_mut().subcols_mut(0, i), i, p);
735 i += 1;
736 }
737 }
738 }
739
740 pub fn cholesky_in_place_req<I: Index, E: Entity>(
743 dim: usize,
744 parallelism: Parallelism,
745 params: BunchKaufmanParams,
746 ) -> Result<StackReq, SizeOverflow> {
747 let _ = parallelism;
748 let mut bs = params.blocksize;
749 if bs < 2 || dim <= bs {
750 bs = 0;
751 }
752 StackReq::try_new::<I>(dim)?.try_and(temp_mat_req::<E>(dim, bs)?)
753 }
754
755 #[derive(Copy, Clone, Debug)]
756 pub struct BunchKaufmanInfo {
757 pub dynamic_regularization_count: usize,
758 pub transposition_count: usize,
759 }
760
761 #[track_caller]
774 pub fn cholesky_in_place<'out, I: Index, E: ComplexField>(
775 matrix: MatMut<'_, E>,
776 subdiag: MatMut<'_, E>,
777 regularization: BunchKaufmanRegularization<'_, E>,
778 perm: &'out mut [I],
779 perm_inv: &'out mut [I],
780 parallelism: Parallelism,
781 stack: PodStack<'_>,
782 params: BunchKaufmanParams,
783 ) -> (BunchKaufmanInfo, PermutationMut<'out, I, E>) {
784 let truncate = <I::Signed as SignedIndex>::truncate;
785 let mut regularization = regularization;
786
787 let n = matrix.nrows();
788 assert!(all(
789 matrix.nrows() == matrix.ncols(),
790 subdiag.nrows() == n,
791 subdiag.ncols() == 1,
792 perm.len() == n,
793 perm_inv.len() == n
794 ));
795
796 #[cfg(feature = "perf-warn")]
797 if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) {
798 if matrix.col_stride().unsigned_abs() == 1 {
799 log::warn!(target: "faer_perf", "Bunch-Kaufman decomposition prefers column-major matrix. Found row-major matrix.");
800 } else {
801 log::warn!(target: "faer_perf", "Bunch-Kaufman decomposition prefers column-major matrix. Found matrix with generic strides.");
802 }
803 }
804
805 let _ = parallelism;
806 let mut matrix = matrix;
807
808 let alpha = E::Real::faer_one()
809 .faer_add(E::Real::faer_from_f64(17.0).faer_sqrt())
810 .faer_scale_power_of_two(E::Real::faer_from_f64(1.0 / 8.0));
811
812 let (pivots, stack) = stack.make_raw::<I>(n);
813
814 let mut bs = params.blocksize;
815 if bs < 2 || n <= bs {
816 bs = 0;
817 }
818 let mut work = temp_mat_uninit(n, bs, stack).0;
819
820 let mut k = 0;
821 let mut dynamic_regularization_count = 0;
822 let mut transposition_count = 0;
823 while k < n {
824 let regularization = BunchKaufmanRegularization {
825 dynamic_regularization_signs: regularization
826 .dynamic_regularization_signs
827 .rb_mut()
828 .map(|signs| &mut signs[k..]),
829 dynamic_regularization_delta: regularization.dynamic_regularization_delta,
830 dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
831 };
832
833 let kb;
834 let reg_count;
835 let piv_count;
836 if bs >= 2 && bs < n - k {
837 (kb, piv_count, reg_count) = cholesky_diagonal_pivoting_blocked_step(
838 matrix.rb_mut().submatrix_mut(k, k, n - k, n - k),
839 regularization,
840 work.rb_mut(),
841 &mut pivots[k..],
842 alpha,
843 parallelism,
844 );
845 } else {
846 (piv_count, reg_count) = cholesky_diagonal_pivoting_unblocked(
847 matrix.rb_mut().submatrix_mut(k, k, n - k, n - k),
848 regularization,
849 &mut pivots[k..],
850 alpha,
851 );
852 kb = n - k;
853 }
854 dynamic_regularization_count += reg_count;
855 transposition_count += piv_count;
856
857 for pivot in &mut pivots[k..k + kb] {
858 let pv = (*pivot).to_signed().sx();
859 if pv as isize >= 0 {
860 *pivot = I::from_signed(truncate(pv + k));
861 } else {
862 *pivot = I::from_signed(truncate(pv - k));
863 }
864 }
865
866 k += kb;
867 }
868
869 convert(matrix.rb_mut(), pivots, subdiag);
870
871 for (i, p) in perm.iter_mut().enumerate() {
872 *p = I::from_signed(truncate(i));
873 }
874 let mut i = 0;
875 while i < n {
876 let p = pivots[i].to_signed().sx();
877 if (p as isize) < 0 {
878 let p = !p;
879 perm.swap(i + 1, p);
880 i += 2;
881 } else {
882 perm.swap(i, p);
883 i += 1;
884 }
885 }
886 for (i, &p) in perm.iter().enumerate() {
887 perm_inv[p.to_signed().zx()] = I::from_signed(truncate(i));
888 }
889
890 (
891 BunchKaufmanInfo {
892 dynamic_regularization_count,
893 transposition_count,
894 },
895 unsafe { PermutationMut::new_unchecked(perm, perm_inv) },
896 )
897 }
898}
899
900pub mod solve {
901 use super::*;
902 use faer_core::assert;
903
904 #[track_caller]
905 pub fn solve_in_place_req<I: Index, E: Entity>(
906 dim: usize,
907 rhs_ncols: usize,
908 parallelism: Parallelism,
909 ) -> Result<StackReq, SizeOverflow> {
910 let _ = parallelism;
911 temp_mat_req::<E>(dim, rhs_ncols)
912 }
913
914 #[track_caller]
915 pub fn solve_in_place_with_conj<I: Index, E: ComplexField>(
916 lb_factors: MatRef<'_, E>,
917 subdiag: MatRef<'_, E>,
918 conj: Conj,
919 perm: PermutationRef<'_, I, E>,
920 rhs: MatMut<'_, E>,
921 parallelism: Parallelism,
922 stack: PodStack<'_>,
923 ) {
924 let n = lb_factors.nrows();
925 let k = rhs.ncols();
926
927 assert!(all(
928 lb_factors.nrows() == lb_factors.ncols(),
929 rhs.nrows() == n,
930 subdiag.nrows() == n,
931 subdiag.ncols() == 1,
932 perm.len() == n
933 ));
934
935 let a = lb_factors;
936 let par = parallelism;
937 let not_conj = conj.compose(Conj::Yes);
938
939 let mut rhs = rhs;
940 let mut x = temp_mat_uninit::<E>(n, k, stack).0;
941
942 permute_rows(x.rb_mut(), rhs.rb(), perm);
943 solve_unit_lower_triangular_in_place_with_conj(a, conj, x.rb_mut(), par);
944
945 let mut i = 0;
946 while i < n {
947 if subdiag.read(i, 0) == E::faer_zero() {
948 let d_inv = a.read(i, i).faer_real();
949 for j in 0..k {
950 x.write(i, j, x.read(i, j).faer_scale_real(d_inv));
951 }
952 i += 1;
953 } else {
954 if conj == Conj::Yes {
955 let akp1k = subdiag.read(i, 0);
956 let ak = a.read(i, i).faer_real();
957 let akp1 = a.read(i + 1, i + 1).faer_real();
958
959 for j in 0..k {
960 let xk = x.read(i, j);
961 let xkp1 = x.read(i + 1, j);
962
963 x.write(i, j, xk.faer_scale_real(ak).faer_add(xkp1.faer_mul(akp1k)));
964 x.write(
965 i + 1,
966 j,
967 xkp1.faer_scale_real(akp1)
968 .faer_add(xk.faer_mul(akp1k.faer_conj())),
969 );
970 }
971 } else {
972 let akp1k = subdiag.read(i, 0);
973 let ak = a.read(i, i).faer_real();
974 let akp1 = a.read(i + 1, i + 1).faer_real();
975
976 for j in 0..k {
977 let xk = x.read(i, j);
978 let xkp1 = x.read(i + 1, j);
979
980 x.write(
981 i,
982 j,
983 xk.faer_scale_real(ak)
984 .faer_add(xkp1.faer_mul(akp1k.faer_conj())),
985 );
986 x.write(
987 i + 1,
988 j,
989 xkp1.faer_scale_real(akp1).faer_add(xk.faer_mul(akp1k)),
990 );
991 }
992 }
993 i += 2;
994 }
995 }
996
997 solve_unit_upper_triangular_in_place_with_conj(a.transpose(), not_conj, x.rb_mut(), par);
998 permute_rows(rhs.rb_mut(), x.rb(), perm.inverse());
999 }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004 use crate::bunch_kaufman::compute::BunchKaufmanParams;
1005
1006 use super::*;
1007 use dyn_stack::GlobalPodBuffer;
1008 use faer_core::{assert, c64, Mat};
1009 use rand::random;
1010
1011 #[test]
1012 fn test_real() {
1013 for n in [3, 6, 19, 100, 421] {
1014 let a = Mat::<f64>::from_fn(n, n, |_, _| random());
1015 let a = &a + a.adjoint();
1016 let rhs = Mat::<f64>::from_fn(n, 2, |_, _| random());
1017
1018 let mut ldl = a.clone();
1019 let mut subdiag = Mat::<f64>::zeros(n, 1);
1020
1021 let mut perm = vec![0usize; n];
1022 let mut perm_inv = vec![0; n];
1023
1024 let params = Default::default();
1025 let mut mem = GlobalPodBuffer::new(
1026 compute::cholesky_in_place_req::<usize, f64>(n, Parallelism::None, params).unwrap(),
1027 );
1028 let (_, perm) = compute::cholesky_in_place(
1029 ldl.as_mut(),
1030 subdiag.as_mut(),
1031 Default::default(),
1032 &mut perm,
1033 &mut perm_inv,
1034 Parallelism::None,
1035 PodStack::new(&mut mem),
1036 params,
1037 );
1038
1039 let mut mem = GlobalPodBuffer::new(
1040 solve::solve_in_place_req::<usize, f64>(n, rhs.ncols(), Parallelism::None).unwrap(),
1041 );
1042 let mut x = rhs.clone();
1043 solve::solve_in_place_with_conj(
1044 ldl.as_ref(),
1045 subdiag.as_ref(),
1046 Conj::No,
1047 perm.rb(),
1048 x.as_mut(),
1049 Parallelism::None,
1050 PodStack::new(&mut mem),
1051 );
1052
1053 let err = &a * &x - &rhs;
1054 let mut max = 0.0;
1055 zipped!(err.as_ref()).for_each(|unzipped!(err)| {
1056 let err = err.read().abs();
1057 if err > max {
1058 max = err
1059 }
1060 });
1061 assert!(max < 1e-9);
1062 }
1063 }
1064
1065 #[test]
1066 fn test_cplx() {
1067 for n in [3, 6, 19, 100, 421] {
1068 let a = Mat::<c64>::from_fn(n, n, |_, _| c64::new(random(), random()));
1069 let a = &a + a.adjoint();
1070 let rhs = Mat::<c64>::from_fn(n, 2, |_, _| c64::new(random(), random()));
1071
1072 let mut ldl = a.clone();
1073 let mut subdiag = Mat::<c64>::zeros(n, 1);
1074
1075 let mut perm = vec![0usize; n];
1076 let mut perm_inv = vec![0; n];
1077
1078 let params = BunchKaufmanParams {
1079 pivoting: compute::PivotingStrategy::Diagonal,
1080 blocksize: 32,
1081 };
1082 let mut mem = GlobalPodBuffer::new(
1083 compute::cholesky_in_place_req::<usize, c64>(n, Parallelism::None, params).unwrap(),
1084 );
1085 let (_, perm) = compute::cholesky_in_place(
1086 ldl.as_mut(),
1087 subdiag.as_mut(),
1088 Default::default(),
1089 &mut perm,
1090 &mut perm_inv,
1091 Parallelism::None,
1092 PodStack::new(&mut mem),
1093 params,
1094 );
1095
1096 let mut x = rhs.clone();
1097 let mut mem = GlobalPodBuffer::new(
1098 solve::solve_in_place_req::<usize, c64>(n, rhs.ncols(), Parallelism::None).unwrap(),
1099 );
1100 solve::solve_in_place_with_conj(
1101 ldl.as_ref(),
1102 subdiag.as_ref(),
1103 Conj::Yes,
1104 perm.rb(),
1105 x.as_mut(),
1106 Parallelism::None,
1107 PodStack::new(&mut mem),
1108 );
1109
1110 let err = a.conjugate() * &x - &rhs;
1111 let mut max = 0.0;
1112 zipped!(err.as_ref()).for_each(|unzipped!(err)| {
1113 let err = err.read().abs();
1114 if err > max {
1115 max = err
1116 }
1117 });
1118 for i in 0..n {
1119 assert!(ldl[(i, i)].faer_imag() == 0.0);
1120 }
1121 assert!(max < 1e-9);
1122 }
1123 }
1124}