1use core::hint::cold_path;
4
5use crate::matrix::{FiniteMatrix, Matrix};
6use crate::vector::{FiniteVector, Vector};
7use crate::{LaError, Tolerance};
8
9#[must_use]
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct Lu<const D: usize> {
13 factors: LuFactors<D>,
14 piv: [usize; D],
15 piv_sign: f64,
16}
17
18#[derive(Clone, Copy, Debug, PartialEq)]
23struct LuFactors<const D: usize> {
24 storage: Matrix<D>,
25}
26
27impl<const D: usize> LuFactors<D> {
28 #[inline]
30 const fn new_unchecked(storage: Matrix<D>) -> Self {
31 Self { storage }
32 }
33
34 #[inline]
36 #[must_use]
37 const fn row(&self, index: usize) -> &[f64; D] {
38 &self.storage.rows[index]
39 }
40
41 #[inline]
43 #[must_use]
44 const fn diag(&self, index: usize) -> f64 {
45 self.storage.rows[index][index]
46 }
47}
48
49impl<const D: usize> Lu<D> {
50 #[inline]
59 pub(crate) fn factor_finite(a: FiniteMatrix<D>, tol: Tolerance) -> Result<Self, LaError> {
60 let mut lu = a.into_matrix();
61 let tol = tol.get();
62
63 let mut piv = [0usize; D];
64 for (i, p) in piv.iter_mut().enumerate() {
65 *p = i;
66 }
67
68 let mut piv_sign = 1.0;
69
70 for k in 0..D {
71 let mut pivot_row = k;
73 let mut pivot_abs = lu.rows[k][k].abs();
74
75 for r in (k + 1)..D {
76 let v = lu.rows[r][k].abs();
77 if v > pivot_abs {
78 pivot_abs = v;
79 pivot_row = r;
80 }
81 }
82
83 if pivot_abs <= tol {
84 cold_path();
85 return Err(LaError::Singular { pivot_col: k });
86 }
87
88 if pivot_row != k {
89 lu.rows.swap(k, pivot_row);
90 piv.swap(k, pivot_row);
91 piv_sign = -piv_sign;
92 }
93
94 let pivot = lu.rows[k][k];
95
96 for r in (k + 1)..D {
98 let mult = lu.rows[r][k] / pivot;
99 lu.rows[r][k] = mult;
100
101 for c in (k + 1)..D {
102 let updated = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
103 lu.rows[r][c] = updated;
104 }
105 }
106 }
107
108 let lu = FiniteMatrix::new(lu)?.into_matrix();
109
110 Ok(Self {
111 factors: LuFactors::new_unchecked(lu),
112 piv,
113 piv_sign,
114 })
115 }
116
117 #[inline]
149 pub const fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
150 let b = match FiniteVector::new(b) {
151 Ok(b) => b,
152 Err(err) => return Err(err),
153 };
154 match self.solve_finite_vec(b) {
155 Ok(x) => Ok(x.into_vector()),
156 Err(err) => Err(err),
157 }
158 }
159
160 #[inline]
169 pub(crate) const fn solve_finite_vec(
170 &self,
171 b: FiniteVector<D>,
172 ) -> Result<FiniteVector<D>, LaError> {
173 let mut x = [0.0; D];
174 let mut i = 0;
175 let b = b.as_array();
176 while i < D {
177 x[i] = b[self.piv[i]];
178 i += 1;
179 }
180
181 let mut i = 0;
183 while i < D {
184 let mut sum = x[i];
185 let row = self.factors.row(i);
186 let mut j = 0;
187 while j < i {
188 sum = (-row[j]).mul_add(x[j], sum);
189 j += 1;
190 }
191 if !sum.is_finite() {
192 cold_path();
193 return Err(LaError::non_finite_at(i));
194 }
195 x[i] = sum;
196 i += 1;
197 }
198
199 let mut ii = 0;
201 while ii < D {
202 let i = D - 1 - ii;
203 let mut sum = x[i];
204 let row = self.factors.row(i);
205 let mut j = i + 1;
206 while j < D {
207 sum = (-row[j]).mul_add(x[j], sum);
208 j += 1;
209 }
210
211 let diag = row[i];
212 if !sum.is_finite() {
213 cold_path();
214 return Err(LaError::non_finite_at(i));
215 }
216
217 let quotient = sum / diag;
218 if !quotient.is_finite() {
219 cold_path();
220 return Err(LaError::non_finite_at(i));
221 }
222 x[i] = quotient;
223 ii += 1;
224 }
225
226 Ok(FiniteVector::new_unchecked(Vector::new_unchecked(x)))
227 }
228
229 #[inline]
249 pub const fn det(&self) -> Result<f64, LaError> {
250 let mut det = self.piv_sign;
251 let mut i = 0;
252 while i < D {
253 det *= self.factors.diag(i);
254 if !det.is_finite() {
255 cold_path();
256 return Err(LaError::non_finite_at(i));
257 }
258 i += 1;
259 }
260 Ok(det)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::DEFAULT_PIVOT_TOL;
268
269 use core::assert_matches;
270 use core::hint::black_box;
271
272 use approx::assert_abs_diff_eq;
273 use pastey::paste;
274
275 macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
276 ($d:literal) => {
277 paste! {
278 #[test]
279 fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
280 let mut rows = [[0.0f64; $d]; $d];
286 for i in 0..$d {
287 rows[i][i] = 1.0;
288 }
289 rows.swap(0, 1);
290
291 let a = Matrix::<$d>::from_rows(black_box(rows));
292 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
293 black_box(Matrix::<$d>::lu);
294 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
295
296 let b_arr = {
298 let mut arr = [0.0f64; $d];
299 let mut val = 1.0f64;
300 for dst in arr.iter_mut() {
301 *dst = val;
302 val += 1.0;
303 }
304 arr
305 };
306 let mut expected = b_arr;
307 expected.swap(0, 1);
308 let b = Vector::<$d>::new(black_box(b_arr));
309
310 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
311 black_box(Lu::<$d>::solve_vec);
312 let x = solve_fn(&lu, b).unwrap().into_array();
313
314 for i in 0..$d {
315 assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
316 }
317 }
318
319 #[test]
320 fn [<public_api_lu_det_pivoting_ $d d>]() {
321 let mut rows = [[0.0f64; $d]; $d];
326 for i in 0..$d {
327 rows[i][i] = 1.0;
328 }
329 rows.swap(0, 1);
330
331 let a = Matrix::<$d>::from_rows(black_box(rows));
332 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
333 black_box(Matrix::<$d>::lu);
334 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
335
336 let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
338 black_box(Lu::<$d>::det);
339 assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 1e-12);
340 }
341 }
342 };
343 }
344
345 gen_public_api_pivoting_solve_vec_and_det_tests!(2);
346 gen_public_api_pivoting_solve_vec_and_det_tests!(3);
347 gen_public_api_pivoting_solve_vec_and_det_tests!(4);
348 gen_public_api_pivoting_solve_vec_and_det_tests!(5);
349
350 macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
351 ($d:literal) => {
352 paste! {
353 #[test]
354 fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
355 #[allow(clippy::large_stack_arrays)]
360 let mut rows = [[0.0f64; $d]; $d];
361 for i in 0..$d {
362 rows[i][i] = 2.0;
363 if i > 0 {
364 rows[i][i - 1] = -1.0;
365 }
366 if i + 1 < $d {
367 rows[i][i + 1] = -1.0;
368 }
369 }
370
371 let a = Matrix::<$d>::from_rows(black_box(rows));
372 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
373 black_box(Matrix::<$d>::lu);
374 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
375
376 let mut b_arr = [0.0f64; $d];
378 b_arr[0] = 1.0;
379 b_arr[$d - 1] = 1.0;
380 let b = Vector::<$d>::new(black_box(b_arr));
381
382 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
383 black_box(Lu::<$d>::solve_vec);
384 let x = solve_fn(&lu, b).unwrap().into_array();
385
386 for &x_i in &x {
387 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
388 }
389 }
390
391 #[test]
392 fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
393 #[allow(clippy::large_stack_arrays)]
399 let mut rows = [[0.0f64; $d]; $d];
400 for i in 0..$d {
401 rows[i][i] = 2.0;
402 if i > 0 {
403 rows[i][i - 1] = -1.0;
404 }
405 if i + 1 < $d {
406 rows[i][i + 1] = -1.0;
407 }
408 }
409
410 let a = Matrix::<$d>::from_rows(black_box(rows));
411 let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
412 black_box(Matrix::<$d>::lu);
413 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
414
415 let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
416 black_box(Lu::<$d>::det);
417 assert_abs_diff_eq!(det_fn(&lu).unwrap(), f64::from($d) + 1.0, epsilon = 1e-8);
418 }
419 }
420 };
421 }
422
423 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
424 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
425 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
426
427 #[test]
428 fn solve_1x1() {
429 let a = Matrix::<1>::from_rows(black_box([[2.0]]));
430 let lu = FiniteMatrix::new(a).unwrap().lu(DEFAULT_PIVOT_TOL).unwrap();
431
432 let b = Vector::<1>::new(black_box([6.0]));
433 let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
434 black_box(Lu::<1>::solve_vec);
435 let x = solve_fn(&lu, b).unwrap().into_array();
436 assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
437
438 let det_fn: fn(&Lu<1>) -> Result<f64, LaError> = black_box(Lu::<1>::det);
439 assert_abs_diff_eq!(det_fn(&lu).unwrap(), 2.0, epsilon = 0.0);
440 }
441
442 #[test]
443 fn solve_2x2_basic() {
444 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
445 let lu = FiniteMatrix::new(a).unwrap().lu(DEFAULT_PIVOT_TOL).unwrap();
446 let b = Vector::<2>::new(black_box([5.0, 11.0]));
447
448 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
449 black_box(Lu::<2>::solve_vec);
450 let x = solve_fn(&lu, b).unwrap().into_array();
451
452 assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
453 assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
454 }
455
456 #[test]
457 fn det_2x2_basic() {
458 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
459 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
460
461 let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
462 assert_abs_diff_eq!(det_fn(&lu).unwrap(), -2.0, epsilon = 1e-12);
463 }
464
465 #[test]
466 fn det_requires_pivot_sign() {
467 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
469 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
470
471 let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
472 assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 0.0);
473 }
474
475 #[test]
476 fn solve_requires_pivoting() {
477 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
478 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
479 let b = Vector::<2>::new(black_box([1.0, 2.0]));
480
481 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
482 black_box(Lu::<2>::solve_vec);
483 let x = solve_fn(&lu, b).unwrap().into_array();
484
485 assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
486 assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
487 }
488
489 #[test]
490 fn singular_detected() {
491 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
492 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
493 assert_eq!(err, LaError::Singular { pivot_col: 1 });
494 }
495
496 #[test]
497 fn singular_due_to_tolerance_at_first_pivot() {
498 let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
500 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
501 assert_eq!(err, LaError::Singular { pivot_col: 0 });
502 }
503
504 #[test]
505 fn invalid_tolerance_rejected() {
506 assert_eq!(
507 Tolerance::new(-1.0),
508 Err(LaError::InvalidTolerance { value: -1.0 })
509 );
510
511 assert_matches!(
512 Tolerance::new(f64::NAN),
513 Err(LaError::InvalidTolerance { value }) if value.is_nan()
514 );
515 assert_eq!(
516 Tolerance::new(f64::INFINITY),
517 Err(LaError::InvalidTolerance {
518 value: f64::INFINITY,
519 })
520 );
521 }
522
523 #[test]
524 fn matrix_constructor_rejects_nonfinite_pivot_entry() {
525 let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err();
526 assert_eq!(
527 err,
528 LaError::NonFinite {
529 row: Some(0),
530 col: 0
531 }
532 );
533 }
534
535 #[test]
536 fn matrix_constructor_rejects_nonfinite_pivot_column_entry() {
537 let err = Matrix::<2>::try_from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]).unwrap_err();
538 assert_eq!(
539 err,
540 LaError::NonFinite {
541 row: Some(1),
542 col: 0
543 }
544 );
545 }
546
547 #[test]
548 fn nonfinite_detected_in_trailing_update() {
549 let a =
550 Matrix::<3>::from_rows([[1.0, f64::MAX, 0.0], [-1.0, f64::MAX, 0.0], [0.0, 0.0, 1.0]]);
551
552 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
553 assert_eq!(
554 err,
555 LaError::NonFinite {
556 row: Some(1),
557 col: 1,
558 }
559 );
560 }
561
562 #[test]
563 fn solve_vec_nonfinite_forward_substitution_overflow() {
564 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
566 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
567
568 let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
569 let err = lu.solve_vec(b).unwrap_err();
570 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
571 }
572
573 #[test]
574 fn solve_vec_nonfinite_back_substitution_overflow() {
575 let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
577 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
578
579 let b = Vector::<2>::new([0.0, 1.0e300]);
580 let err = lu.solve_vec(b).unwrap_err();
581 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
582 }
583
584 #[test]
585 fn solve_vec_nonfinite_back_substitution_sum_overflow() {
586 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 1.0e200], [0.0, 0.0, 1.0]]);
593 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
594
595 let b = Vector::<3>::new([0.0, 0.0, 1.0e200]);
596 let err = lu.solve_vec(b).unwrap_err();
597 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
598 }
599
600 #[test]
601 fn det_rejects_product_overflow() {
602 let a = Matrix::<5>::from_rows([
603 [1.0e100, 0.0, 0.0, 0.0, 0.0],
604 [0.0, 1.0e100, 0.0, 0.0, 0.0],
605 [0.0, 0.0, 1.0e100, 0.0, 0.0],
606 [0.0, 0.0, 0.0, 1.0e100, 0.0],
607 [0.0, 0.0, 0.0, 0.0, 1.0e100],
608 ]);
609 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
610 assert_eq!(lu.det(), Err(LaError::NonFinite { row: None, col: 3 }));
611 }
612
613 macro_rules! gen_solve_vec_boundary_tests {
614 ($d:literal) => {
615 paste! {
616 #[test]
619 fn [<solve_vec_rhs_boundary_rejects_non_finite_ $d d>]() {
620 let mut rhs = [1.0; $d];
621 rhs[$d - 1] = f64::NAN;
622
623 assert_eq!(
624 Vector::<$d>::try_new(rhs),
625 Err(LaError::NonFinite {
626 row: None,
627 col: $d - 1,
628 })
629 );
630 }
631
632 #[test]
633 fn [<solve_vec_revalidates_unchecked_rhs_storage_ $d d>]() {
634 let lu = Matrix::<$d>::identity().lu(DEFAULT_PIVOT_TOL).unwrap();
635 let mut rhs = [1.0; $d];
636 rhs[$d - 1] = f64::NAN;
637 let rhs = Vector::<$d>::new_unchecked(rhs);
638
639 assert_eq!(
640 lu.solve_vec(rhs),
641 Err(LaError::NonFinite {
642 row: None,
643 col: $d - 1,
644 })
645 );
646 }
647 }
648 };
649 }
650
651 gen_solve_vec_boundary_tests!(2);
652 gen_solve_vec_boundary_tests!(3);
653 gen_solve_vec_boundary_tests!(4);
654 gen_solve_vec_boundary_tests!(5);
655
656 #[test]
666 fn lu_det_const_eval_d2() {
667 const DET: Result<f64, LaError> = {
668 let mut factors = Matrix::<2>::identity();
670 factors.rows[0][0] = 2.0;
671 factors.rows[1][1] = 3.0;
672 let lu = Lu::<2> {
673 factors: LuFactors::new_unchecked(factors),
674 piv: [0, 1],
675 piv_sign: 1.0,
676 };
677 lu.det()
678 };
679 assert_eq!(DET, Ok(6.0));
680 }
681
682 #[test]
683 fn lu_det_const_eval_d3_row_swap() {
684 const DET: Result<f64, LaError> = {
685 let lu = Lu::<3> {
688 factors: LuFactors::new_unchecked(Matrix::<3>::identity()),
689 piv: [1, 0, 2],
690 piv_sign: -1.0,
691 };
692 lu.det()
693 };
694 assert_eq!(DET, Ok(-1.0));
695 }
696
697 #[test]
698 fn lu_solve_vec_const_eval_d2() {
699 const X: [f64; 2] = {
701 let lu = Lu::<2> {
702 factors: LuFactors::new_unchecked(Matrix::<2>::identity()),
703 piv: [0, 1],
704 piv_sign: 1.0,
705 };
706 let b = Vector::<2>::new([1.0, 2.0]);
707 match lu.solve_vec(b) {
708 Ok(v) => v.into_array(),
709 Err(_) => [0.0, 0.0],
710 }
711 };
712 assert!((X[0] - 1.0).abs() <= 1e-12);
713 assert!((X[1] - 2.0).abs() <= 1e-12);
714 }
715}