1use super::CholeskyError;
2use crate::ldlt_diagonal::compute::RankUpdate;
3use dyn_stack::{PodStack, SizeOverflow, StackReq};
4use faer_core::{
5 assert, debug_assert, mul::triangular::BlockStructure, parallelism_degree, solve, unzipped,
6 zipped, ComplexField, Entity, MatMut, Parallelism, SimdCtx,
7};
8use reborrow::*;
9
10fn cholesky_in_place_left_looking_impl<E: ComplexField>(
11 offset: usize,
12 matrix: MatMut<'_, E>,
13 regularization: LltRegularization<E>,
14 parallelism: Parallelism,
15 params: LltParams,
16) -> Result<usize, CholeskyError> {
17 let mut matrix = matrix;
18 let _ = params;
19 let _ = parallelism;
20 assert_eq!(matrix.ncols(), matrix.nrows());
21
22 let n = matrix.nrows();
23
24 if n == 0 {
25 return Ok(0);
26 }
27
28 let mut idx = 0;
29 let arch = E::Simd::default();
30 let eps = regularization
31 .dynamic_regularization_epsilon
32 .faer_real()
33 .faer_abs();
34 let delta = regularization
35 .dynamic_regularization_delta
36 .faer_real()
37 .faer_abs();
38 let has_eps = delta > E::Real::faer_zero();
39 let mut dynamic_regularization_count = 0usize;
40 loop {
41 let block_size = 1;
42
43 let (_, _, bottom_left, bottom_right) = matrix.rb_mut().split_at_mut(idx, idx);
44 let (_, l10, _, l20) = bottom_left.into_const().split_at(block_size, 0);
45 let (mut a11, _, a21, _) = bottom_right.split_at_mut(block_size, block_size);
46
47 let l10 = l10.row(0);
48 let mut a21 = a21.col_mut(0);
49
50 let mut dot = E::Real::faer_zero();
69 for j in 0..idx {
70 dot = dot.faer_add(l10.read(j).faer_abs2());
71 }
72 a11.write(
73 0,
74 0,
75 E::faer_from_real(a11.read(0, 0).faer_real().faer_sub(dot)),
76 );
77
78 let mut real = a11.read(0, 0).faer_real();
79 if has_eps && real >= E::Real::faer_zero() && real <= eps {
80 real = delta;
81 dynamic_regularization_count += 1;
82 }
83
84 if real > E::Real::faer_zero() {
85 a11.write(0, 0, E::faer_from_real(real.faer_sqrt()));
86 } else {
87 return Err(CholeskyError {
88 non_positive_definite_minor: offset + idx + 1,
89 });
90 };
91
92 if idx + block_size == n {
93 break;
94 }
95
96 let l11 = a11.read(0, 0);
97
98 if a21.row_stride() == 1 {
100 arch.dispatch(RankUpdate {
101 a21: a21.rb_mut().as_2d_mut(),
102 l20,
103 l10: l10.as_2d(),
104 });
105 } else {
106 for j in 0..idx {
107 let l20_col = l20.col(j);
108 let l10_conj = l10.read(j).faer_conj();
109
110 zipped!(a21.rb_mut().as_2d_mut(), l20_col.as_2d()).for_each(
111 |unzipped!(mut dst, src)| {
112 dst.write(dst.read().faer_sub(src.read().faer_mul(l10_conj)))
113 },
114 );
115 }
116 }
117
118 let r = l11.faer_real().faer_inv();
124 zipped!(a21.rb_mut().as_2d_mut())
125 .for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(r)));
126
127 idx += block_size;
128 }
129 Ok(dynamic_regularization_count)
130}
131
132#[derive(Default, Copy, Clone)]
133#[non_exhaustive]
134pub struct LltParams {}
135
136#[derive(Copy, Clone, Debug)]
138pub struct LltRegularization<E: ComplexField> {
139 pub dynamic_regularization_delta: E::Real,
140 pub dynamic_regularization_epsilon: E::Real,
141}
142
143impl<E: ComplexField> Default for LltRegularization<E> {
144 fn default() -> Self {
145 Self {
146 dynamic_regularization_delta: E::Real::faer_zero(),
147 dynamic_regularization_epsilon: E::Real::faer_zero(),
148 }
149 }
150}
151
152pub fn cholesky_in_place_req<E: Entity>(
155 dim: usize,
156 parallelism: Parallelism,
157 params: LltParams,
158) -> Result<StackReq, SizeOverflow> {
159 let _ = dim;
160 let _ = parallelism;
161 let _ = params;
162 Ok(StackReq::default())
163}
164
165fn cholesky_in_place_impl<E: ComplexField>(
167 offset: usize,
168 count: &mut usize,
169 matrix: MatMut<'_, E>,
170 regularization: LltRegularization<E>,
171 parallelism: Parallelism,
172 stack: PodStack<'_>,
173 params: LltParams,
174) -> Result<(), CholeskyError> {
175 debug_assert!(matrix.nrows() == matrix.ncols());
178 let mut matrix = matrix;
179 let mut stack = stack;
180
181 let n = matrix.nrows();
182 if n < 32 {
183 *count += cholesky_in_place_left_looking_impl(
184 offset,
185 matrix,
186 regularization,
187 parallelism,
188 params,
189 )?;
190 Ok(())
191 } else {
192 let block_size = Ord::min(n / 2, 128 * parallelism_degree(parallelism));
193 let (mut l00, _, mut a10, mut a11) = matrix.rb_mut().split_at_mut(block_size, block_size);
194
195 cholesky_in_place_impl(
196 offset,
197 count,
198 l00.rb_mut(),
199 regularization,
200 parallelism,
201 stack.rb_mut(),
202 params,
203 )?;
204
205 let l00 = l00.into_const();
206
207 solve::solve_lower_triangular_in_place(
208 l00.conjugate(),
209 a10.rb_mut().transpose_mut(),
210 parallelism,
211 );
212
213 faer_core::mul::triangular::matmul(
214 a11.rb_mut(),
215 BlockStructure::TriangularLower,
216 a10.rb(),
217 BlockStructure::Rectangular,
218 a10.rb().adjoint(),
219 BlockStructure::Rectangular,
220 Some(E::faer_one()),
221 E::faer_one().faer_neg(),
222 parallelism,
223 );
224
225 cholesky_in_place_impl(
226 offset + block_size,
227 count,
228 a11,
229 regularization,
230 parallelism,
231 stack,
232 params,
233 )
234 }
235}
236
237#[derive(Copy, Clone, Debug)]
238pub struct LltInfo {
239 pub dynamic_regularization_count: usize,
240}
241
242#[track_caller]
261#[inline]
262pub fn cholesky_in_place<E: ComplexField>(
263 matrix: MatMut<'_, E>,
264 regularization: LltRegularization<E>,
265 parallelism: Parallelism,
266 stack: PodStack<'_>,
267 params: LltParams,
268) -> Result<LltInfo, CholeskyError> {
269 let _ = params;
270 assert!(matrix.ncols() == matrix.nrows());
271 #[cfg(feature = "perf-warn")]
272 if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) {
273 if matrix.col_stride().unsigned_abs() == 1 {
274 log::warn!(target: "faer_perf", "LLT prefers column-major matrix. Found row-major matrix.");
275 } else {
276 log::warn!(target: "faer_perf", "LLT prefers column-major matrix. Found matrix with generic strides.");
277 }
278 }
279
280 let mut count = 0;
281 cholesky_in_place_impl(
282 0,
283 &mut count,
284 matrix,
285 regularization,
286 parallelism,
287 stack,
288 params,
289 )?;
290 Ok(LltInfo {
291 dynamic_regularization_count: count,
292 })
293}