1#![forbid(unsafe_code)]
2
3use core::hint::cold_path;
6
7use crate::matrix::Matrix;
8use crate::vector::Vector;
9use crate::{LaError, Tolerance};
10
11#[must_use]
16#[derive(Clone, Copy, Debug, PartialEq)]
17pub struct Lu<const D: usize> {
18 factors: LuFactors<D>,
19 piv: [usize; D],
20 piv_sign: f64,
21}
22
23#[derive(Clone, Copy, Debug, PartialEq)]
28struct LuFactors<const D: usize> {
29 storage: Matrix<D>,
30}
31
32impl<const D: usize> LuFactors<D> {
33 #[inline]
35 const fn new_unchecked(storage: Matrix<D>) -> Self {
36 Self { storage }
37 }
38
39 #[inline]
41 #[must_use]
42 const fn row(&self, index: usize) -> &[f64; D] {
43 &self.storage.rows()[index]
44 }
45
46 #[inline]
48 #[must_use]
49 const fn diag(&self, index: usize) -> f64 {
50 self.storage.rows()[index][index]
51 }
52}
53
54impl<const D: usize> Lu<D> {
55 #[inline]
64 #[allow(clippy::needless_range_loop)]
65 pub(crate) fn factor_finite(a: Matrix<D>, tol: Tolerance) -> Result<Self, LaError> {
66 let mut lu = a;
67 let tol = tol.get();
68
69 let mut piv = [0usize; D];
70 for (i, p) in piv.iter_mut().enumerate() {
71 *p = i;
72 }
73
74 let mut piv_sign = 1.0;
75
76 {
77 let rows = lu.rows_mut_unchecked();
78
79 for k in 0..D {
80 let mut pivot_row = k;
82 let mut pivot_abs = rows[k][k].abs();
83
84 for r in (k + 1)..D {
85 let v = rows[r][k].abs();
86 if v > pivot_abs {
87 pivot_abs = v;
88 pivot_row = r;
89 }
90 }
91
92 if pivot_abs <= tol {
93 cold_path();
94 return Err(LaError::Singular { pivot_col: k });
95 }
96
97 if pivot_row != k {
98 rows.swap(k, pivot_row);
99 piv.swap(k, pivot_row);
100 piv_sign = -piv_sign;
101 }
102
103 let pivot = rows[k][k];
104
105 for r in (k + 1)..D {
107 let mult = rows[r][k] / pivot;
108 rows[r][k] = mult;
109
110 for c in (k + 1)..D {
111 let updated = (-mult).mul_add(rows[k][c], rows[r][c]);
112 rows[r][c] = updated;
113 }
114 }
115 }
116 }
117
118 let lu = lu.validate_finite()?;
119
120 Ok(Self {
121 factors: LuFactors::new_unchecked(lu),
122 piv,
123 piv_sign,
124 })
125 }
126
127 #[inline]
155 pub const fn solve(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
156 self.solve_finite(b)
157 }
158
159 #[inline]
168 pub(crate) const fn solve_finite(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
169 let mut x = [0.0; D];
170 let b = b.as_array();
171 let mut i = 0;
172
173 if D <= 4 {
174 while i < D {
175 x[i] = b[self.piv[i]];
176 i += 1;
177 }
178
179 i = 0;
182 while i < D {
183 let mut sum = x[i];
184 let row = self.factors.row(i);
185 let mut j = 0;
186 while j < i {
187 sum = (-row[j]).mul_add(x[j], sum);
188 j += 1;
189 }
190 if !sum.is_finite() {
191 cold_path();
192 return Err(LaError::non_finite_at(i));
193 }
194 x[i] = sum;
195 i += 1;
196 }
197 } else {
198 while i < D {
201 let mut sum = b[self.piv[i]];
202 let row = self.factors.row(i);
203 let mut j = 0;
204 while j < i {
205 sum = (-row[j]).mul_add(x[j], sum);
206 j += 1;
207 }
208 if !sum.is_finite() {
209 cold_path();
210 return Err(LaError::non_finite_at(i));
211 }
212 x[i] = sum;
213 i += 1;
214 }
215 }
216
217 let mut ii = 0;
219 while ii < D {
220 let i = D - 1 - ii;
221 let mut sum = x[i];
222 let row = self.factors.row(i);
223 let mut j = i + 1;
224 while j < D {
225 sum = (-row[j]).mul_add(x[j], sum);
226 j += 1;
227 }
228
229 let diag = row[i];
230 if !sum.is_finite() {
231 cold_path();
232 return Err(LaError::non_finite_at(i));
233 }
234
235 let quotient = sum / diag;
236 if !quotient.is_finite() {
237 cold_path();
238 return Err(LaError::non_finite_at(i));
239 }
240 x[i] = quotient;
241 ii += 1;
242 }
243
244 Ok(Vector::new_unchecked(x))
245 }
246
247 #[inline]
267 pub const fn det(&self) -> Result<f64, LaError> {
268 let mut det = self.piv_sign;
269 let mut i = 0;
270 while i < D {
271 det *= self.factors.diag(i);
272 if !det.is_finite() {
273 cold_path();
274 return Err(LaError::non_finite_at(i));
275 }
276 i += 1;
277 }
278 Ok(det)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::DEFAULT_SINGULAR_TOL;
286
287 use core::hint::black_box;
288
289 use approx::assert_abs_diff_eq;
290 use pastey::paste;
291
292 macro_rules! gen_public_api_pivoting_solve_and_det_tests {
293 ($d:literal) => {
294 paste! {
295 #[test]
296 fn [<public_api_lu_solve_pivoting_ $d d>]() {
297 let mut rows = [[0.0f64; $d]; $d];
303 for i in 0..$d {
304 rows[i][i] = 1.0;
305 }
306 rows.swap(0, 1);
307
308 let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
309 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
310 black_box(Matrix::<$d>::lu);
311 let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
312
313 let b_arr = {
315 let mut arr = [0.0f64; $d];
316 let mut val = 1.0f64;
317 for dst in arr.iter_mut() {
318 *dst = val;
319 val += 1.0;
320 }
321 arr
322 };
323 let mut expected = b_arr;
324 expected.swap(0, 1);
325 let b = Vector::<$d>::new(black_box(b_arr));
326
327 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
328 black_box(Lu::<$d>::solve);
329 let x = solve_fn(&lu, b).unwrap().into_array();
330
331 for i in 0..$d {
332 assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
333 }
334 }
335
336 #[test]
337 fn [<public_api_lu_det_pivoting_ $d d>]() {
338 let mut rows = [[0.0f64; $d]; $d];
343 for i in 0..$d {
344 rows[i][i] = 1.0;
345 }
346 rows.swap(0, 1);
347
348 let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
349 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
350 black_box(Matrix::<$d>::lu);
351 let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
352
353 let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
355 black_box(Lu::<$d>::det);
356 assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 1e-12);
357 }
358 }
359 };
360 }
361
362 gen_public_api_pivoting_solve_and_det_tests!(2);
363 gen_public_api_pivoting_solve_and_det_tests!(3);
364 gen_public_api_pivoting_solve_and_det_tests!(4);
365 gen_public_api_pivoting_solve_and_det_tests!(5);
366
367 macro_rules! gen_public_api_tridiagonal_smoke_solve_and_det_tests {
368 ($d:literal) => {
369 paste! {
370 #[test]
371 fn [<public_api_lu_solve_tridiagonal_smoke_ $d d>]() {
372 #[allow(clippy::large_stack_arrays)]
377 let mut rows = [[0.0f64; $d]; $d];
378 for i in 0..$d {
379 rows[i][i] = 2.0;
380 if i > 0 {
381 rows[i][i - 1] = -1.0;
382 }
383 if i + 1 < $d {
384 rows[i][i + 1] = -1.0;
385 }
386 }
387
388 let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
389 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
390 black_box(Matrix::<$d>::lu);
391 let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
392
393 let mut b_arr = [0.0f64; $d];
395 b_arr[0] = 1.0;
396 b_arr[$d - 1] = 1.0;
397 let b = Vector::<$d>::new(black_box(b_arr));
398
399 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
400 black_box(Lu::<$d>::solve);
401 let x = solve_fn(&lu, b).unwrap().into_array();
402
403 for &x_i in &x {
404 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
405 }
406 }
407
408 #[test]
409 fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
410 #[allow(clippy::large_stack_arrays)]
416 let mut rows = [[0.0f64; $d]; $d];
417 for i in 0..$d {
418 rows[i][i] = 2.0;
419 if i > 0 {
420 rows[i][i - 1] = -1.0;
421 }
422 if i + 1 < $d {
423 rows[i][i + 1] = -1.0;
424 }
425 }
426
427 let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
428 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
429 black_box(Matrix::<$d>::lu);
430 let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
431
432 let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
433 black_box(Lu::<$d>::det);
434 assert_abs_diff_eq!(det_fn(&lu).unwrap(), f64::from($d) + 1.0, epsilon = 1e-8);
435 }
436 }
437 };
438 }
439
440 gen_public_api_tridiagonal_smoke_solve_and_det_tests!(16);
441 gen_public_api_tridiagonal_smoke_solve_and_det_tests!(32);
442 gen_public_api_tridiagonal_smoke_solve_and_det_tests!(64);
443
444 #[test]
445 fn solve_0x0_returns_empty_vector_and_unit_det() {
446 let a = Matrix::<0>::zero();
447 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
448
449 assert_eq!(lu.det(), Ok(1.0));
450 assert!(
451 lu.solve(Vector::<0>::zero())
452 .unwrap()
453 .into_array()
454 .is_empty()
455 );
456 }
457
458 #[test]
459 fn solve_1x1() {
460 let a = Matrix::<1>::try_from_rows(black_box([[2.0]])).unwrap();
461 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
462
463 let b = Vector::<1>::new(black_box([6.0]));
464 let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
465 black_box(Lu::<1>::solve);
466 let x = solve_fn(&lu, b).unwrap().into_array();
467 assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
468
469 let det_fn: fn(&Lu<1>) -> Result<f64, LaError> = black_box(Lu::<1>::det);
470 assert_abs_diff_eq!(det_fn(&lu).unwrap(), 2.0, epsilon = 0.0);
471 }
472
473 #[test]
474 fn solve_2x2_basic() {
475 let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [3.0, 4.0]])).unwrap();
476 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
477 let b = Vector::<2>::new(black_box([5.0, 11.0]));
478
479 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
480 black_box(Lu::<2>::solve);
481 let x = solve_fn(&lu, b).unwrap().into_array();
482
483 assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
484 assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
485 }
486
487 #[test]
488 fn det_2x2_basic() {
489 let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [3.0, 4.0]])).unwrap();
490 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
491
492 let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
493 assert_abs_diff_eq!(det_fn(&lu).unwrap(), -2.0, epsilon = 1e-12);
494 }
495
496 #[test]
497 fn det_requires_pivot_sign() {
498 let a = Matrix::<2>::try_from_rows(black_box([[0.0, 1.0], [1.0, 0.0]])).unwrap();
500 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
501
502 let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
503 assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 0.0);
504 }
505
506 #[test]
507 fn solve_requires_pivoting() {
508 let a = Matrix::<2>::try_from_rows(black_box([[0.0, 1.0], [1.0, 0.0]])).unwrap();
509 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
510 let b = Vector::<2>::new(black_box([1.0, 2.0]));
511
512 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
513 black_box(Lu::<2>::solve);
514 let x = solve_fn(&lu, b).unwrap().into_array();
515
516 assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
517 assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
518 }
519
520 #[test]
521 fn singular_detected() {
522 let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [2.0, 4.0]])).unwrap();
523 let err = a.lu(DEFAULT_SINGULAR_TOL).unwrap_err();
524 assert_eq!(err, LaError::Singular { pivot_col: 1 });
525 }
526
527 #[test]
528 fn singular_due_to_tolerance_at_first_pivot() {
529 let a = Matrix::<2>::try_from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]])).unwrap();
531 let err = a.lu(DEFAULT_SINGULAR_TOL).unwrap_err();
532 assert_eq!(err, LaError::Singular { pivot_col: 0 });
533 }
534
535 #[test]
536 fn matrix_constructor_rejects_nonfinite_pivot_entry() {
537 let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err();
538 assert_eq!(
539 err,
540 LaError::NonFinite {
541 row: Some(0),
542 col: 0
543 }
544 );
545 }
546
547 #[test]
548 fn matrix_constructor_rejects_nonfinite_pivot_column_entry() {
549 let err = Matrix::<2>::try_from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]).unwrap_err();
550 assert_eq!(
551 err,
552 LaError::NonFinite {
553 row: Some(1),
554 col: 0
555 }
556 );
557 }
558
559 #[test]
560 fn nonfinite_detected_in_trailing_update() {
561 let a = Matrix::<3>::try_from_rows([
562 [1.0, f64::MAX, 0.0],
563 [-1.0, f64::MAX, 0.0],
564 [0.0, 0.0, 1.0],
565 ])
566 .unwrap();
567
568 let err = a.lu(DEFAULT_SINGULAR_TOL).unwrap_err();
569 assert_eq!(
570 err,
571 LaError::NonFinite {
572 row: Some(1),
573 col: 1,
574 }
575 );
576 }
577
578 #[test]
579 fn solve_nonfinite_forward_substitution_overflow() {
580 let a = Matrix::<3>::try_from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
582 .unwrap();
583 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
584
585 let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
586 let err = lu.solve(b).unwrap_err();
587 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
588 }
589
590 #[test]
591 fn solve_nonfinite_forward_substitution_overflow_fused_branch_5d() {
592 let a = Matrix::<5>::try_from_rows([
595 [1.0, 0.0, 0.0, 0.0, 0.0],
596 [-1.0, 1.0, 0.0, 0.0, 0.0],
597 [0.0, 0.0, 1.0, 0.0, 0.0],
598 [0.0, 0.0, 0.0, 1.0, 0.0],
599 [0.0, 0.0, 0.0, 0.0, 1.0],
600 ])
601 .unwrap();
602 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
603
604 let b = Vector::<5>::new([1.0e308, 1.0e308, 0.0, 0.0, 0.0]);
605 let err = lu.solve(b).unwrap_err();
606 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
607 }
608
609 #[test]
610 fn solve_nonfinite_back_substitution_overflow() {
611 let a = Matrix::<2>::try_from_rows([[1.0, 1.0], [0.0, 2.0e-12]]).unwrap();
613 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
614
615 let b = Vector::<2>::new([0.0, 1.0e300]);
616 let err = lu.solve(b).unwrap_err();
617 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
618 }
619
620 #[test]
621 fn solve_nonfinite_back_substitution_sum_overflow() {
622 let a = Matrix::<3>::try_from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 1.0e200], [0.0, 0.0, 1.0]])
629 .unwrap();
630 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
631
632 let b = Vector::<3>::new([0.0, 0.0, 1.0e200]);
633 let err = lu.solve(b).unwrap_err();
634 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
635 }
636
637 #[test]
638 fn det_rejects_product_overflow() {
639 let a = Matrix::<5>::try_from_rows([
640 [1.0e100, 0.0, 0.0, 0.0, 0.0],
641 [0.0, 1.0e100, 0.0, 0.0, 0.0],
642 [0.0, 0.0, 1.0e100, 0.0, 0.0],
643 [0.0, 0.0, 0.0, 1.0e100, 0.0],
644 [0.0, 0.0, 0.0, 0.0, 1.0e100],
645 ])
646 .unwrap();
647 let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
648 assert_eq!(lu.det(), Err(LaError::NonFinite { row: None, col: 3 }));
649 }
650
651 macro_rules! gen_solve_boundary_tests {
652 ($d:literal) => {
653 paste! {
654 #[test]
657 fn [<solve_rhs_constructor_rejects_non_finite_ $d d>]() {
658 let mut rhs = [1.0; $d];
659 rhs[$d - 1] = f64::NAN;
660
661 assert_eq!(
662 Vector::<$d>::try_new(rhs),
663 Err(LaError::NonFinite {
664 row: None,
665 col: $d - 1,
666 })
667 );
668 }
669 }
670 };
671 }
672
673 gen_solve_boundary_tests!(2);
674 gen_solve_boundary_tests!(3);
675 gen_solve_boundary_tests!(4);
676 gen_solve_boundary_tests!(5);
677
678 #[test]
688 fn lu_det_const_eval_d2() {
689 const DET: Result<f64, LaError> = {
690 let factors = Matrix::<2>::from_rows_unchecked([[2.0, 0.0], [0.0, 3.0]]);
692 let lu = Lu::<2> {
693 factors: LuFactors::new_unchecked(factors),
694 piv: [0, 1],
695 piv_sign: 1.0,
696 };
697 lu.det()
698 };
699 assert_eq!(DET, Ok(6.0));
700 }
701
702 #[test]
703 fn lu_det_const_eval_d3_row_swap() {
704 const DET: Result<f64, LaError> = {
705 let lu = Lu::<3> {
708 factors: LuFactors::new_unchecked(Matrix::<3>::identity()),
709 piv: [1, 0, 2],
710 piv_sign: -1.0,
711 };
712 lu.det()
713 };
714 assert_eq!(DET, Ok(-1.0));
715 }
716
717 #[test]
718 fn lu_solve_const_eval_d2() {
719 const X: [f64; 2] = {
721 let lu = Lu::<2> {
722 factors: LuFactors::new_unchecked(Matrix::<2>::identity()),
723 piv: [0, 1],
724 piv_sign: 1.0,
725 };
726 let b = Vector::<2>::new([1.0, 2.0]);
727 match lu.solve(b) {
728 Ok(v) => v.into_array(),
729 Err(_) => [0.0, 0.0],
730 }
731 };
732 assert!((X[0] - 1.0).abs() <= 1e-12);
733 assert!((X[1] - 2.0).abs() <= 1e-12);
734 }
735}