1#![forbid(unsafe_code)]
2
3use core::hint::cold_path;
17
18use crate::matrix::{Matrix, SymmetricMatrix};
19use crate::vector::Vector;
20use crate::{LaError, Tolerance};
21
22#[must_use]
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub struct Ldlt<const D: usize> {
44 factors: LdltFactors<D>,
45}
46
47#[derive(Clone, Copy, Debug, PartialEq)]
52struct LdltFactors<const D: usize> {
53 storage: Matrix<D>,
54}
55
56impl<const D: usize> LdltFactors<D> {
57 #[inline]
59 const fn new_unchecked(storage: Matrix<D>) -> Self {
60 Self { storage }
61 }
62
63 #[inline]
65 #[must_use]
66 const fn row(&self, index: usize) -> &[f64; D] {
67 &self.storage.rows()[index]
68 }
69
70 #[inline]
72 #[must_use]
73 const fn entry(&self, row: usize, col: usize) -> f64 {
74 self.storage.rows()[row][col]
75 }
76
77 #[inline]
79 #[must_use]
80 const fn diag(&self, index: usize) -> f64 {
81 self.storage.rows()[index][index]
82 }
83}
84
85impl<const D: usize> Ldlt<D> {
86 #[inline]
88 #[allow(clippy::needless_range_loop)]
89 pub(crate) fn factor_symmetric(a: SymmetricMatrix<D>, tol: Tolerance) -> Result<Self, LaError> {
90 let mut f = a.into_matrix();
91 let tol = tol.get();
92
93 {
94 let rows = f.rows_mut_unchecked();
95
96 for j in 0..D {
98 let d = rows[j][j];
99 if !d.is_finite() {
100 cold_path();
101 return Err(LaError::non_finite_cell(j, j));
102 }
103 if d < 0.0 {
104 cold_path();
105 return Err(LaError::not_positive_semidefinite(j, d));
106 }
107 if d <= tol {
108 cold_path();
109 return Err(LaError::Singular { pivot_col: j });
110 }
111
112 if D <= 5 {
113 for i in (j + 1)..D {
116 let l = rows[i][j] / d;
117 if !l.is_finite() {
118 cold_path();
119 return Err(LaError::non_finite_cell(i, j));
120 }
121 rows[i][j] = l;
122 }
123
124 for i in (j + 1)..D {
125 let l_i = rows[i][j];
126 let l_i_d = l_i * d;
127
128 for k in (j + 1)..=i {
129 let l_k = rows[k][j];
130 let new_val = (-l_i_d).mul_add(l_k, rows[i][k]);
131 rows[i][k] = new_val;
132 }
133 }
134 } else {
135 for i in (j + 1)..D {
138 let l_i = rows[i][j] / d;
139 if !l_i.is_finite() {
140 cold_path();
141 return Err(LaError::non_finite_cell(i, j));
142 }
143 rows[i][j] = l_i;
144
145 let l_i_d = l_i * d;
146
147 for k in (j + 1)..=i {
148 let l_k = rows[k][j];
149 let new_val = (-l_i_d).mul_add(l_k, rows[i][k]);
150 rows[i][k] = new_val;
151 }
152 }
153 }
154 }
155 }
156
157 Ok(Self {
160 factors: LdltFactors::new_unchecked(f),
161 })
162 }
163
164 #[inline]
186 pub const fn det(&self) -> Result<f64, LaError> {
187 let mut det = 1.0;
188 let mut i = 0;
189 while i < D {
190 det *= self.factors.diag(i);
191 if !det.is_finite() {
192 cold_path();
193 return Err(LaError::non_finite_at(i));
194 }
195 i += 1;
196 }
197 Ok(det)
198 }
199
200 #[inline]
228 pub const fn solve(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
229 self.solve_finite(b)
230 }
231
232 #[inline]
241 pub(crate) const fn solve_finite(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
242 let mut x = b.into_array();
243
244 let mut i = 0;
246 while i < D {
247 let mut sum = x[i];
248 let row = self.factors.row(i);
249 let mut j = 0;
250 while j < i {
251 sum = (-row[j]).mul_add(x[j], sum);
252 j += 1;
253 }
254 if !sum.is_finite() {
255 cold_path();
256 return Err(LaError::non_finite_at(i));
257 }
258 x[i] = sum;
259 i += 1;
260 }
261
262 let mut i = 0;
264 while i < D {
265 let diag = self.factors.diag(i);
266
267 let quotient = x[i] / diag;
268 if !quotient.is_finite() {
269 cold_path();
270 return Err(LaError::non_finite_at(i));
271 }
272 x[i] = quotient;
273 i += 1;
274 }
275
276 if D <= 4 {
277 let mut ii = 0;
280 while ii < D {
281 let i = D - 1 - ii;
282 let mut sum = x[i];
283 let mut j = i + 1;
284 while j < D {
285 sum = (-self.factors.entry(j, i)).mul_add(x[j], sum);
286 j += 1;
287 }
288 if !sum.is_finite() {
289 cold_path();
290 return Err(LaError::non_finite_at(i));
291 }
292 x[i] = sum;
293 ii += 1;
294 }
295 } else {
296 let mut jj = D;
300 while jj > 0 {
301 jj -= 1;
302
303 let x_j = x[jj];
304 if !x_j.is_finite() {
305 cold_path();
306 return Err(LaError::non_finite_at(jj));
307 }
308
309 let row = self.factors.row(jj);
310 let mut i = 0;
311 while i < jj {
312 x[i] = (-row[i]).mul_add(x_j, x[i]);
313 i += 1;
314 }
315 }
316 }
317
318 Ok(Vector::new_unchecked(x))
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 use crate::DEFAULT_SINGULAR_TOL;
327 use core::hint::black_box;
328
329 use approx::assert_abs_diff_eq;
330 use pastey::paste;
331
332 macro_rules! gen_public_api_ldlt_identity_tests {
333 ($d:literal) => {
334 paste! {
335 #[test]
336 fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
337 let a = Matrix::<$d>::identity();
338 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
339
340 assert_abs_diff_eq!(ldlt.det().unwrap(), 1.0, epsilon = 1e-12);
341
342 let b_arr = {
343 let mut arr = [0.0f64; $d];
344 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
345 for (dst, src) in arr.iter_mut().zip(values.iter()) {
346 *dst = *src;
347 }
348 arr
349 };
350 let b = Vector::<$d>::new(black_box(b_arr));
351 let x = ldlt.solve(b).unwrap().into_array();
352
353 for i in 0..$d {
354 assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
355 }
356 }
357 }
358 };
359 }
360
361 gen_public_api_ldlt_identity_tests!(2);
362 gen_public_api_ldlt_identity_tests!(3);
363 gen_public_api_ldlt_identity_tests!(4);
364 gen_public_api_ldlt_identity_tests!(5);
365
366 macro_rules! gen_public_api_ldlt_diagonal_tests {
367 ($d:literal) => {
368 paste! {
369 #[test]
370 fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
371 let diag = {
372 let mut arr = [0.0f64; $d];
373 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
374 for (dst, src) in arr.iter_mut().zip(values.iter()) {
375 *dst = *src;
376 }
377 arr
378 };
379
380 let mut rows = [[0.0f64; $d]; $d];
381 for i in 0..$d {
382 rows[i][i] = diag[i];
383 }
384
385 let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
386 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
387
388 let expected_det = {
389 let mut acc = 1.0;
390 for i in 0..$d {
391 acc *= diag[i];
392 }
393 acc
394 };
395 assert_abs_diff_eq!(ldlt.det().unwrap(), expected_det, epsilon = 1e-12);
396
397 let b_arr = {
398 let mut arr = [0.0f64; $d];
399 let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
400 for (dst, src) in arr.iter_mut().zip(values.iter()) {
401 *dst = *src;
402 }
403 arr
404 };
405
406 let b = Vector::<$d>::new(black_box(b_arr));
407 let x = ldlt.solve(b).unwrap().into_array();
408
409 for i in 0..$d {
410 assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
411 }
412 }
413 }
414 };
415 }
416
417 gen_public_api_ldlt_diagonal_tests!(2);
418 gen_public_api_ldlt_diagonal_tests!(3);
419 gen_public_api_ldlt_diagonal_tests!(4);
420 gen_public_api_ldlt_diagonal_tests!(5);
421
422 #[test]
423 fn solve_0x0_returns_empty_vector_and_unit_det() {
424 let a = Matrix::<0>::zero();
425 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
426
427 assert_eq!(ldlt.det(), Ok(1.0));
428 assert!(
429 ldlt.solve(Vector::<0>::zero())
430 .unwrap()
431 .into_array()
432 .is_empty()
433 );
434 }
435
436 #[test]
437 fn solve_2x2_known_spd() {
438 let a = Matrix::<2>::try_from_rows(black_box([[4.0, 2.0], [2.0, 3.0]])).unwrap();
439 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
440
441 let b = Vector::<2>::new(black_box([1.0, 2.0]));
442 let x = ldlt.solve(b).unwrap().into_array();
443
444 assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
445 assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
446 assert_abs_diff_eq!(ldlt.det().unwrap(), 8.0, epsilon = 1e-12);
447 }
448
449 #[test]
450 fn solve_3x3_spd_tridiagonal_smoke() {
451 let a = Matrix::<3>::try_from_rows(black_box([
452 [2.0, -1.0, 0.0],
453 [-1.0, 2.0, -1.0],
454 [0.0, -1.0, 2.0],
455 ]))
456 .unwrap();
457 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
458
459 let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
461 let x = ldlt.solve(b).unwrap().into_array();
462
463 for &x_i in &x {
464 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
465 }
466 }
467
468 #[test]
469 fn singular_detected_for_degenerate_psd() {
470 let a = Matrix::<2>::try_from_rows(black_box([[1.0, 1.0], [1.0, 1.0]])).unwrap();
472 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
473 assert_eq!(err, LaError::Singular { pivot_col: 1 });
474 }
475
476 #[test]
477 fn negative_initial_diagonal_reports_not_positive_semidefinite() {
478 let a = Matrix::<2>::try_from_rows(black_box([[-1.0, 0.0], [0.0, 1.0]])).unwrap();
479 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
480 assert_eq!(
481 err,
482 LaError::NotPositiveSemidefinite {
483 pivot_col: 0,
484 value: -1.0,
485 }
486 );
487 }
488
489 #[test]
490 fn negative_updated_diagonal_reports_not_positive_semidefinite() {
491 let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [2.0, 1.0]])).unwrap();
492 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
493 assert_eq!(
494 err,
495 LaError::NotPositiveSemidefinite {
496 pivot_col: 1,
497 value: -3.0,
498 }
499 );
500 }
501
502 #[test]
503 fn matrix_constructor_rejects_nonfinite_diagonal() {
504 let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err();
505 assert_eq!(
506 err,
507 LaError::NonFinite {
508 row: Some(0),
509 col: 0
510 }
511 );
512 }
513
514 #[test]
515 fn matrix_constructor_rejects_nonfinite_offdiagonal_before_asymmetry() {
516 let err = Matrix::<2>::try_from_rows([[1.0, f64::NAN], [0.0, 1.0]]).unwrap_err();
517 assert_eq!(
518 err,
519 LaError::NonFinite {
520 row: Some(0),
521 col: 1,
522 }
523 );
524 }
525
526 #[test]
527 fn nonfinite_l_multiplier_overflow() {
528 let a = Matrix::<2>::try_from_rows([[1e-11, 1e300], [1e300, 1.0]]).unwrap();
530 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
531 assert_eq!(
532 err,
533 LaError::NonFinite {
534 row: Some(1),
535 col: 0
536 }
537 );
538 }
539
540 #[test]
541 fn nonfinite_l_multiplier_overflow_fused_branch_6d() {
542 let mut rows = [[0.0; 6]; 6];
545 for (i, row) in rows.iter_mut().enumerate() {
546 row[i] = 1.0;
547 }
548 rows[0][0] = 1e-11;
549 rows[0][5] = 1e300;
550 rows[5][0] = 1e300;
551
552 let a = Matrix::<6>::try_from_rows(rows).unwrap();
553 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
554 assert_eq!(
555 err,
556 LaError::NonFinite {
557 row: Some(5),
558 col: 0
559 }
560 );
561 }
562
563 #[test]
564 fn nonfinite_trailing_submatrix_overflow() {
565 let a = Matrix::<2>::try_from_rows([[1.0, 1e200], [1e200, 1.0]]).unwrap();
568 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
569 assert_eq!(
570 err,
571 LaError::NonFinite {
572 row: Some(1),
573 col: 1
574 }
575 );
576 }
577
578 #[test]
579 fn nonfinite_trailing_submatrix_overflow_fused_branch_6d() {
580 let mut rows = [[0.0; 6]; 6];
583 for (i, row) in rows.iter_mut().enumerate() {
584 row[i] = 1.0;
585 }
586 rows[0][5] = 1e200;
587 rows[5][0] = 1e200;
588
589 let a = Matrix::<6>::try_from_rows(rows).unwrap();
590 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
591 assert_eq!(
592 err,
593 LaError::NonFinite {
594 row: Some(5),
595 col: 5
596 }
597 );
598 }
599
600 #[test]
601 fn nonfinite_solve_forward_substitution_overflow() {
602 let a = Matrix::<3>::try_from_rows([
605 [1.0, 1e153, 0.0],
606 [1e153, 1e306 + 1.0, 0.0],
607 [0.0, 0.0, 1.0],
608 ])
609 .unwrap();
610 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
611
612 let b = Vector::<3>::new([1e156, 0.0, 0.0]);
613 let err = ldlt.solve(b).unwrap_err();
614 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
615 }
616
617 #[test]
618 fn nonfinite_solve_back_substitution_overflow() {
619 let a = Matrix::<3>::try_from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 2.0, 5.0]])
624 .unwrap();
625 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
626
627 let b = Vector::<3>::new([0.0, 0.0, 1e308]);
628 let err = ldlt.solve(b).unwrap_err();
629 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
630 }
631
632 #[test]
633 fn nonfinite_solve_back_substitution_overflow_scatter_branch_5d() {
634 let a = Matrix::<5>::try_from_rows([
637 [1.0, 0.0, 0.0, 0.0, 0.0],
638 [0.0, 1.0, 0.0, 0.0, 0.0],
639 [0.0, 0.0, 1.0, 0.0, 0.0],
640 [0.0, 0.0, 0.0, 1.0, 2.0],
641 [0.0, 0.0, 0.0, 2.0, 5.0],
642 ])
643 .unwrap();
644 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
645
646 let b = Vector::<5>::new([0.0, 0.0, 0.0, 0.0, 1e308]);
647 let err = ldlt.solve(b).unwrap_err();
648 assert_eq!(err, LaError::NonFinite { row: None, col: 3 });
649 }
650
651 #[test]
652 fn nonfinite_solve_diagonal_solve_overflow() {
653 let a = Matrix::<2>::try_from_rows([[1.0, 0.0], [0.0, 1.0e-11]]).unwrap();
659 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
660
661 let b = Vector::<2>::new([0.0, 1.0e300]);
662 let err = ldlt.solve(b).unwrap_err();
663 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
664 }
665
666 #[test]
667 fn det_rejects_product_overflow() {
668 let a = Matrix::<5>::try_from_rows([
669 [1.0e100, 0.0, 0.0, 0.0, 0.0],
670 [0.0, 1.0e100, 0.0, 0.0, 0.0],
671 [0.0, 0.0, 1.0e100, 0.0, 0.0],
672 [0.0, 0.0, 0.0, 1.0e100, 0.0],
673 [0.0, 0.0, 0.0, 0.0, 1.0e100],
674 ])
675 .unwrap();
676 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
677 assert_eq!(ldlt.det(), Err(LaError::NonFinite { row: None, col: 3 }));
678 }
679
680 #[test]
681 fn asymmetric_input_returns_typed_error() {
682 let a = Matrix::<3>::try_from_rows([[4.0, 2.0, 0.0], [-2.0, 5.0, 1.0], [0.0, 1.0, 3.0]])
684 .unwrap();
685 assert_eq!(
686 a.ldlt(DEFAULT_SINGULAR_TOL),
687 Err(LaError::Asymmetric {
688 row: 0,
689 col: 1,
690 dim: 3,
691 })
692 );
693 }
694
695 macro_rules! gen_solve_boundary_tests {
696 ($d:literal) => {
697 paste! {
698 #[test]
701 fn [<solve_rhs_constructor_rejects_non_finite_ $d d>]() {
702 let mut rhs = [1.0; $d];
703 rhs[$d - 1] = f64::NAN;
704
705 assert_eq!(
706 Vector::<$d>::try_new(rhs),
707 Err(LaError::NonFinite {
708 row: None,
709 col: $d - 1,
710 })
711 );
712 }
713 }
714 };
715 }
716
717 gen_solve_boundary_tests!(2);
718 gen_solve_boundary_tests!(3);
719 gen_solve_boundary_tests!(4);
720 gen_solve_boundary_tests!(5);
721
722 macro_rules! gen_ldlt_const_eval_tests {
733 ($d:literal) => {
734 paste! {
735 #[test]
740 fn [<ldlt_det_const_eval_ $d d>]() {
741 const DET: Result<f64, LaError> = {
742 let mut rows = [[0.0f64; $d]; $d];
743 let mut i = 0;
744 while i < $d {
745 rows[i][i] = 1.0;
746 i += 1;
747 }
748 rows[0][0] = 2.0;
749 let factors = Matrix::<$d>::from_rows_unchecked(rows);
750 let ldlt = Ldlt::<$d> {
751 factors: LdltFactors::new_unchecked(factors),
752 };
753 ldlt.det()
754 };
755 assert_eq!(DET, Ok(2.0));
756 }
757
758 #[test]
763 fn [<ldlt_solve_const_eval_ $d d>]() {
764 #[allow(clippy::cast_precision_loss)]
765 const X: [f64; $d] = {
766 let ldlt = Ldlt::<$d> {
767 factors: LdltFactors::new_unchecked(Matrix::<$d>::identity()),
768 };
769 let mut b_arr = [0.0f64; $d];
770 let mut i = 0;
771 while i < $d {
772 b_arr[i] = i as f64 + 1.0;
773 i += 1;
774 }
775 let b = Vector::<$d>::new(b_arr);
776 match ldlt.solve(b) {
777 Ok(v) => v.into_array(),
778 Err(_) => [0.0f64; $d],
779 }
780 };
781 #[allow(clippy::cast_precision_loss)]
782 for i in 0..$d {
783 let expected = i as f64 + 1.0;
784 assert!((X[i] - expected).abs() <= 1e-12);
785 }
786 }
787 }
788 };
789 }
790
791 gen_ldlt_const_eval_tests!(2);
792 gen_ldlt_const_eval_tests!(3);
793 gen_ldlt_const_eval_tests!(4);
794 gen_ldlt_const_eval_tests!(5);
795}