faer_cholesky/ldlt_diagonal/
compute.rs

1use dyn_stack::{PodStack, SizeOverflow, StackReq};
2use faer_core::{
3    assert, debug_assert, group_helpers::*, mul::triangular::BlockStructure, solve, temp_mat_req,
4    temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism,
5    SimdCtx,
6};
7use faer_entity::*;
8use reborrow::*;
9
10pub(crate) struct RankUpdate<'a, E: ComplexField> {
11    pub a21: MatMut<'a, E>,
12    pub l20: MatRef<'a, E>,
13    pub l10: MatRef<'a, E>,
14}
15
16impl<E: ComplexField> pulp::WithSimd for RankUpdate<'_, E> {
17    type Output = ();
18
19    #[inline(always)]
20    fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
21        let Self { a21, l20, l10 } = self;
22
23        debug_assert_eq!(a21.row_stride(), 1);
24        debug_assert_eq!(l20.row_stride(), 1);
25        debug_assert_eq!(l20.nrows(), a21.nrows());
26        debug_assert_eq!(l20.ncols(), l10.ncols());
27        debug_assert_eq!(a21.ncols(), 1);
28        debug_assert_eq!(l10.nrows(), 1);
29
30        let m = l20.nrows();
31        let n = l20.ncols();
32
33        if m == 0 {
34            return;
35        }
36
37        let simd = SimdFor::<E, S>::new(simd);
38        let acc = SliceGroupMut::<'_, E>::new(a21.try_get_contiguous_col_mut(0));
39        let offset = simd.align_offset(acc.rb());
40
41        let (mut acc_head, mut acc_body, mut acc_tail) = simd.as_aligned_simd_mut(acc, offset);
42
43        for j in 0..n {
44            let l10 = simd.splat(l10.read(0, j).faer_neg().faer_conj());
45            let l20 = SliceGroup::<'_, E>::new(l20.try_get_contiguous_col(j));
46
47            let (l20_head, l20_body, l20_tail) = simd.as_aligned_simd(l20, offset);
48
49            #[inline(always)]
50            fn process<E: ComplexField, S: pulp::Simd>(
51                simd: SimdFor<E, S>,
52                mut acc: impl Write<Output = SimdGroupFor<E, S>>,
53                l20: impl Read<Output = SimdGroupFor<E, S>>,
54                l10: SimdGroupFor<E, S>,
55            ) {
56                let zero = simd.splat(E::faer_zero());
57                acc.write(simd.mul_add_e(l10, l20.read_or(zero), acc.read_or(zero)));
58            }
59
60            process(simd, acc_head.rb_mut(), l20_head, l10);
61            for (acc, l20) in acc_body
62                .rb_mut()
63                .into_mut_iter()
64                .zip(l20_body.into_ref_iter())
65            {
66                process(simd, acc, l20, l10)
67            }
68            process(simd, acc_tail.rb_mut(), l20_tail, l10);
69        }
70    }
71}
72
73fn cholesky_in_place_left_looking_impl<E: ComplexField>(
74    matrix: MatMut<'_, E>,
75    regularization: LdltRegularization<'_, E>,
76    parallelism: Parallelism,
77    params: LdltDiagParams,
78) -> usize {
79    let mut matrix = matrix;
80    let _ = parallelism;
81    let _ = params;
82
83    debug_assert!(
84        matrix.ncols() == matrix.nrows(),
85        "only square matrices can be decomposed into cholesky factors",
86    );
87
88    let n = matrix.nrows();
89
90    if n == 0 {
91        return 0;
92    }
93
94    let mut idx = 0;
95    let arch = E::Simd::default();
96
97    let eps = regularization.dynamic_regularization_epsilon.faer_abs();
98    let delta = regularization.dynamic_regularization_delta.faer_abs();
99    let has_eps = delta > E::Real::faer_zero();
100    let mut dynamic_regularization_count = 0usize;
101    loop {
102        let block_size = 1;
103
104        // we split L/D rows/cols into 3 sections each
105        //     ┌             ┐
106        //     | L00         |
107        // L = | L10 A11     |
108        //     | L20 A21 A22 |
109        //     └             ┘
110        //     ┌          ┐
111        //     | D0       |
112        // D = |    D1    |
113        //     |       D2 |
114        //     └          ┘
115        //
116        // we already computed L00, L10, L20, and D0. we now compute L11, L21, and D1
117
118        let (top_left, top_right, bottom_left, bottom_right) =
119            matrix.rb_mut().split_at_mut(idx, idx);
120        let l00 = top_left.into_const();
121        let d0 = l00.diagonal().column_vector();
122        let (_, l10, _, l20) = bottom_left.into_const().split_at(block_size, 0);
123        let (mut a11, _, a21, _) = bottom_right.split_at_mut(block_size, block_size);
124
125        // reserve space for L10×D0
126        let mut l10xd0 = top_right
127            .submatrix_mut(0, 0, idx, block_size)
128            .transpose_mut();
129
130        zipped!(l10xd0.rb_mut(), l10, d0.transpose().as_2d()).for_each(
131            |unzipped!(mut dst, src, factor)| {
132                dst.write(
133                    src.read()
134                        .faer_scale_real(factor.read().faer_real().faer_inv()),
135                )
136            },
137        );
138
139        let l10xd0 = l10xd0.into_const();
140
141        let mut d = a11
142            .read(0, 0)
143            .faer_sub(faer_core::mul::inner_prod::inner_prod_with_conj_arch(
144                arch,
145                l10xd0.row(0).transpose().as_2d(),
146                Conj::Yes,
147                l10.row(0).transpose().as_2d(),
148                Conj::No,
149            ))
150            .faer_real();
151
152        // dynamic regularization code taken from clarabel.rs with modifications
153        if has_eps {
154            if let Some(signs) = regularization.dynamic_regularization_signs {
155                if signs[idx] > 0 && d <= eps {
156                    d = delta;
157                    dynamic_regularization_count += 1;
158                } else if signs[idx] < 0 && d >= eps.faer_neg() {
159                    d = delta.faer_neg();
160                    dynamic_regularization_count += 1;
161                }
162            } else if d.faer_abs() <= eps {
163                if d < E::Real::faer_zero() {
164                    d = delta.faer_neg();
165                } else {
166                    d = delta;
167                }
168                dynamic_regularization_count += 1;
169            }
170        }
171
172        let d = d.faer_inv();
173        a11.write(0, 0, E::faer_from_real(d));
174
175        if idx + block_size == n {
176            break;
177        }
178
179        let mut a21 = a21.col_mut(0);
180
181        // A21 -= L20 × L10^H
182        if a21.row_stride() == 1 {
183            arch.dispatch(RankUpdate {
184                a21: a21.rb_mut().as_2d_mut(),
185                l20,
186                l10: l10xd0,
187            });
188        } else {
189            for j in 0..idx {
190                let l20_col = l20.col(j);
191                let l10_conj = l10xd0.read(0, j).faer_conj();
192
193                zipped!(a21.rb_mut().as_2d_mut(), l20_col.as_2d()).for_each(
194                    |unzipped!(mut dst, src)| {
195                        dst.write(dst.read().faer_sub(src.read().faer_mul(l10_conj)))
196                    },
197                );
198            }
199        }
200
201        zipped!(a21.rb_mut().as_2d_mut())
202            .for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d)));
203
204        idx += block_size;
205    }
206    dynamic_regularization_count
207}
208
209#[derive(Default, Copy, Clone)]
210#[non_exhaustive]
211pub struct LdltDiagParams {}
212
213/// Computes the size and alignment of required workspace for performing a Cholesky decomposition.
214pub fn raw_cholesky_in_place_req<E: Entity>(
215    dim: usize,
216    parallelism: Parallelism,
217    params: LdltDiagParams,
218) -> Result<StackReq, SizeOverflow> {
219    let _ = parallelism;
220    let _ = params;
221    temp_mat_req::<E>(dim, dim)
222}
223
224// uses an out parameter for tail recursion
225fn cholesky_in_place_impl<E: ComplexField>(
226    count: &mut usize,
227    matrix: MatMut<'_, E>,
228    regularization: LdltRegularization<'_, E>,
229    parallelism: Parallelism,
230    stack: PodStack<'_>,
231    params: LdltDiagParams,
232) {
233    // right looking cholesky
234
235    debug_assert!(matrix.nrows() == matrix.ncols());
236    let mut matrix = matrix;
237    let mut stack = stack;
238
239    let n = matrix.nrows();
240    if n < 32 {
241        *count += cholesky_in_place_left_looking_impl(matrix, regularization, parallelism, params)
242    } else {
243        let block_size = Ord::min(n / 2, 128);
244        let rem = n - block_size;
245        let (mut l00, _, mut a10, mut a11) = matrix.rb_mut().split_at_mut(block_size, block_size);
246
247        cholesky_in_place_impl(
248            count,
249            l00.rb_mut(),
250            regularization,
251            parallelism,
252            stack.rb_mut(),
253            params,
254        );
255
256        let l00 = l00.into_const();
257        let d0 = l00.diagonal().column_vector();
258
259        solve::solve_unit_lower_triangular_in_place(
260            l00.conjugate(),
261            a10.rb_mut().transpose_mut(),
262            parallelism,
263        );
264
265        {
266            // reserve space for L10×D0
267            let (mut l10xd0, _) = temp_mat_uninit::<E>(rem, block_size, stack.rb_mut());
268            let mut l10xd0 = l10xd0.as_mut();
269
270            for j in 0..block_size {
271                let l10xd0_col = l10xd0.rb_mut().col_mut(j);
272                let a10_col = a10.rb_mut().col_mut(j);
273                let d0_elem = d0.read(j);
274
275                zipped!(l10xd0_col.as_2d_mut(), a10_col.as_2d_mut()).for_each(
276                    |unzipped!(mut l10xd0_elem, mut a10_elem)| {
277                        let a10_elem_read = a10_elem.read();
278                        a10_elem.write(a10_elem_read.faer_mul(d0_elem));
279                        l10xd0_elem.write(a10_elem_read);
280                    },
281                );
282            }
283
284            faer_core::mul::triangular::matmul(
285                a11.rb_mut(),
286                BlockStructure::TriangularLower,
287                a10.into_const(),
288                BlockStructure::Rectangular,
289                l10xd0.adjoint_mut().into_const(),
290                BlockStructure::Rectangular,
291                Some(E::faer_one()),
292                E::faer_one().faer_neg(),
293                parallelism,
294            );
295        }
296
297        cholesky_in_place_impl(
298            count,
299            a11,
300            LdltRegularization {
301                dynamic_regularization_signs: regularization
302                    .dynamic_regularization_signs
303                    .map(|signs| &signs[block_size..]),
304                dynamic_regularization_delta: regularization.dynamic_regularization_delta,
305                dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
306            },
307            parallelism,
308            stack,
309            params,
310        )
311    }
312}
313
314/// Dynamic LDLT regularization.
315#[derive(Copy, Clone, Debug)]
316pub struct LdltRegularization<'a, E: ComplexField> {
317    pub dynamic_regularization_signs: Option<&'a [i8]>,
318    pub dynamic_regularization_delta: E::Real,
319    pub dynamic_regularization_epsilon: E::Real,
320}
321
322#[derive(Copy, Clone, Debug)]
323pub struct LdltInfo {
324    pub dynamic_regularization_count: usize,
325}
326
327impl<E: ComplexField> Default for LdltRegularization<'_, E> {
328    fn default() -> Self {
329        Self {
330            dynamic_regularization_signs: None,
331            dynamic_regularization_delta: E::Real::faer_zero(),
332            dynamic_regularization_epsilon: E::Real::faer_zero(),
333        }
334    }
335}
336
337/// Computes the Cholesky factors $L$ and $D$ of the input matrix such that $L$ is strictly lower
338/// triangular, $D$ is real-valued diagonal, and
339/// $$LDL^H = A.$$
340///
341/// The result is stored back in the same matrix.
342///
343/// The input matrix is interpreted as symmetric and only the lower triangular part is read.
344///
345/// The matrix $L$ is stored in the strictly lower triangular part of the input matrix, and the
346/// inverses of the diagonal elements of $D$ are stored on the diagonal.
347///
348/// The strictly upper triangular part of the matrix is clobbered and may be filled with garbage
349/// values.
350///
351/// # Warning
352///
353/// The Cholesky decomposition with diagonal may have poor numerical stability properties when used
354/// with non positive definite matrices. In the general case, it is recommended to first permute
355/// (and conjugate when necessary) the rows and columns of the matrix using the permutation obtained
356/// from [`crate::compute_cholesky_permutation`].
357///
358/// # Panics
359///
360/// Panics if the input matrix is not square.
361///
362/// This can also panic if the provided memory in `stack` is insufficient (see
363/// [`raw_cholesky_in_place_req`]).
364#[track_caller]
365#[inline]
366pub fn raw_cholesky_in_place<E: ComplexField>(
367    matrix: MatMut<'_, E>,
368    regularization: LdltRegularization<'_, E>,
369    parallelism: Parallelism,
370    stack: PodStack<'_>,
371    params: LdltDiagParams,
372) -> LdltInfo {
373    assert!(matrix.ncols() == matrix.nrows());
374    #[cfg(feature = "perf-warn")]
375    if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) {
376        if matrix.col_stride().unsigned_abs() == 1 {
377            log::warn!(target: "faer_perf", "LDLT prefers column-major matrix. Found row-major matrix.");
378        } else {
379            log::warn!(target: "faer_perf", "LDLT prefers column-major matrix. Found matrix with generic strides.");
380        }
381    }
382
383    let mut count = 0;
384    cholesky_in_place_impl(
385        &mut count,
386        matrix,
387        regularization,
388        parallelism,
389        stack,
390        params,
391    );
392    LdltInfo {
393        dynamic_regularization_count: count,
394    }
395}