faer_cholesky/
lib.rs

1#![allow(clippy::type_complexity)]
2#![allow(clippy::too_many_arguments)]
3#![cfg_attr(not(feature = "std"), no_std)]
4
5use core::cmp::Ordering;
6use faer_core::{
7    assert,
8    permutation::{Index, PermutationMut, SignedIndex},
9    ComplexField, MatRef,
10};
11
12pub mod bunch_kaufman;
13pub mod ldlt_diagonal;
14pub mod llt;
15
16/// Computes a permutation that reduces the chance of numerical errors during the $LDL^H$
17/// factorization with diagonal $D$, then stores the result in `perm_indices` and
18/// `perm_inv_indices`.
19#[track_caller]
20pub fn compute_cholesky_permutation<'a, E: ComplexField, I: Index>(
21    perm_indices: &'a mut [I],
22    perm_inv_indices: &'a mut [I],
23    matrix: MatRef<'_, E>,
24) -> PermutationMut<'a, I, E> {
25    let n = matrix.nrows();
26    let truncate = <I::Signed as SignedIndex>::truncate;
27    assert!(
28        matrix.nrows() == matrix.ncols(),
29        "input matrix must be square",
30    );
31    assert!(
32        perm_indices.len() == n,
33        "length of permutation must be equal to the matrix dimension",
34    );
35    assert!(
36        perm_inv_indices.len() == n,
37        "length of inverse permutation must be equal to the matrix dimension",
38    );
39
40    for (i, p) in perm_indices.iter_mut().enumerate() {
41        *p = I::from_signed(truncate(i));
42    }
43
44    perm_indices.sort_unstable_by(move |&i, &j| {
45        let lhs = matrix
46            .read(i.to_signed().zx(), i.to_signed().zx())
47            .faer_abs();
48        let rhs = matrix
49            .read(j.to_signed().zx(), j.to_signed().zx())
50            .faer_abs();
51        let cmp = rhs.partial_cmp(&lhs);
52        if let Some(cmp) = cmp {
53            cmp
54        } else {
55            Ordering::Equal
56        }
57    });
58
59    for (i, p) in perm_indices.iter().copied().enumerate() {
60        perm_inv_indices[p.to_signed().zx()] = I::from_signed(truncate(i));
61    }
62
63    unsafe { PermutationMut::new_unchecked(perm_indices, perm_inv_indices) }
64}