1use core::hint::cold_path;
15
16use crate::matrix::{FiniteMatrix, Matrix, SymmetricMatrix};
17use crate::vector::{FiniteVector, Vector};
18use crate::{LaError, Tolerance};
19
20#[must_use]
37#[derive(Clone, Copy, Debug, PartialEq)]
38pub struct Ldlt<const D: usize> {
39 factors: LdltFactors<D>,
40}
41
42#[derive(Clone, Copy, Debug, PartialEq)]
47struct LdltFactors<const D: usize> {
48 storage: Matrix<D>,
49}
50
51impl<const D: usize> LdltFactors<D> {
52 #[inline]
54 const fn new_unchecked(storage: Matrix<D>) -> Self {
55 Self { storage }
56 }
57
58 #[inline]
60 #[must_use]
61 const fn row(&self, index: usize) -> &[f64; D] {
62 &self.storage.rows[index]
63 }
64
65 #[inline]
67 #[must_use]
68 const fn entry(&self, row: usize, col: usize) -> f64 {
69 self.storage.rows[row][col]
70 }
71
72 #[inline]
74 #[must_use]
75 const fn diag(&self, index: usize) -> f64 {
76 self.storage.rows[index][index]
77 }
78}
79
80impl<const D: usize> Ldlt<D> {
81 #[inline]
83 pub(crate) fn factor_symmetric(a: SymmetricMatrix<D>, tol: Tolerance) -> Result<Self, LaError> {
84 let mut f = a.into_matrix();
85 let tol = tol.get();
86
87 for j in 0..D {
89 let d = f.rows[j][j];
90 if !d.is_finite() {
91 cold_path();
92 return Err(LaError::non_finite_cell(j, j));
93 }
94 if d < 0.0 {
95 cold_path();
96 return Err(LaError::not_positive_semidefinite(j, d));
97 }
98 if d <= tol {
99 cold_path();
100 return Err(LaError::Singular { pivot_col: j });
101 }
102
103 for i in (j + 1)..D {
105 let l = f.rows[i][j] / d;
106 if !l.is_finite() {
107 cold_path();
108 return Err(LaError::non_finite_cell(i, j));
109 }
110 f.rows[i][j] = l;
111 }
112
113 for i in (j + 1)..D {
115 let l_i = f.rows[i][j];
116 let l_i_d = l_i * d;
117
118 for k in (j + 1)..=i {
119 let l_k = f.rows[k][j];
120 let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]);
121 f.rows[i][k] = new_val;
122 }
123 }
124 }
125
126 let f = FiniteMatrix::new(f)?.into_matrix();
127
128 Ok(Self {
129 factors: LdltFactors::new_unchecked(f),
130 })
131 }
132
133 #[inline]
155 pub const fn det(&self) -> Result<f64, LaError> {
156 let mut det = 1.0;
157 let mut i = 0;
158 while i < D {
159 det *= self.factors.diag(i);
160 if !det.is_finite() {
161 cold_path();
162 return Err(LaError::non_finite_at(i));
163 }
164 i += 1;
165 }
166 Ok(det)
167 }
168
169 #[inline]
198 pub const fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
199 let b = match FiniteVector::new(b) {
200 Ok(b) => b,
201 Err(err) => return Err(err),
202 };
203 match self.solve_finite_vec(b) {
204 Ok(x) => Ok(x.into_vector()),
205 Err(err) => Err(err),
206 }
207 }
208
209 #[inline]
218 pub(crate) const fn solve_finite_vec(
219 &self,
220 b: FiniteVector<D>,
221 ) -> Result<FiniteVector<D>, LaError> {
222 let mut x = b.into_array();
223
224 let mut i = 0;
226 while i < D {
227 let mut sum = x[i];
228 let row = self.factors.row(i);
229 let mut j = 0;
230 while j < i {
231 sum = (-row[j]).mul_add(x[j], sum);
232 j += 1;
233 }
234 if !sum.is_finite() {
235 cold_path();
236 return Err(LaError::non_finite_at(i));
237 }
238 x[i] = sum;
239 i += 1;
240 }
241
242 let mut i = 0;
244 while i < D {
245 let diag = self.factors.diag(i);
246
247 let quotient = x[i] / diag;
248 if !quotient.is_finite() {
249 cold_path();
250 return Err(LaError::non_finite_at(i));
251 }
252 x[i] = quotient;
253 i += 1;
254 }
255
256 let mut ii = 0;
258 while ii < D {
259 let i = D - 1 - ii;
260 let mut sum = x[i];
261 let mut j = i + 1;
262 while j < D {
263 sum = (-self.factors.entry(j, i)).mul_add(x[j], sum);
264 j += 1;
265 }
266 if !sum.is_finite() {
267 cold_path();
268 return Err(LaError::non_finite_at(i));
269 }
270 x[i] = sum;
271 ii += 1;
272 }
273
274 Ok(FiniteVector::new_unchecked(Vector::new_unchecked(x)))
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 use crate::DEFAULT_SINGULAR_TOL;
283 use crate::matrix::FiniteMatrix;
284
285 use core::assert_matches;
286 use core::hint::black_box;
287
288 use approx::assert_abs_diff_eq;
289 use pastey::paste;
290
291 macro_rules! gen_public_api_ldlt_identity_tests {
292 ($d:literal) => {
293 paste! {
294 #[test]
295 fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
296 let a = Matrix::<$d>::identity();
297 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
298
299 assert_abs_diff_eq!(ldlt.det().unwrap(), 1.0, epsilon = 1e-12);
300
301 let b_arr = {
302 let mut arr = [0.0f64; $d];
303 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
304 for (dst, src) in arr.iter_mut().zip(values.iter()) {
305 *dst = *src;
306 }
307 arr
308 };
309 let b = Vector::<$d>::new(black_box(b_arr));
310 let x = ldlt.solve_vec(b).unwrap().into_array();
311
312 for i in 0..$d {
313 assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
314 }
315 }
316 }
317 };
318 }
319
320 gen_public_api_ldlt_identity_tests!(2);
321 gen_public_api_ldlt_identity_tests!(3);
322 gen_public_api_ldlt_identity_tests!(4);
323 gen_public_api_ldlt_identity_tests!(5);
324
325 macro_rules! gen_public_api_ldlt_diagonal_tests {
326 ($d:literal) => {
327 paste! {
328 #[test]
329 fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
330 let diag = {
331 let mut arr = [0.0f64; $d];
332 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
333 for (dst, src) in arr.iter_mut().zip(values.iter()) {
334 *dst = *src;
335 }
336 arr
337 };
338
339 let mut rows = [[0.0f64; $d]; $d];
340 for i in 0..$d {
341 rows[i][i] = diag[i];
342 }
343
344 let a = Matrix::<$d>::from_rows(black_box(rows));
345 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
346
347 let expected_det = {
348 let mut acc = 1.0;
349 for i in 0..$d {
350 acc *= diag[i];
351 }
352 acc
353 };
354 assert_abs_diff_eq!(ldlt.det().unwrap(), expected_det, epsilon = 1e-12);
355
356 let b_arr = {
357 let mut arr = [0.0f64; $d];
358 let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
359 for (dst, src) in arr.iter_mut().zip(values.iter()) {
360 *dst = *src;
361 }
362 arr
363 };
364
365 let b = Vector::<$d>::new(black_box(b_arr));
366 let x = ldlt.solve_vec(b).unwrap().into_array();
367
368 for i in 0..$d {
369 assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
370 }
371 }
372 }
373 };
374 }
375
376 gen_public_api_ldlt_diagonal_tests!(2);
377 gen_public_api_ldlt_diagonal_tests!(3);
378 gen_public_api_ldlt_diagonal_tests!(4);
379 gen_public_api_ldlt_diagonal_tests!(5);
380
381 #[test]
382 fn solve_2x2_known_spd() {
383 let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]]));
384 let ldlt = FiniteMatrix::new(a)
385 .unwrap()
386 .ldlt(DEFAULT_SINGULAR_TOL)
387 .unwrap();
388
389 let b = Vector::<2>::new(black_box([1.0, 2.0]));
390 let x = ldlt.solve_vec(b).unwrap().into_array();
391
392 assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
393 assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
394 assert_abs_diff_eq!(ldlt.det().unwrap(), 8.0, epsilon = 1e-12);
395 }
396
397 #[test]
398 fn solve_3x3_spd_tridiagonal_smoke() {
399 let a = Matrix::<3>::from_rows(black_box([
400 [2.0, -1.0, 0.0],
401 [-1.0, 2.0, -1.0],
402 [0.0, -1.0, 2.0],
403 ]));
404 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
405
406 let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
408 let x = ldlt.solve_vec(b).unwrap().into_array();
409
410 for &x_i in &x {
411 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
412 }
413 }
414
415 #[test]
416 fn singular_detected_for_degenerate_psd() {
417 let a = Matrix::<2>::from_rows(black_box([[1.0, 1.0], [1.0, 1.0]]));
419 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
420 assert_eq!(err, LaError::Singular { pivot_col: 1 });
421 }
422
423 #[test]
424 fn negative_initial_diagonal_reports_not_positive_semidefinite() {
425 let a = Matrix::<2>::from_rows(black_box([[-1.0, 0.0], [0.0, 1.0]]));
426 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
427 assert_eq!(
428 err,
429 LaError::NotPositiveSemidefinite {
430 pivot_col: 0,
431 value: -1.0,
432 }
433 );
434 }
435
436 #[test]
437 fn negative_updated_diagonal_reports_not_positive_semidefinite() {
438 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 1.0]]));
439 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
440 assert_eq!(
441 err,
442 LaError::NotPositiveSemidefinite {
443 pivot_col: 1,
444 value: -3.0,
445 }
446 );
447 }
448
449 #[test]
450 fn matrix_constructor_rejects_nonfinite_diagonal() {
451 let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err();
452 assert_eq!(
453 err,
454 LaError::NonFinite {
455 row: Some(0),
456 col: 0
457 }
458 );
459 }
460
461 #[test]
462 fn matrix_constructor_rejects_nonfinite_offdiagonal_before_asymmetry() {
463 let err = Matrix::<2>::try_from_rows([[1.0, f64::NAN], [0.0, 1.0]]).unwrap_err();
464 assert_eq!(
465 err,
466 LaError::NonFinite {
467 row: Some(0),
468 col: 1,
469 }
470 );
471 }
472
473 #[test]
474 fn nonfinite_l_multiplier_overflow() {
475 let a = Matrix::<2>::from_rows([[1e-11, 1e300], [1e300, 1.0]]);
477 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
478 assert_eq!(
479 err,
480 LaError::NonFinite {
481 row: Some(1),
482 col: 0
483 }
484 );
485 }
486
487 #[test]
488 fn nonfinite_trailing_submatrix_overflow() {
489 let a = Matrix::<2>::from_rows([[1.0, 1e200], [1e200, 1.0]]);
492 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
493 assert_eq!(
494 err,
495 LaError::NonFinite {
496 row: Some(1),
497 col: 1
498 }
499 );
500 }
501
502 #[test]
503 fn nonfinite_solve_vec_forward_substitution_overflow() {
504 let a = Matrix::<3>::from_rows([
507 [1.0, 1e153, 0.0],
508 [1e153, 1e306 + 1.0, 0.0],
509 [0.0, 0.0, 1.0],
510 ]);
511 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
512
513 let b = Vector::<3>::new([1e156, 0.0, 0.0]);
514 let err = ldlt.solve_vec(b).unwrap_err();
515 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
516 }
517
518 #[test]
519 fn nonfinite_solve_vec_back_substitution_overflow() {
520 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 2.0, 5.0]]);
525 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
526
527 let b = Vector::<3>::new([0.0, 0.0, 1e308]);
528 let err = ldlt.solve_vec(b).unwrap_err();
529 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
530 }
531
532 #[test]
533 fn nonfinite_solve_vec_diagonal_solve_overflow() {
534 let a = Matrix::<2>::from_rows([[1.0, 0.0], [0.0, 1.0e-11]]);
540 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
541
542 let b = Vector::<2>::new([0.0, 1.0e300]);
543 let err = ldlt.solve_vec(b).unwrap_err();
544 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
545 }
546
547 #[test]
548 fn det_rejects_product_overflow() {
549 let a = Matrix::<5>::from_rows([
550 [1.0e100, 0.0, 0.0, 0.0, 0.0],
551 [0.0, 1.0e100, 0.0, 0.0, 0.0],
552 [0.0, 0.0, 1.0e100, 0.0, 0.0],
553 [0.0, 0.0, 0.0, 1.0e100, 0.0],
554 [0.0, 0.0, 0.0, 0.0, 1.0e100],
555 ]);
556 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
557 assert_eq!(ldlt.det(), Err(LaError::NonFinite { row: None, col: 3 }));
558 }
559
560 #[test]
561 fn asymmetric_input_returns_typed_error() {
562 let a = Matrix::<3>::from_rows([[4.0, 2.0, 0.0], [-2.0, 5.0, 1.0], [0.0, 1.0, 3.0]]);
564 assert_eq!(
565 a.ldlt(DEFAULT_SINGULAR_TOL),
566 Err(LaError::Asymmetric {
567 row: 0,
568 col: 1,
569 dim: 3,
570 })
571 );
572 }
573
574 #[test]
575 fn invalid_tolerance_rejected() {
576 assert_eq!(
577 Tolerance::new(-1.0),
578 Err(LaError::InvalidTolerance { value: -1.0 })
579 );
580
581 assert_matches!(
582 Tolerance::new(f64::NAN),
583 Err(LaError::InvalidTolerance { value }) if value.is_nan()
584 );
585 assert_eq!(
586 Tolerance::new(f64::INFINITY),
587 Err(LaError::InvalidTolerance {
588 value: f64::INFINITY,
589 })
590 );
591 }
592
593 macro_rules! gen_solve_vec_boundary_tests {
594 ($d:literal) => {
595 paste! {
596 #[test]
599 fn [<solve_vec_rhs_boundary_rejects_non_finite_ $d d>]() {
600 let mut rhs = [1.0; $d];
601 rhs[$d - 1] = f64::NAN;
602
603 assert_eq!(
604 Vector::<$d>::try_new(rhs),
605 Err(LaError::NonFinite {
606 row: None,
607 col: $d - 1,
608 })
609 );
610 }
611
612 #[test]
613 fn [<solve_vec_revalidates_unchecked_rhs_storage_ $d d>]() {
614 let ldlt = Matrix::<$d>::identity().ldlt(DEFAULT_SINGULAR_TOL).unwrap();
615 let mut rhs = [1.0; $d];
616 rhs[$d - 1] = f64::NAN;
617 let rhs = Vector::<$d>::new_unchecked(rhs);
618
619 assert_eq!(
620 ldlt.solve_vec(rhs),
621 Err(LaError::NonFinite {
622 row: None,
623 col: $d - 1,
624 })
625 );
626 }
627 }
628 };
629 }
630
631 gen_solve_vec_boundary_tests!(2);
632 gen_solve_vec_boundary_tests!(3);
633 gen_solve_vec_boundary_tests!(4);
634 gen_solve_vec_boundary_tests!(5);
635
636 macro_rules! gen_ldlt_const_eval_tests {
647 ($d:literal) => {
648 paste! {
649 #[test]
654 fn [<ldlt_det_const_eval_ $d d>]() {
655 const DET: Result<f64, LaError> = {
656 let mut factors = Matrix::<$d>::identity();
657 factors.rows[0][0] = 2.0;
658 let ldlt = Ldlt::<$d> {
659 factors: LdltFactors::new_unchecked(factors),
660 };
661 ldlt.det()
662 };
663 assert_eq!(DET, Ok(2.0));
664 }
665
666 #[test]
671 fn [<ldlt_solve_vec_const_eval_ $d d>]() {
672 #[allow(clippy::cast_precision_loss)]
673 const X: [f64; $d] = {
674 let ldlt = Ldlt::<$d> {
675 factors: LdltFactors::new_unchecked(Matrix::<$d>::identity()),
676 };
677 let mut b_arr = [0.0f64; $d];
678 let mut i = 0;
679 while i < $d {
680 b_arr[i] = i as f64 + 1.0;
681 i += 1;
682 }
683 let b = Vector::<$d>::new(b_arr);
684 match ldlt.solve_vec(b) {
685 Ok(v) => v.into_array(),
686 Err(_) => [0.0f64; $d],
687 }
688 };
689 #[allow(clippy::cast_precision_loss)]
690 for i in 0..$d {
691 let expected = i as f64 + 1.0;
692 assert!((X[i] - expected).abs() <= 1e-12);
693 }
694 }
695 }
696 };
697 }
698
699 gen_ldlt_const_eval_tests!(2);
700 gen_ldlt_const_eval_tests!(3);
701 gen_ldlt_const_eval_tests!(4);
702 gen_ldlt_const_eval_tests!(5);
703}