faer_cholesky/llt/
compute.rs

1use super::CholeskyError;
2use crate::ldlt_diagonal::compute::RankUpdate;
3use dyn_stack::{PodStack, SizeOverflow, StackReq};
4use faer_core::{
5    assert, debug_assert, mul::triangular::BlockStructure, parallelism_degree, solve, unzipped,
6    zipped, ComplexField, Entity, MatMut, Parallelism, SimdCtx,
7};
8use reborrow::*;
9
10fn cholesky_in_place_left_looking_impl<E: ComplexField>(
11    offset: usize,
12    matrix: MatMut<'_, E>,
13    regularization: LltRegularization<E>,
14    parallelism: Parallelism,
15    params: LltParams,
16) -> Result<usize, CholeskyError> {
17    let mut matrix = matrix;
18    let _ = params;
19    let _ = parallelism;
20    assert_eq!(matrix.ncols(), matrix.nrows());
21
22    let n = matrix.nrows();
23
24    if n == 0 {
25        return Ok(0);
26    }
27
28    let mut idx = 0;
29    let arch = E::Simd::default();
30    let eps = regularization
31        .dynamic_regularization_epsilon
32        .faer_real()
33        .faer_abs();
34    let delta = regularization
35        .dynamic_regularization_delta
36        .faer_real()
37        .faer_abs();
38    let has_eps = delta > E::Real::faer_zero();
39    let mut dynamic_regularization_count = 0usize;
40    loop {
41        let block_size = 1;
42
43        let (_, _, bottom_left, bottom_right) = matrix.rb_mut().split_at_mut(idx, idx);
44        let (_, l10, _, l20) = bottom_left.into_const().split_at(block_size, 0);
45        let (mut a11, _, a21, _) = bottom_right.split_at_mut(block_size, block_size);
46
47        let l10 = l10.row(0);
48        let mut a21 = a21.col_mut(0);
49
50        //
51        //      L00
52        // A =  L10  A11
53        //      L20  A21  A22
54        //
55        // the first column block is already computed
56        // we now compute A11 and A21
57        //
58        // L00           L00^H L10^H L20^H
59        // L10 L11             L11^H L21^H
60        // L20 L21 L22 ×             L22^H
61        //
62        //
63        // L00×L00^H
64        // L10×L00^H  L10×L10^H + L11×L11^H
65        // L20×L00^H  L20×L10^H + L21×L11^H  L20×L20^H + L21×L21^H + L22×L22^H
66
67        // A11 -= L10 × L10^H
68        let mut dot = E::Real::faer_zero();
69        for j in 0..idx {
70            dot = dot.faer_add(l10.read(j).faer_abs2());
71        }
72        a11.write(
73            0,
74            0,
75            E::faer_from_real(a11.read(0, 0).faer_real().faer_sub(dot)),
76        );
77
78        let mut real = a11.read(0, 0).faer_real();
79        if has_eps && real >= E::Real::faer_zero() && real <= eps {
80            real = delta;
81            dynamic_regularization_count += 1;
82        }
83
84        if real > E::Real::faer_zero() {
85            a11.write(0, 0, E::faer_from_real(real.faer_sqrt()));
86        } else {
87            return Err(CholeskyError {
88                non_positive_definite_minor: offset + idx + 1,
89            });
90        };
91
92        if idx + block_size == n {
93            break;
94        }
95
96        let l11 = a11.read(0, 0);
97
98        // A21 -= L20 × L10^H
99        if a21.row_stride() == 1 {
100            arch.dispatch(RankUpdate {
101                a21: a21.rb_mut().as_2d_mut(),
102                l20,
103                l10: l10.as_2d(),
104            });
105        } else {
106            for j in 0..idx {
107                let l20_col = l20.col(j);
108                let l10_conj = l10.read(j).faer_conj();
109
110                zipped!(a21.rb_mut().as_2d_mut(), l20_col.as_2d()).for_each(
111                    |unzipped!(mut dst, src)| {
112                        dst.write(dst.read().faer_sub(src.read().faer_mul(l10_conj)))
113                    },
114                );
115            }
116        }
117
118        // A21 is now L21×L11^H
119        // find L21
120        //
121        // conj(L11) L21^T = A21^T
122
123        let r = l11.faer_real().faer_inv();
124        zipped!(a21.rb_mut().as_2d_mut())
125            .for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(r)));
126
127        idx += block_size;
128    }
129    Ok(dynamic_regularization_count)
130}
131
132#[derive(Default, Copy, Clone)]
133#[non_exhaustive]
134pub struct LltParams {}
135
136/// Dynamic LLT regularization.
137#[derive(Copy, Clone, Debug)]
138pub struct LltRegularization<E: ComplexField> {
139    pub dynamic_regularization_delta: E::Real,
140    pub dynamic_regularization_epsilon: E::Real,
141}
142
143impl<E: ComplexField> Default for LltRegularization<E> {
144    fn default() -> Self {
145        Self {
146            dynamic_regularization_delta: E::Real::faer_zero(),
147            dynamic_regularization_epsilon: E::Real::faer_zero(),
148        }
149    }
150}
151
152/// Computes the size and alignment of required workspace for performing a Cholesky
153/// decomposition.
154pub fn cholesky_in_place_req<E: Entity>(
155    dim: usize,
156    parallelism: Parallelism,
157    params: LltParams,
158) -> Result<StackReq, SizeOverflow> {
159    let _ = dim;
160    let _ = parallelism;
161    let _ = params;
162    Ok(StackReq::default())
163}
164
165// uses an out parameter for tail recursion
166fn cholesky_in_place_impl<E: ComplexField>(
167    offset: usize,
168    count: &mut usize,
169    matrix: MatMut<'_, E>,
170    regularization: LltRegularization<E>,
171    parallelism: Parallelism,
172    stack: PodStack<'_>,
173    params: LltParams,
174) -> Result<(), CholeskyError> {
175    // right looking cholesky
176
177    debug_assert!(matrix.nrows() == matrix.ncols());
178    let mut matrix = matrix;
179    let mut stack = stack;
180
181    let n = matrix.nrows();
182    if n < 32 {
183        *count += cholesky_in_place_left_looking_impl(
184            offset,
185            matrix,
186            regularization,
187            parallelism,
188            params,
189        )?;
190        Ok(())
191    } else {
192        let block_size = Ord::min(n / 2, 128 * parallelism_degree(parallelism));
193        let (mut l00, _, mut a10, mut a11) = matrix.rb_mut().split_at_mut(block_size, block_size);
194
195        cholesky_in_place_impl(
196            offset,
197            count,
198            l00.rb_mut(),
199            regularization,
200            parallelism,
201            stack.rb_mut(),
202            params,
203        )?;
204
205        let l00 = l00.into_const();
206
207        solve::solve_lower_triangular_in_place(
208            l00.conjugate(),
209            a10.rb_mut().transpose_mut(),
210            parallelism,
211        );
212
213        faer_core::mul::triangular::matmul(
214            a11.rb_mut(),
215            BlockStructure::TriangularLower,
216            a10.rb(),
217            BlockStructure::Rectangular,
218            a10.rb().adjoint(),
219            BlockStructure::Rectangular,
220            Some(E::faer_one()),
221            E::faer_one().faer_neg(),
222            parallelism,
223        );
224
225        cholesky_in_place_impl(
226            offset + block_size,
227            count,
228            a11,
229            regularization,
230            parallelism,
231            stack,
232            params,
233        )
234    }
235}
236
237#[derive(Copy, Clone, Debug)]
238pub struct LltInfo {
239    pub dynamic_regularization_count: usize,
240}
241
242/// Computes the Cholesky factor $L$ of a hermitian positive definite input matrix $A$ such that
243/// $L$ is lower triangular, and
244/// $$LL^H == A.$$
245///
246/// The result is stored back in the lower half of the same matrix, or an error is returned if the
247/// matrix is not positive definite.
248///
249/// The input matrix is interpreted as symmetric and only the lower triangular part is read.
250///
251/// The strictly upper triangular part of the matrix is clobbered and may be filled with garbage
252/// values.
253///
254/// # Panics
255///
256/// Panics if the input matrix is not square.
257///
258/// This can also panic if the provided memory in `stack` is insufficient (see
259/// [`cholesky_in_place_req`]).
260#[track_caller]
261#[inline]
262pub fn cholesky_in_place<E: ComplexField>(
263    matrix: MatMut<'_, E>,
264    regularization: LltRegularization<E>,
265    parallelism: Parallelism,
266    stack: PodStack<'_>,
267    params: LltParams,
268) -> Result<LltInfo, CholeskyError> {
269    let _ = params;
270    assert!(matrix.ncols() == matrix.nrows());
271    #[cfg(feature = "perf-warn")]
272    if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) {
273        if matrix.col_stride().unsigned_abs() == 1 {
274            log::warn!(target: "faer_perf", "LLT prefers column-major matrix. Found row-major matrix.");
275        } else {
276            log::warn!(target: "faer_perf", "LLT prefers column-major matrix. Found matrix with generic strides.");
277        }
278    }
279
280    let mut count = 0;
281    cholesky_in_place_impl(
282        0,
283        &mut count,
284        matrix,
285        regularization,
286        parallelism,
287        stack,
288        params,
289    )?;
290    Ok(LltInfo {
291        dynamic_regularization_count: count,
292    })
293}