1use dyn_stack::{PodStack, SizeOverflow, StackReq};
2use faer_core::{
3 assert, debug_assert, group_helpers::*, mul::triangular::BlockStructure, solve, temp_mat_req,
4 temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism,
5 SimdCtx,
6};
7use faer_entity::*;
8use reborrow::*;
9
10pub(crate) struct RankUpdate<'a, E: ComplexField> {
11 pub a21: MatMut<'a, E>,
12 pub l20: MatRef<'a, E>,
13 pub l10: MatRef<'a, E>,
14}
15
16impl<E: ComplexField> pulp::WithSimd for RankUpdate<'_, E> {
17 type Output = ();
18
19 #[inline(always)]
20 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
21 let Self { a21, l20, l10 } = self;
22
23 debug_assert_eq!(a21.row_stride(), 1);
24 debug_assert_eq!(l20.row_stride(), 1);
25 debug_assert_eq!(l20.nrows(), a21.nrows());
26 debug_assert_eq!(l20.ncols(), l10.ncols());
27 debug_assert_eq!(a21.ncols(), 1);
28 debug_assert_eq!(l10.nrows(), 1);
29
30 let m = l20.nrows();
31 let n = l20.ncols();
32
33 if m == 0 {
34 return;
35 }
36
37 let simd = SimdFor::<E, S>::new(simd);
38 let acc = SliceGroupMut::<'_, E>::new(a21.try_get_contiguous_col_mut(0));
39 let offset = simd.align_offset(acc.rb());
40
41 let (mut acc_head, mut acc_body, mut acc_tail) = simd.as_aligned_simd_mut(acc, offset);
42
43 for j in 0..n {
44 let l10 = simd.splat(l10.read(0, j).faer_neg().faer_conj());
45 let l20 = SliceGroup::<'_, E>::new(l20.try_get_contiguous_col(j));
46
47 let (l20_head, l20_body, l20_tail) = simd.as_aligned_simd(l20, offset);
48
49 #[inline(always)]
50 fn process<E: ComplexField, S: pulp::Simd>(
51 simd: SimdFor<E, S>,
52 mut acc: impl Write<Output = SimdGroupFor<E, S>>,
53 l20: impl Read<Output = SimdGroupFor<E, S>>,
54 l10: SimdGroupFor<E, S>,
55 ) {
56 let zero = simd.splat(E::faer_zero());
57 acc.write(simd.mul_add_e(l10, l20.read_or(zero), acc.read_or(zero)));
58 }
59
60 process(simd, acc_head.rb_mut(), l20_head, l10);
61 for (acc, l20) in acc_body
62 .rb_mut()
63 .into_mut_iter()
64 .zip(l20_body.into_ref_iter())
65 {
66 process(simd, acc, l20, l10)
67 }
68 process(simd, acc_tail.rb_mut(), l20_tail, l10);
69 }
70 }
71}
72
73fn cholesky_in_place_left_looking_impl<E: ComplexField>(
74 matrix: MatMut<'_, E>,
75 regularization: LdltRegularization<'_, E>,
76 parallelism: Parallelism,
77 params: LdltDiagParams,
78) -> usize {
79 let mut matrix = matrix;
80 let _ = parallelism;
81 let _ = params;
82
83 debug_assert!(
84 matrix.ncols() == matrix.nrows(),
85 "only square matrices can be decomposed into cholesky factors",
86 );
87
88 let n = matrix.nrows();
89
90 if n == 0 {
91 return 0;
92 }
93
94 let mut idx = 0;
95 let arch = E::Simd::default();
96
97 let eps = regularization.dynamic_regularization_epsilon.faer_abs();
98 let delta = regularization.dynamic_regularization_delta.faer_abs();
99 let has_eps = delta > E::Real::faer_zero();
100 let mut dynamic_regularization_count = 0usize;
101 loop {
102 let block_size = 1;
103
104 let (top_left, top_right, bottom_left, bottom_right) =
119 matrix.rb_mut().split_at_mut(idx, idx);
120 let l00 = top_left.into_const();
121 let d0 = l00.diagonal().column_vector();
122 let (_, l10, _, l20) = bottom_left.into_const().split_at(block_size, 0);
123 let (mut a11, _, a21, _) = bottom_right.split_at_mut(block_size, block_size);
124
125 let mut l10xd0 = top_right
127 .submatrix_mut(0, 0, idx, block_size)
128 .transpose_mut();
129
130 zipped!(l10xd0.rb_mut(), l10, d0.transpose().as_2d()).for_each(
131 |unzipped!(mut dst, src, factor)| {
132 dst.write(
133 src.read()
134 .faer_scale_real(factor.read().faer_real().faer_inv()),
135 )
136 },
137 );
138
139 let l10xd0 = l10xd0.into_const();
140
141 let mut d = a11
142 .read(0, 0)
143 .faer_sub(faer_core::mul::inner_prod::inner_prod_with_conj_arch(
144 arch,
145 l10xd0.row(0).transpose().as_2d(),
146 Conj::Yes,
147 l10.row(0).transpose().as_2d(),
148 Conj::No,
149 ))
150 .faer_real();
151
152 if has_eps {
154 if let Some(signs) = regularization.dynamic_regularization_signs {
155 if signs[idx] > 0 && d <= eps {
156 d = delta;
157 dynamic_regularization_count += 1;
158 } else if signs[idx] < 0 && d >= eps.faer_neg() {
159 d = delta.faer_neg();
160 dynamic_regularization_count += 1;
161 }
162 } else if d.faer_abs() <= eps {
163 if d < E::Real::faer_zero() {
164 d = delta.faer_neg();
165 } else {
166 d = delta;
167 }
168 dynamic_regularization_count += 1;
169 }
170 }
171
172 let d = d.faer_inv();
173 a11.write(0, 0, E::faer_from_real(d));
174
175 if idx + block_size == n {
176 break;
177 }
178
179 let mut a21 = a21.col_mut(0);
180
181 if a21.row_stride() == 1 {
183 arch.dispatch(RankUpdate {
184 a21: a21.rb_mut().as_2d_mut(),
185 l20,
186 l10: l10xd0,
187 });
188 } else {
189 for j in 0..idx {
190 let l20_col = l20.col(j);
191 let l10_conj = l10xd0.read(0, j).faer_conj();
192
193 zipped!(a21.rb_mut().as_2d_mut(), l20_col.as_2d()).for_each(
194 |unzipped!(mut dst, src)| {
195 dst.write(dst.read().faer_sub(src.read().faer_mul(l10_conj)))
196 },
197 );
198 }
199 }
200
201 zipped!(a21.rb_mut().as_2d_mut())
202 .for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d)));
203
204 idx += block_size;
205 }
206 dynamic_regularization_count
207}
208
209#[derive(Default, Copy, Clone)]
210#[non_exhaustive]
211pub struct LdltDiagParams {}
212
213pub fn raw_cholesky_in_place_req<E: Entity>(
215 dim: usize,
216 parallelism: Parallelism,
217 params: LdltDiagParams,
218) -> Result<StackReq, SizeOverflow> {
219 let _ = parallelism;
220 let _ = params;
221 temp_mat_req::<E>(dim, dim)
222}
223
224fn cholesky_in_place_impl<E: ComplexField>(
226 count: &mut usize,
227 matrix: MatMut<'_, E>,
228 regularization: LdltRegularization<'_, E>,
229 parallelism: Parallelism,
230 stack: PodStack<'_>,
231 params: LdltDiagParams,
232) {
233 debug_assert!(matrix.nrows() == matrix.ncols());
236 let mut matrix = matrix;
237 let mut stack = stack;
238
239 let n = matrix.nrows();
240 if n < 32 {
241 *count += cholesky_in_place_left_looking_impl(matrix, regularization, parallelism, params)
242 } else {
243 let block_size = Ord::min(n / 2, 128);
244 let rem = n - block_size;
245 let (mut l00, _, mut a10, mut a11) = matrix.rb_mut().split_at_mut(block_size, block_size);
246
247 cholesky_in_place_impl(
248 count,
249 l00.rb_mut(),
250 regularization,
251 parallelism,
252 stack.rb_mut(),
253 params,
254 );
255
256 let l00 = l00.into_const();
257 let d0 = l00.diagonal().column_vector();
258
259 solve::solve_unit_lower_triangular_in_place(
260 l00.conjugate(),
261 a10.rb_mut().transpose_mut(),
262 parallelism,
263 );
264
265 {
266 let (mut l10xd0, _) = temp_mat_uninit::<E>(rem, block_size, stack.rb_mut());
268 let mut l10xd0 = l10xd0.as_mut();
269
270 for j in 0..block_size {
271 let l10xd0_col = l10xd0.rb_mut().col_mut(j);
272 let a10_col = a10.rb_mut().col_mut(j);
273 let d0_elem = d0.read(j);
274
275 zipped!(l10xd0_col.as_2d_mut(), a10_col.as_2d_mut()).for_each(
276 |unzipped!(mut l10xd0_elem, mut a10_elem)| {
277 let a10_elem_read = a10_elem.read();
278 a10_elem.write(a10_elem_read.faer_mul(d0_elem));
279 l10xd0_elem.write(a10_elem_read);
280 },
281 );
282 }
283
284 faer_core::mul::triangular::matmul(
285 a11.rb_mut(),
286 BlockStructure::TriangularLower,
287 a10.into_const(),
288 BlockStructure::Rectangular,
289 l10xd0.adjoint_mut().into_const(),
290 BlockStructure::Rectangular,
291 Some(E::faer_one()),
292 E::faer_one().faer_neg(),
293 parallelism,
294 );
295 }
296
297 cholesky_in_place_impl(
298 count,
299 a11,
300 LdltRegularization {
301 dynamic_regularization_signs: regularization
302 .dynamic_regularization_signs
303 .map(|signs| &signs[block_size..]),
304 dynamic_regularization_delta: regularization.dynamic_regularization_delta,
305 dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
306 },
307 parallelism,
308 stack,
309 params,
310 )
311 }
312}
313
314#[derive(Copy, Clone, Debug)]
316pub struct LdltRegularization<'a, E: ComplexField> {
317 pub dynamic_regularization_signs: Option<&'a [i8]>,
318 pub dynamic_regularization_delta: E::Real,
319 pub dynamic_regularization_epsilon: E::Real,
320}
321
322#[derive(Copy, Clone, Debug)]
323pub struct LdltInfo {
324 pub dynamic_regularization_count: usize,
325}
326
327impl<E: ComplexField> Default for LdltRegularization<'_, E> {
328 fn default() -> Self {
329 Self {
330 dynamic_regularization_signs: None,
331 dynamic_regularization_delta: E::Real::faer_zero(),
332 dynamic_regularization_epsilon: E::Real::faer_zero(),
333 }
334 }
335}
336
337#[track_caller]
365#[inline]
366pub fn raw_cholesky_in_place<E: ComplexField>(
367 matrix: MatMut<'_, E>,
368 regularization: LdltRegularization<'_, E>,
369 parallelism: Parallelism,
370 stack: PodStack<'_>,
371 params: LdltDiagParams,
372) -> LdltInfo {
373 assert!(matrix.ncols() == matrix.nrows());
374 #[cfg(feature = "perf-warn")]
375 if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) {
376 if matrix.col_stride().unsigned_abs() == 1 {
377 log::warn!(target: "faer_perf", "LDLT prefers column-major matrix. Found row-major matrix.");
378 } else {
379 log::warn!(target: "faer_perf", "LDLT prefers column-major matrix. Found matrix with generic strides.");
380 }
381 }
382
383 let mut count = 0;
384 cholesky_in_place_impl(
385 &mut count,
386 matrix,
387 regularization,
388 parallelism,
389 stack,
390 params,
391 );
392 LdltInfo {
393 dynamic_regularization_count: count,
394 }
395}