faer_cholesky/bunch_kaufman/
mod.rs

1//! The Bunch Kaufman decomposition of a hermitian matrix $A$ is such that:
2//! $$P A P^\top = LBL^H,$$
3//! where $B$ is a block diagonal matrix, with $1\times 1$ or $2 \times 2 $ diagonal blocks, and
4//! $L$ is a unit lower triangular matrix.
5
6use dyn_stack::{PodStack, SizeOverflow, StackReq};
7use faer_core::{
8    mul::triangular::{self, BlockStructure},
9    permutation::{
10        permute_rows, swap_cols, swap_rows, Index, PermutationMut, PermutationRef, SignedIndex,
11    },
12    solve::{
13        solve_unit_lower_triangular_in_place_with_conj,
14        solve_unit_upper_triangular_in_place_with_conj,
15    },
16    temp_mat_req, temp_mat_uninit, unzipped, zipped, Conj, MatMut, MatRef, Parallelism,
17};
18use faer_entity::{ComplexField, Entity, RealField};
19use reborrow::*;
20
21pub mod compute {
22    use super::*;
23    use faer_core::assert;
24
25    #[derive(Copy, Clone)]
26    #[non_exhaustive]
27    pub enum PivotingStrategy {
28        Diagonal,
29    }
30
31    #[derive(Copy, Clone)]
32    #[non_exhaustive]
33    pub struct BunchKaufmanParams {
34        pub pivoting: PivotingStrategy,
35        pub blocksize: usize,
36    }
37
38    /// Dynamic Bunch-Kaufman regularization.
39    #[derive(Debug)]
40    pub struct BunchKaufmanRegularization<'a, E: ComplexField> {
41        pub dynamic_regularization_signs: Option<&'a mut [i8]>,
42        pub dynamic_regularization_delta: E::Real,
43        pub dynamic_regularization_epsilon: E::Real,
44    }
45
46    impl<E: ComplexField> Default for BunchKaufmanRegularization<'_, E> {
47        fn default() -> Self {
48            Self {
49                dynamic_regularization_signs: None,
50                dynamic_regularization_delta: E::Real::faer_zero(),
51                dynamic_regularization_epsilon: E::Real::faer_zero(),
52            }
53        }
54    }
55
56    impl Default for BunchKaufmanParams {
57        fn default() -> Self {
58            Self {
59                pivoting: PivotingStrategy::Diagonal,
60                blocksize: 64,
61            }
62        }
63    }
64
65    fn best_score_idx<E: ComplexField>(a: MatRef<'_, E>) -> Option<(usize, usize, E::Real)> {
66        let m = a.nrows();
67        let n = a.ncols();
68
69        if m == 0 || n == 0 {
70            return None;
71        }
72
73        let mut best_row = 0usize;
74        let mut best_col = 0usize;
75        let mut best_score = E::Real::faer_zero();
76
77        for j in 0..n {
78            for i in 0..m {
79                let score = a.read(i, j).faer_abs();
80                if score > best_score {
81                    best_row = i;
82                    best_col = j;
83                    best_score = score;
84                }
85            }
86        }
87
88        Some((best_row, best_col, best_score))
89    }
90
91    fn assign_col<E: ComplexField>(a: MatMut<'_, E>, i: usize, j: usize) {
92        if i < j {
93            let (ai, aj) = a.subcols_mut(i, j - i + 1).split_at_col_mut(1);
94            ai.col_mut(0).copy_from(aj.rb().col(j - i - 1));
95        } else if j < i {
96            let (aj, ai) = a.subcols_mut(j, i - j + 1).split_at_col_mut(1);
97            ai.col_mut(i - j - 1).copy_from(aj.rb().col(0));
98        }
99    }
100
101    fn best_score<E: ComplexField>(a: MatRef<'_, E>) -> E::Real {
102        let m = a.nrows();
103        let n = a.ncols();
104
105        let mut best_score = E::Real::faer_zero();
106
107        for j in 0..n {
108            for i in 0..m {
109                let score = a.read(i, j).faer_abs();
110                if score > best_score {
111                    best_score = score;
112                }
113            }
114        }
115
116        best_score
117    }
118
119    #[inline(always)]
120    fn max<E: RealField>(a: E, b: E) -> E {
121        if a > b {
122            a
123        } else {
124            b
125        }
126    }
127
128    fn swap_elems_conj<E: ComplexField>(
129        a: MatMut<'_, E>,
130        i0: usize,
131        j0: usize,
132        i1: usize,
133        j1: usize,
134    ) {
135        let mut a = a;
136        let tmp = a.read(i0, j0).faer_conj();
137        a.write(i0, j0, a.read(i1, j1).faer_conj());
138        a.write(i1, j1, tmp);
139    }
140    fn swap_elems<E: ComplexField>(a: MatMut<'_, E>, i0: usize, j0: usize, i1: usize, j1: usize) {
141        let mut a = a;
142        let tmp = a.read(i0, j0);
143        a.write(i0, j0, a.read(i1, j1));
144        a.write(i1, j1, tmp);
145    }
146
147    fn cholesky_diagonal_pivoting_blocked_step<I: Index, E: ComplexField>(
148        mut a: MatMut<'_, E>,
149        regularization: BunchKaufmanRegularization<'_, E>,
150        mut w: MatMut<'_, E>,
151        pivots: &mut [I],
152        alpha: E::Real,
153        parallelism: Parallelism,
154    ) -> (usize, usize, usize) {
155        assert!(a.nrows() == a.ncols());
156        let n = a.nrows();
157        let nb = w.ncols();
158        assert!(nb < n);
159        if n == 0 {
160            return (0, 0, 0);
161        }
162
163        let eps = regularization.dynamic_regularization_epsilon.faer_abs();
164        let delta = regularization.dynamic_regularization_delta.faer_abs();
165        let mut signs = regularization.dynamic_regularization_signs;
166        let has_eps = delta > E::Real::faer_zero();
167        let mut dynamic_regularization_count = 0usize;
168        let mut pivot_count = 0usize;
169
170        let truncate = <I::Signed as SignedIndex>::truncate;
171
172        let mut k = 0;
173        while k < n && k + 1 < nb {
174            let make_real = |mut mat: MatMut<'_, E>, i, j| {
175                mat.write(i, j, E::faer_from_real(mat.read(i, j).faer_real()))
176            };
177
178            w.rb_mut()
179                .subrows_mut(k, n - k)
180                .col_mut(k)
181                .copy_from(a.rb().subrows(k, n - k).col(k));
182
183            let (w_left, w_right) = w
184                .rb_mut()
185                .submatrix_mut(k, 0, n - k, k + 1)
186                .split_at_col_mut(k);
187            let w_row = w_left.rb().row(0);
188            let w_col = w_right.col_mut(0);
189            faer_core::mul::matmul(
190                w_col.as_2d_mut(),
191                a.rb().submatrix(k, 0, n - k, k),
192                w_row.rb().transpose().as_2d(),
193                Some(E::faer_one()),
194                E::faer_one().faer_neg(),
195                parallelism,
196            );
197            make_real(w.rb_mut(), k, k);
198
199            let mut k_step = 1;
200
201            let abs_akk = w.read(k, k).faer_real().faer_abs();
202            let imax;
203            let colmax;
204
205            if k + 1 < n {
206                (imax, _, colmax) =
207                    best_score_idx(w.rb().col(k).as_2d().subrows(k + 1, n - k - 1)).unwrap();
208            } else {
209                imax = 0;
210                colmax = E::Real::faer_zero();
211            }
212            let imax = imax + k + 1;
213
214            let kp;
215            if max(abs_akk, colmax) == E::Real::faer_zero() {
216                kp = k;
217
218                let mut d11 = w.read(k, k).faer_real();
219                if has_eps {
220                    if let Some(signs) = signs.rb_mut() {
221                        if signs[k] > 0 && d11 <= eps {
222                            d11 = delta;
223                            dynamic_regularization_count += 1;
224                        } else if signs[k] < 0 && d11 >= eps.faer_neg() {
225                            d11 = delta.faer_neg();
226                            dynamic_regularization_count += 1;
227                        }
228                    }
229                }
230                let d11 = d11.faer_inv();
231                a.write(k, k, E::faer_from_real(d11));
232            } else {
233                if abs_akk >= colmax.faer_mul(alpha) {
234                    kp = k;
235                } else {
236                    zipped!(
237                        w.rb_mut()
238                            .subrows_mut(k, imax - k)
239                            .col_mut(k + 1)
240                            .as_2d_mut(),
241                        a.rb().row(imax).subcols(k, imax - k).transpose().as_2d(),
242                    )
243                    .for_each(|unzipped!(mut dst, src)| dst.write(src.read().faer_conj()));
244
245                    w.rb_mut()
246                        .subrows_mut(imax, n - imax)
247                        .col_mut(k + 1)
248                        .copy_from(a.rb().subrows(imax, n - imax).col(imax));
249
250                    let (w_left, w_right) = w
251                        .rb_mut()
252                        .submatrix_mut(k, 0, n - k, nb)
253                        .split_at_col_mut(k + 1);
254                    let w_row = w_left.rb().row(imax - k).subcols(0, k);
255                    let w_col = w_right.col_mut(0);
256
257                    faer_core::mul::matmul(
258                        w_col.as_2d_mut(),
259                        a.rb().submatrix(k, 0, n - k, k),
260                        w_row.rb().transpose().as_2d(),
261                        Some(E::faer_one()),
262                        E::faer_one().faer_neg(),
263                        parallelism,
264                    );
265                    make_real(w.rb_mut(), imax, k + 1);
266
267                    let rowmax = max(
268                        best_score(w.rb().subrows(k, imax - k).col(k + 1).as_2d()),
269                        best_score(w.rb().subrows(imax + 1, n - imax - 1).col(k + 1).as_2d()),
270                    );
271
272                    if abs_akk >= alpha.faer_mul(colmax).faer_mul(colmax.faer_div(rowmax)) {
273                        kp = k;
274                    } else if w.read(imax, k + 1).faer_real().faer_abs() >= alpha.faer_mul(rowmax) {
275                        kp = imax;
276                        assign_col(w.rb_mut().subrows_mut(k, n - k), k, k + 1);
277                    } else {
278                        kp = imax;
279                        k_step = 2;
280                    }
281                }
282
283                let kk = k + k_step - 1;
284
285                if kp != kk {
286                    pivot_count += 1;
287                    if let Some(signs) = signs.rb_mut() {
288                        signs.swap(kp, kk);
289                    }
290                    a.write(kp, kp, a.read(kk, kk));
291                    for j in kk + 1..kp {
292                        a.write(kp, j, a.read(j, kk).faer_conj());
293                    }
294                    assign_col(a.rb_mut().subrows_mut(kp + 1, n - kp - 1), kp, kk);
295
296                    swap_rows(a.rb_mut().subcols_mut(0, k), kk, kp);
297                    swap_rows(w.rb_mut().subcols_mut(0, kk + 1), kk, kp);
298                }
299
300                if k_step == 1 {
301                    a.rb_mut()
302                        .subrows_mut(k, n - k)
303                        .col_mut(k)
304                        .copy_from(w.rb().subrows(k, n - k).col(k));
305
306                    let mut d11 = w.read(k, k).faer_real();
307                    if has_eps {
308                        if let Some(signs) = signs.rb_mut() {
309                            if signs[k] > 0 && d11 <= eps {
310                                d11 = delta;
311                                dynamic_regularization_count += 1;
312                            } else if signs[k] < 0 && d11 >= eps.faer_neg() {
313                                d11 = delta.faer_neg();
314                                dynamic_regularization_count += 1;
315                            }
316                        } else {
317                            if d11.faer_abs() <= eps {
318                                if d11 < E::Real::faer_zero() {
319                                    d11 = delta.faer_neg();
320                                } else {
321                                    d11 = delta;
322                                }
323                                dynamic_regularization_count += 1;
324                            }
325                        }
326                    }
327                    let d11 = d11.faer_inv();
328                    a.write(k, k, E::faer_from_real(d11));
329
330                    let x = a.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k);
331                    zipped!(x.as_2d_mut())
332                        .for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
333                    zipped!(w
334                        .rb_mut()
335                        .subrows_mut(k + 1, n - k - 1)
336                        .col_mut(k)
337                        .as_2d_mut())
338                    .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
339                } else {
340                    let d21 = w.read(k + 1, k).faer_abs();
341                    let d21_inv = d21.faer_inv();
342                    let mut d11 = d21_inv.faer_scale_real(w.read(k + 1, k + 1).faer_real());
343                    let mut d22 = d21_inv.faer_scale_real(w.read(k, k).faer_real());
344
345                    let eps = eps.faer_mul(d21_inv);
346                    let delta = delta.faer_mul(d21_inv);
347                    if has_eps {
348                        if let Some(signs) = signs.rb_mut() {
349                            if signs[k] > 0 && signs[k + 1] > 0 {
350                                if d11 <= eps {
351                                    d11 = delta;
352                                }
353                                if d22 <= eps {
354                                    d22 = delta;
355                                }
356                            } else if signs[k] < 0 && signs[k + 1] < 0 {
357                                if d11 >= eps.faer_neg() {
358                                    d11 = delta.faer_neg();
359                                }
360                                if d22 >= eps.faer_neg() {
361                                    d22 = delta.faer_neg();
362                                }
363                            }
364                        }
365                    }
366
367                    // t = (d11/|d21| * d22/|d21| - 1.0)
368                    let mut t = d11.faer_mul(d22).faer_sub(E::Real::faer_one());
369                    if has_eps {
370                        if let Some(signs) = signs.rb_mut() {
371                            if ((signs[k] > 0 && signs[k + 1] > 0)
372                                || (signs[k] < 0 && signs[k + 1] < 0))
373                                && t <= eps
374                            {
375                                t = delta;
376                            } else if ((signs[k] > 0 && signs[k + 1] < 0)
377                                || (signs[k] < 0 && signs[k + 1] > 0))
378                                && t >= eps.faer_neg()
379                            {
380                                t = delta.faer_neg();
381                            }
382                        }
383                    }
384
385                    let t = t.faer_inv();
386                    let d21 = w.read(k + 1, k).faer_scale_real(d21_inv);
387                    let d = t.faer_mul(d21_inv);
388
389                    a.write(k, k, E::faer_from_real(d11.faer_mul(d)));
390                    a.write(k + 1, k, d21.faer_scale_real(d.faer_neg()));
391                    a.write(k + 1, k + 1, E::faer_from_real(d22.faer_mul(d)));
392
393                    for j in k + 2..n {
394                        let wk = (w
395                            .read(j, k)
396                            .faer_scale_real(d11)
397                            .faer_sub(w.read(j, k + 1).faer_mul(d21)))
398                        .faer_scale_real(d);
399                        let wkp1 = (w
400                            .read(j, k + 1)
401                            .faer_scale_real(d22)
402                            .faer_sub(w.read(j, k).faer_mul(d21.faer_conj())))
403                        .faer_scale_real(d);
404
405                        a.write(j, k, wk);
406                        a.write(j, k + 1, wkp1);
407                    }
408
409                    zipped!(w
410                        .rb_mut()
411                        .subrows_mut(k + 1, n - k - 1)
412                        .col_mut(k)
413                        .as_2d_mut())
414                    .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
415                    zipped!(w
416                        .rb_mut()
417                        .subrows_mut(k + 2, n - k - 2)
418                        .col_mut(k + 1)
419                        .as_2d_mut())
420                    .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
421                }
422            }
423
424            if k_step == 1 {
425                pivots[k] = I::from_signed(truncate(kp));
426            } else {
427                pivots[k] = I::from_signed(truncate(!kp));
428                pivots[k + 1] = I::from_signed(truncate(!kp));
429            }
430
431            k += k_step;
432        }
433
434        let (a_left, mut a_right) = a.rb_mut().subrows_mut(k, n - k).split_at_col_mut(k);
435        triangular::matmul(
436            a_right.rb_mut(),
437            BlockStructure::TriangularLower,
438            a_left.rb(),
439            BlockStructure::Rectangular,
440            w.rb().submatrix(k, 0, n - k, k).transpose(),
441            BlockStructure::Rectangular,
442            Some(E::faer_one()),
443            E::faer_one().faer_neg(),
444            parallelism,
445        );
446
447        zipped!(a_right.diagonal_mut().column_vector_mut().as_2d_mut())
448            .for_each(|unzipped!(mut x)| x.write(E::faer_from_real(x.read().faer_real())));
449
450        let mut j = k - 1;
451        loop {
452            let jj = j;
453            let mut jp = pivots[j].to_signed().sx();
454            if (jp as isize) < 0 {
455                jp = !jp;
456                j -= 1;
457            }
458
459            if j == 0 {
460                return (k, pivot_count, dynamic_regularization_count);
461            }
462            j -= 1;
463
464            if jp != jj {
465                swap_rows(a.rb_mut().subcols_mut(0, j + 1), jp, jj);
466            }
467            if j == 0 {
468                return (k, pivot_count, dynamic_regularization_count);
469            }
470        }
471    }
472
473    fn cholesky_diagonal_pivoting_unblocked<I: Index, E: ComplexField>(
474        mut a: MatMut<'_, E>,
475        regularization: BunchKaufmanRegularization<'_, E>,
476        pivots: &mut [I],
477        alpha: E::Real,
478    ) -> (usize, usize) {
479        let truncate = <I::Signed as SignedIndex>::truncate;
480
481        assert!(a.nrows() == a.ncols());
482        let n = a.nrows();
483        if n == 0 {
484            return (0, 0);
485        }
486
487        let eps = regularization.dynamic_regularization_epsilon.faer_abs();
488        let delta = regularization.dynamic_regularization_delta.faer_abs();
489        let mut signs = regularization.dynamic_regularization_signs;
490        let has_eps = delta > E::Real::faer_zero();
491        let mut dynamic_regularization_count = 0usize;
492        let mut pivot_count = 0usize;
493
494        let mut k = 0;
495        while k < n {
496            let make_real = |mut mat: MatMut<'_, E>, i, j| {
497                mat.write(i, j, E::faer_from_real(mat.read(i, j).faer_real()))
498            };
499
500            let mut k_step = 1;
501
502            let abs_akk = a.read(k, k).faer_abs();
503            let imax;
504            let colmax;
505
506            if k + 1 < n {
507                (imax, _, colmax) =
508                    best_score_idx(a.rb().col(k).subrows(k + 1, n - k - 1).as_2d()).unwrap();
509            } else {
510                imax = 0;
511                colmax = E::Real::faer_zero();
512            }
513            let imax = imax + k + 1;
514
515            let kp;
516            if max(abs_akk, colmax) == E::Real::faer_zero() {
517                kp = k;
518
519                let mut d11 = a.read(k, k).faer_real();
520                if has_eps {
521                    if let Some(signs) = signs.rb_mut() {
522                        if signs[k] > 0 && d11 <= eps {
523                            d11 = delta;
524                            dynamic_regularization_count += 1;
525                        } else if signs[k] < 0 && d11 >= eps.faer_neg() {
526                            d11 = delta.faer_neg();
527                            dynamic_regularization_count += 1;
528                        }
529                    }
530                }
531                let d11 = d11.faer_inv();
532                a.write(k, k, E::faer_from_real(d11));
533            } else {
534                if abs_akk >= colmax.faer_mul(alpha) {
535                    kp = k;
536                } else {
537                    let rowmax = max(
538                        best_score(a.rb().row(imax).subcols(k, imax - k).as_2d()),
539                        best_score(a.rb().subrows(imax + 1, n - imax - 1).col(imax).as_2d()),
540                    );
541
542                    if abs_akk >= alpha.faer_mul(colmax).faer_mul(colmax.faer_div(rowmax)) {
543                        kp = k;
544                    } else if a.read(imax, imax).faer_abs() >= alpha.faer_mul(rowmax) {
545                        kp = imax
546                    } else {
547                        kp = imax;
548                        k_step = 2;
549                    }
550                }
551
552                let kk = k + k_step - 1;
553
554                if kp != kk {
555                    pivot_count += 1;
556                    swap_cols(a.rb_mut().subrows_mut(kp + 1, n - kp - 1), kk, kp);
557                    for j in kk + 1..kp {
558                        swap_elems_conj(a.rb_mut(), j, kk, kp, j);
559                    }
560
561                    a.write(kp, kk, a.read(kp, kk).faer_conj());
562                    swap_elems(a.rb_mut(), kk, kk, kp, kp);
563
564                    if k_step == 2 {
565                        swap_elems(a.rb_mut(), k + 1, k, kp, k);
566                    }
567                }
568
569                if k_step == 1 {
570                    let mut d11 = a.read(k, k).faer_real();
571                    if has_eps {
572                        if let Some(signs) = signs.rb_mut() {
573                            if signs[k] > 0 && d11 <= eps {
574                                d11 = delta;
575                                dynamic_regularization_count += 1;
576                            } else if signs[k] < 0 && d11 >= eps.faer_neg() {
577                                d11 = delta.faer_neg();
578                                dynamic_regularization_count += 1;
579                            }
580                        } else {
581                            if d11.faer_abs() <= eps {
582                                if d11 < E::Real::faer_zero() {
583                                    d11 = delta.faer_neg();
584                                } else {
585                                    d11 = delta;
586                                }
587                                dynamic_regularization_count += 1;
588                            }
589                        }
590                    }
591                    let d11 = d11.faer_inv();
592                    a.write(k, k, E::faer_from_real(d11));
593
594                    let (x, mut trailing) = a
595                        .rb_mut()
596                        .subrows_mut(k + 1, n - k - 1)
597                        .subcols_mut(k, n - k)
598                        .split_at_col_mut(1);
599
600                    for j in 0..n - k - 1 {
601                        let d11xj = x.read(j, 0).faer_conj().faer_scale_real(d11);
602                        for i in j..n - k - 1 {
603                            let xi = x.read(i, 0);
604                            trailing.write(i, j, trailing.read(i, j).faer_sub(d11xj.faer_mul(xi)));
605                        }
606                        make_real(trailing.rb_mut(), j, j);
607                    }
608                    zipped!(x).for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
609                } else {
610                    let d21 = a.read(k + 1, k).faer_abs();
611                    let d21_inv = d21.faer_inv();
612                    let mut d11 = d21_inv.faer_scale_real(a.read(k + 1, k + 1).faer_real());
613                    let mut d22 = d21_inv.faer_scale_real(a.read(k, k).faer_real());
614
615                    let eps = eps.faer_mul(d21_inv);
616                    let delta = delta.faer_mul(d21_inv);
617                    if has_eps {
618                        if let Some(signs) = signs.rb_mut() {
619                            if signs[k] > 0 && signs[k + 1] > 0 {
620                                if d11 <= eps {
621                                    d11 = delta;
622                                }
623                                if d22 <= eps {
624                                    d22 = delta;
625                                }
626                            } else if signs[k] < 0 && signs[k + 1] < 0 {
627                                if d11 >= eps.faer_neg() {
628                                    d11 = delta.faer_neg();
629                                }
630                                if d22 >= eps.faer_neg() {
631                                    d22 = delta.faer_neg();
632                                }
633                            }
634                        }
635                    }
636
637                    // t = (d11/|d21| * d22/|d21| - 1.0)
638                    let mut t = d11.faer_mul(d22).faer_sub(E::Real::faer_one());
639                    if has_eps {
640                        if let Some(signs) = signs.rb_mut() {
641                            if ((signs[k] > 0 && signs[k + 1] > 0)
642                                || (signs[k] < 0 && signs[k + 1] < 0))
643                                && t <= eps
644                            {
645                                t = delta;
646                            } else if ((signs[k] > 0 && signs[k + 1] < 0)
647                                || (signs[k] < 0 && signs[k + 1] > 0))
648                                && t >= eps.faer_neg()
649                            {
650                                t = delta.faer_neg();
651                            }
652                        }
653                    }
654
655                    let t = t.faer_inv();
656                    let d21 = a.read(k + 1, k).faer_scale_real(d21_inv);
657                    let d = t.faer_mul(d21_inv);
658
659                    a.write(k, k, E::faer_from_real(d11.faer_mul(d)));
660                    a.write(k + 1, k, d21.faer_scale_real(d.faer_neg()));
661                    a.write(k + 1, k + 1, E::faer_from_real(d22.faer_mul(d)));
662
663                    for j in k + 2..n {
664                        let wk = (a
665                            .read(j, k)
666                            .faer_scale_real(d11)
667                            .faer_sub(a.read(j, k + 1).faer_mul(d21)))
668                        .faer_scale_real(d);
669                        let wkp1 = (a
670                            .read(j, k + 1)
671                            .faer_scale_real(d22)
672                            .faer_sub(a.read(j, k).faer_mul(d21.faer_conj())))
673                        .faer_scale_real(d);
674
675                        for i in j..n {
676                            a.write(
677                                i,
678                                j,
679                                a.read(i, j)
680                                    .faer_sub(a.read(i, k).faer_mul(wk.faer_conj()))
681                                    .faer_sub(a.read(i, k + 1).faer_mul(wkp1.faer_conj())),
682                            );
683                        }
684                        make_real(a.rb_mut(), j, j);
685
686                        a.write(j, k, wk);
687                        a.write(j, k + 1, wkp1);
688                    }
689                }
690            }
691
692            if k_step == 1 {
693                pivots[k] = I::from_signed(truncate(kp));
694            } else {
695                pivots[k] = I::from_signed(truncate(!kp));
696                pivots[k + 1] = I::from_signed(truncate(!kp));
697            }
698
699            k += k_step;
700        }
701
702        (pivot_count, dynamic_regularization_count)
703    }
704
705    fn convert<I: Index, E: ComplexField>(
706        mut a: MatMut<'_, E>,
707        pivots: &[I],
708        mut subdiag: MatMut<'_, E>,
709    ) {
710        assert!(a.nrows() == a.ncols());
711        let n = a.nrows();
712
713        let mut i = 0;
714        while i < n {
715            if (pivots[i].to_signed().sx() as isize) < 0 {
716                subdiag.write(i, 0, a.read(i + 1, i));
717                subdiag.write(i + 1, 0, E::faer_zero());
718                a.write(i + 1, i, E::faer_zero());
719                i += 2;
720            } else {
721                subdiag.write(i, 0, E::faer_zero());
722                i += 1;
723            }
724        }
725
726        let mut i = 0;
727        while i < n {
728            let p = pivots[i].to_signed().sx();
729            if (p as isize) < 0 {
730                let p = !p;
731                swap_rows(a.rb_mut().subcols_mut(0, i), i + 1, p);
732                i += 2;
733            } else {
734                swap_rows(a.rb_mut().subcols_mut(0, i), i, p);
735                i += 1;
736            }
737        }
738    }
739
740    /// Computes the size and alignment of required workspace for performing a Cholesky
741    /// decomposition with Bunch-Kaufman pivoting.
742    pub fn cholesky_in_place_req<I: Index, E: Entity>(
743        dim: usize,
744        parallelism: Parallelism,
745        params: BunchKaufmanParams,
746    ) -> Result<StackReq, SizeOverflow> {
747        let _ = parallelism;
748        let mut bs = params.blocksize;
749        if bs < 2 || dim <= bs {
750            bs = 0;
751        }
752        StackReq::try_new::<I>(dim)?.try_and(temp_mat_req::<E>(dim, bs)?)
753    }
754
755    #[derive(Copy, Clone, Debug)]
756    pub struct BunchKaufmanInfo {
757        pub dynamic_regularization_count: usize,
758        pub transposition_count: usize,
759    }
760
761    /// Computes the Cholesky factorization with Bunch-Kaufman  pivoting of the input matrix and
762    /// stores the factorization in `matrix` and `subdiag`.
763    ///
764    /// The inverses of the diagonal blocks of the block diagonal matrix are stored on the diagonal
765    /// of `matrix`, while the subdiagonal elements of those inverses are stored in `subdiag`.
766    ///
767    /// # Panics
768    ///
769    /// Panics if the input matrix is not square.
770    ///
771    /// This can also panic if the provided memory in `stack` is insufficient (see
772    /// [`cholesky_in_place_req`]).
773    #[track_caller]
774    pub fn cholesky_in_place<'out, I: Index, E: ComplexField>(
775        matrix: MatMut<'_, E>,
776        subdiag: MatMut<'_, E>,
777        regularization: BunchKaufmanRegularization<'_, E>,
778        perm: &'out mut [I],
779        perm_inv: &'out mut [I],
780        parallelism: Parallelism,
781        stack: PodStack<'_>,
782        params: BunchKaufmanParams,
783    ) -> (BunchKaufmanInfo, PermutationMut<'out, I, E>) {
784        let truncate = <I::Signed as SignedIndex>::truncate;
785        let mut regularization = regularization;
786
787        let n = matrix.nrows();
788        assert!(all(
789            matrix.nrows() == matrix.ncols(),
790            subdiag.nrows() == n,
791            subdiag.ncols() == 1,
792            perm.len() == n,
793            perm_inv.len() == n
794        ));
795
796        #[cfg(feature = "perf-warn")]
797        if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) {
798            if matrix.col_stride().unsigned_abs() == 1 {
799                log::warn!(target: "faer_perf", "Bunch-Kaufman decomposition prefers column-major matrix. Found row-major matrix.");
800            } else {
801                log::warn!(target: "faer_perf", "Bunch-Kaufman decomposition prefers column-major matrix. Found matrix with generic strides.");
802            }
803        }
804
805        let _ = parallelism;
806        let mut matrix = matrix;
807
808        let alpha = E::Real::faer_one()
809            .faer_add(E::Real::faer_from_f64(17.0).faer_sqrt())
810            .faer_scale_power_of_two(E::Real::faer_from_f64(1.0 / 8.0));
811
812        let (pivots, stack) = stack.make_raw::<I>(n);
813
814        let mut bs = params.blocksize;
815        if bs < 2 || n <= bs {
816            bs = 0;
817        }
818        let mut work = temp_mat_uninit(n, bs, stack).0;
819
820        let mut k = 0;
821        let mut dynamic_regularization_count = 0;
822        let mut transposition_count = 0;
823        while k < n {
824            let regularization = BunchKaufmanRegularization {
825                dynamic_regularization_signs: regularization
826                    .dynamic_regularization_signs
827                    .rb_mut()
828                    .map(|signs| &mut signs[k..]),
829                dynamic_regularization_delta: regularization.dynamic_regularization_delta,
830                dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
831            };
832
833            let kb;
834            let reg_count;
835            let piv_count;
836            if bs >= 2 && bs < n - k {
837                (kb, piv_count, reg_count) = cholesky_diagonal_pivoting_blocked_step(
838                    matrix.rb_mut().submatrix_mut(k, k, n - k, n - k),
839                    regularization,
840                    work.rb_mut(),
841                    &mut pivots[k..],
842                    alpha,
843                    parallelism,
844                );
845            } else {
846                (piv_count, reg_count) = cholesky_diagonal_pivoting_unblocked(
847                    matrix.rb_mut().submatrix_mut(k, k, n - k, n - k),
848                    regularization,
849                    &mut pivots[k..],
850                    alpha,
851                );
852                kb = n - k;
853            }
854            dynamic_regularization_count += reg_count;
855            transposition_count += piv_count;
856
857            for pivot in &mut pivots[k..k + kb] {
858                let pv = (*pivot).to_signed().sx();
859                if pv as isize >= 0 {
860                    *pivot = I::from_signed(truncate(pv + k));
861                } else {
862                    *pivot = I::from_signed(truncate(pv - k));
863                }
864            }
865
866            k += kb;
867        }
868
869        convert(matrix.rb_mut(), pivots, subdiag);
870
871        for (i, p) in perm.iter_mut().enumerate() {
872            *p = I::from_signed(truncate(i));
873        }
874        let mut i = 0;
875        while i < n {
876            let p = pivots[i].to_signed().sx();
877            if (p as isize) < 0 {
878                let p = !p;
879                perm.swap(i + 1, p);
880                i += 2;
881            } else {
882                perm.swap(i, p);
883                i += 1;
884            }
885        }
886        for (i, &p) in perm.iter().enumerate() {
887            perm_inv[p.to_signed().zx()] = I::from_signed(truncate(i));
888        }
889
890        (
891            BunchKaufmanInfo {
892                dynamic_regularization_count,
893                transposition_count,
894            },
895            unsafe { PermutationMut::new_unchecked(perm, perm_inv) },
896        )
897    }
898}
899
900pub mod solve {
901    use super::*;
902    use faer_core::assert;
903
904    #[track_caller]
905    pub fn solve_in_place_req<I: Index, E: Entity>(
906        dim: usize,
907        rhs_ncols: usize,
908        parallelism: Parallelism,
909    ) -> Result<StackReq, SizeOverflow> {
910        let _ = parallelism;
911        temp_mat_req::<E>(dim, rhs_ncols)
912    }
913
914    #[track_caller]
915    pub fn solve_in_place_with_conj<I: Index, E: ComplexField>(
916        lb_factors: MatRef<'_, E>,
917        subdiag: MatRef<'_, E>,
918        conj: Conj,
919        perm: PermutationRef<'_, I, E>,
920        rhs: MatMut<'_, E>,
921        parallelism: Parallelism,
922        stack: PodStack<'_>,
923    ) {
924        let n = lb_factors.nrows();
925        let k = rhs.ncols();
926
927        assert!(all(
928            lb_factors.nrows() == lb_factors.ncols(),
929            rhs.nrows() == n,
930            subdiag.nrows() == n,
931            subdiag.ncols() == 1,
932            perm.len() == n
933        ));
934
935        let a = lb_factors;
936        let par = parallelism;
937        let not_conj = conj.compose(Conj::Yes);
938
939        let mut rhs = rhs;
940        let mut x = temp_mat_uninit::<E>(n, k, stack).0;
941
942        permute_rows(x.rb_mut(), rhs.rb(), perm);
943        solve_unit_lower_triangular_in_place_with_conj(a, conj, x.rb_mut(), par);
944
945        let mut i = 0;
946        while i < n {
947            if subdiag.read(i, 0) == E::faer_zero() {
948                let d_inv = a.read(i, i).faer_real();
949                for j in 0..k {
950                    x.write(i, j, x.read(i, j).faer_scale_real(d_inv));
951                }
952                i += 1;
953            } else {
954                if conj == Conj::Yes {
955                    let akp1k = subdiag.read(i, 0);
956                    let ak = a.read(i, i).faer_real();
957                    let akp1 = a.read(i + 1, i + 1).faer_real();
958
959                    for j in 0..k {
960                        let xk = x.read(i, j);
961                        let xkp1 = x.read(i + 1, j);
962
963                        x.write(i, j, xk.faer_scale_real(ak).faer_add(xkp1.faer_mul(akp1k)));
964                        x.write(
965                            i + 1,
966                            j,
967                            xkp1.faer_scale_real(akp1)
968                                .faer_add(xk.faer_mul(akp1k.faer_conj())),
969                        );
970                    }
971                } else {
972                    let akp1k = subdiag.read(i, 0);
973                    let ak = a.read(i, i).faer_real();
974                    let akp1 = a.read(i + 1, i + 1).faer_real();
975
976                    for j in 0..k {
977                        let xk = x.read(i, j);
978                        let xkp1 = x.read(i + 1, j);
979
980                        x.write(
981                            i,
982                            j,
983                            xk.faer_scale_real(ak)
984                                .faer_add(xkp1.faer_mul(akp1k.faer_conj())),
985                        );
986                        x.write(
987                            i + 1,
988                            j,
989                            xkp1.faer_scale_real(akp1).faer_add(xk.faer_mul(akp1k)),
990                        );
991                    }
992                }
993                i += 2;
994            }
995        }
996
997        solve_unit_upper_triangular_in_place_with_conj(a.transpose(), not_conj, x.rb_mut(), par);
998        permute_rows(rhs.rb_mut(), x.rb(), perm.inverse());
999    }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    use crate::bunch_kaufman::compute::BunchKaufmanParams;
1005
1006    use super::*;
1007    use dyn_stack::GlobalPodBuffer;
1008    use faer_core::{assert, c64, Mat};
1009    use rand::random;
1010
1011    #[test]
1012    fn test_real() {
1013        for n in [3, 6, 19, 100, 421] {
1014            let a = Mat::<f64>::from_fn(n, n, |_, _| random());
1015            let a = &a + a.adjoint();
1016            let rhs = Mat::<f64>::from_fn(n, 2, |_, _| random());
1017
1018            let mut ldl = a.clone();
1019            let mut subdiag = Mat::<f64>::zeros(n, 1);
1020
1021            let mut perm = vec![0usize; n];
1022            let mut perm_inv = vec![0; n];
1023
1024            let params = Default::default();
1025            let mut mem = GlobalPodBuffer::new(
1026                compute::cholesky_in_place_req::<usize, f64>(n, Parallelism::None, params).unwrap(),
1027            );
1028            let (_, perm) = compute::cholesky_in_place(
1029                ldl.as_mut(),
1030                subdiag.as_mut(),
1031                Default::default(),
1032                &mut perm,
1033                &mut perm_inv,
1034                Parallelism::None,
1035                PodStack::new(&mut mem),
1036                params,
1037            );
1038
1039            let mut mem = GlobalPodBuffer::new(
1040                solve::solve_in_place_req::<usize, f64>(n, rhs.ncols(), Parallelism::None).unwrap(),
1041            );
1042            let mut x = rhs.clone();
1043            solve::solve_in_place_with_conj(
1044                ldl.as_ref(),
1045                subdiag.as_ref(),
1046                Conj::No,
1047                perm.rb(),
1048                x.as_mut(),
1049                Parallelism::None,
1050                PodStack::new(&mut mem),
1051            );
1052
1053            let err = &a * &x - &rhs;
1054            let mut max = 0.0;
1055            zipped!(err.as_ref()).for_each(|unzipped!(err)| {
1056                let err = err.read().abs();
1057                if err > max {
1058                    max = err
1059                }
1060            });
1061            assert!(max < 1e-9);
1062        }
1063    }
1064
1065    #[test]
1066    fn test_cplx() {
1067        for n in [3, 6, 19, 100, 421] {
1068            let a = Mat::<c64>::from_fn(n, n, |_, _| c64::new(random(), random()));
1069            let a = &a + a.adjoint();
1070            let rhs = Mat::<c64>::from_fn(n, 2, |_, _| c64::new(random(), random()));
1071
1072            let mut ldl = a.clone();
1073            let mut subdiag = Mat::<c64>::zeros(n, 1);
1074
1075            let mut perm = vec![0usize; n];
1076            let mut perm_inv = vec![0; n];
1077
1078            let params = BunchKaufmanParams {
1079                pivoting: compute::PivotingStrategy::Diagonal,
1080                blocksize: 32,
1081            };
1082            let mut mem = GlobalPodBuffer::new(
1083                compute::cholesky_in_place_req::<usize, c64>(n, Parallelism::None, params).unwrap(),
1084            );
1085            let (_, perm) = compute::cholesky_in_place(
1086                ldl.as_mut(),
1087                subdiag.as_mut(),
1088                Default::default(),
1089                &mut perm,
1090                &mut perm_inv,
1091                Parallelism::None,
1092                PodStack::new(&mut mem),
1093                params,
1094            );
1095
1096            let mut x = rhs.clone();
1097            let mut mem = GlobalPodBuffer::new(
1098                solve::solve_in_place_req::<usize, c64>(n, rhs.ncols(), Parallelism::None).unwrap(),
1099            );
1100            solve::solve_in_place_with_conj(
1101                ldl.as_ref(),
1102                subdiag.as_ref(),
1103                Conj::Yes,
1104                perm.rb(),
1105                x.as_mut(),
1106                Parallelism::None,
1107                PodStack::new(&mut mem),
1108            );
1109
1110            let err = a.conjugate() * &x - &rhs;
1111            let mut max = 0.0;
1112            zipped!(err.as_ref()).for_each(|unzipped!(err)| {
1113                let err = err.read().abs();
1114                if err > max {
1115                    max = err
1116                }
1117            });
1118            for i in 0..n {
1119                assert!(ldl[(i, i)].faer_imag() == 0.0);
1120            }
1121            assert!(max < 1e-9);
1122        }
1123    }
1124}