1use core::hint::cold_path;
4
5use crate::LaError;
6use crate::matrix::Matrix;
7use crate::vector::Vector;
8
9#[must_use]
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct Lu<const D: usize> {
13 factors: Matrix<D>,
14 piv: [usize; D],
15 piv_sign: f64,
16 tol: f64,
17}
18
19impl<const D: usize> Lu<D> {
20 #[inline]
21 pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
22 let mut lu = a;
23
24 let mut piv = [0usize; D];
25 for (i, p) in piv.iter_mut().enumerate() {
26 *p = i;
27 }
28
29 let mut piv_sign = 1.0;
30
31 for k in 0..D {
32 let mut pivot_row = k;
34 let mut pivot_abs = lu.rows[k][k].abs();
35 if !pivot_abs.is_finite() {
36 cold_path();
37 return Err(LaError::non_finite_cell(k, k));
38 }
39
40 for r in (k + 1)..D {
41 let v = lu.rows[r][k].abs();
42 if !v.is_finite() {
43 cold_path();
44 return Err(LaError::non_finite_cell(r, k));
45 }
46 if v > pivot_abs {
47 pivot_abs = v;
48 pivot_row = r;
49 }
50 }
51
52 if pivot_abs <= tol {
53 cold_path();
54 return Err(LaError::Singular { pivot_col: k });
55 }
56
57 if pivot_row != k {
58 lu.rows.swap(k, pivot_row);
59 piv.swap(k, pivot_row);
60 piv_sign = -piv_sign;
61 }
62
63 let pivot = lu.rows[k][k];
64 if !pivot.is_finite() {
65 cold_path();
66 return Err(LaError::non_finite_cell(k, k));
67 }
68
69 for r in (k + 1)..D {
71 let mult = lu.rows[r][k] / pivot;
72 if !mult.is_finite() {
73 cold_path();
74 return Err(LaError::non_finite_cell(r, k));
75 }
76 lu.rows[r][k] = mult;
77
78 for c in (k + 1)..D {
79 lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
80 }
81 }
82 }
83
84 Ok(Self {
85 factors: lu,
86 piv,
87 piv_sign,
88 tol,
89 })
90 }
91
92 #[inline]
124 pub const fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
125 let mut x = [0.0; D];
126 let mut i = 0;
127 while i < D {
128 x[i] = b.data[self.piv[i]];
129 i += 1;
130 }
131
132 let mut i = 0;
134 while i < D {
135 let mut sum = x[i];
136 let row = self.factors.rows[i];
137 let mut j = 0;
138 while j < i {
139 sum = (-row[j]).mul_add(x[j], sum);
140 j += 1;
141 }
142 if !sum.is_finite() {
143 cold_path();
144 return Err(LaError::non_finite_at(i));
145 }
146 x[i] = sum;
147 i += 1;
148 }
149
150 let mut ii = 0;
152 while ii < D {
153 let i = D - 1 - ii;
154 let mut sum = x[i];
155 let row = self.factors.rows[i];
156 let mut j = i + 1;
157 while j < D {
158 sum = (-row[j]).mul_add(x[j], sum);
159 j += 1;
160 }
161
162 let diag = row[i];
163 if !diag.is_finite() {
167 cold_path();
168 return Err(LaError::non_finite_cell(i, i));
169 }
170 if !sum.is_finite() {
171 cold_path();
172 return Err(LaError::non_finite_at(i));
173 }
174 if diag.abs() <= self.tol {
175 cold_path();
176 return Err(LaError::Singular { pivot_col: i });
177 }
178
179 let quotient = sum / diag;
180 if !quotient.is_finite() {
181 cold_path();
182 return Err(LaError::non_finite_at(i));
183 }
184 x[i] = quotient;
185 ii += 1;
186 }
187
188 Ok(Vector::new(x))
189 }
190
191 #[inline]
207 #[must_use]
208 pub const fn det(&self) -> f64 {
209 let mut det = self.piv_sign;
210 let mut i = 0;
211 while i < D {
212 det *= self.factors.rows[i][i];
213 i += 1;
214 }
215 det
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::DEFAULT_PIVOT_TOL;
223
224 use core::hint::black_box;
225
226 use approx::assert_abs_diff_eq;
227 use pastey::paste;
228
229 macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
230 ($d:literal) => {
231 paste! {
232 #[test]
233 fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
234 let mut rows = [[0.0f64; $d]; $d];
240 for i in 0..$d {
241 rows[i][i] = 1.0;
242 }
243 rows.swap(0, 1);
244
245 let a = Matrix::<$d>::from_rows(black_box(rows));
246 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
247 black_box(Matrix::<$d>::lu);
248 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
249
250 let b_arr = {
252 let mut arr = [0.0f64; $d];
253 let mut val = 1.0f64;
254 for dst in arr.iter_mut() {
255 *dst = val;
256 val += 1.0;
257 }
258 arr
259 };
260 let mut expected = b_arr;
261 expected.swap(0, 1);
262 let b = Vector::<$d>::new(black_box(b_arr));
263
264 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
265 black_box(Lu::<$d>::solve_vec);
266 let x = solve_fn(&lu, b).unwrap().into_array();
267
268 for i in 0..$d {
269 assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
270 }
271 }
272
273 #[test]
274 fn [<public_api_lu_det_pivoting_ $d d>]() {
275 let mut rows = [[0.0f64; $d]; $d];
280 for i in 0..$d {
281 rows[i][i] = 1.0;
282 }
283 rows.swap(0, 1);
284
285 let a = Matrix::<$d>::from_rows(black_box(rows));
286 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
287 black_box(Matrix::<$d>::lu);
288 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
289
290 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
292 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
293 }
294 }
295 };
296 }
297
298 gen_public_api_pivoting_solve_vec_and_det_tests!(2);
299 gen_public_api_pivoting_solve_vec_and_det_tests!(3);
300 gen_public_api_pivoting_solve_vec_and_det_tests!(4);
301 gen_public_api_pivoting_solve_vec_and_det_tests!(5);
302
303 macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
304 ($d:literal) => {
305 paste! {
306 #[test]
307 fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
308 #[allow(clippy::large_stack_arrays)]
313 let mut rows = [[0.0f64; $d]; $d];
314 for i in 0..$d {
315 rows[i][i] = 2.0;
316 if i > 0 {
317 rows[i][i - 1] = -1.0;
318 }
319 if i + 1 < $d {
320 rows[i][i + 1] = -1.0;
321 }
322 }
323
324 let a = Matrix::<$d>::from_rows(black_box(rows));
325 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
326 black_box(Matrix::<$d>::lu);
327 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
328
329 let mut b_arr = [0.0f64; $d];
331 b_arr[0] = 1.0;
332 b_arr[$d - 1] = 1.0;
333 let b = Vector::<$d>::new(black_box(b_arr));
334
335 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
336 black_box(Lu::<$d>::solve_vec);
337 let x = solve_fn(&lu, b).unwrap().into_array();
338
339 for &x_i in &x {
340 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
341 }
342 }
343
344 #[test]
345 fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
346 #[allow(clippy::large_stack_arrays)]
352 let mut rows = [[0.0f64; $d]; $d];
353 for i in 0..$d {
354 rows[i][i] = 2.0;
355 if i > 0 {
356 rows[i][i - 1] = -1.0;
357 }
358 if i + 1 < $d {
359 rows[i][i + 1] = -1.0;
360 }
361 }
362
363 let a = Matrix::<$d>::from_rows(black_box(rows));
364 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
365 black_box(Matrix::<$d>::lu);
366 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
367
368 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
369 assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
370 }
371 }
372 };
373 }
374
375 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
376 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
377 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
378
379 #[test]
380 fn solve_1x1() {
381 let a = Matrix::<1>::from_rows(black_box([[2.0]]));
382 let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
383
384 let b = Vector::<1>::new(black_box([6.0]));
385 let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
386 black_box(Lu::<1>::solve_vec);
387 let x = solve_fn(&lu, b).unwrap().into_array();
388 assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
389
390 let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
391 assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
392 }
393
394 #[test]
395 fn solve_2x2_basic() {
396 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
397 let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
398 let b = Vector::<2>::new(black_box([5.0, 11.0]));
399
400 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
401 black_box(Lu::<2>::solve_vec);
402 let x = solve_fn(&lu, b).unwrap().into_array();
403
404 assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
405 assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
406 }
407
408 #[test]
409 fn det_2x2_basic() {
410 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
411 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
412
413 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
414 assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
415 }
416
417 #[test]
418 fn det_requires_pivot_sign() {
419 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
421 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
422
423 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
424 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
425 }
426
427 #[test]
428 fn solve_requires_pivoting() {
429 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
430 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
431 let b = Vector::<2>::new(black_box([1.0, 2.0]));
432
433 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
434 black_box(Lu::<2>::solve_vec);
435 let x = solve_fn(&lu, b).unwrap().into_array();
436
437 assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
438 assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
439 }
440
441 #[test]
442 fn singular_detected() {
443 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
444 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
445 assert_eq!(err, LaError::Singular { pivot_col: 1 });
446 }
447
448 #[test]
449 fn singular_due_to_tolerance_at_first_pivot() {
450 let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
452 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
453 assert_eq!(err, LaError::Singular { pivot_col: 0 });
454 }
455
456 #[test]
457 fn nonfinite_detected_on_pivot_entry() {
458 let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
459 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
460 assert_eq!(
461 err,
462 LaError::NonFinite {
463 row: Some(0),
464 col: 0
465 }
466 );
467 }
468
469 #[test]
470 fn nonfinite_detected_in_pivot_column_scan() {
471 let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
472 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
473 assert_eq!(
474 err,
475 LaError::NonFinite {
476 row: Some(1),
477 col: 0
478 }
479 );
480 }
481
482 #[test]
483 fn solve_vec_nonfinite_forward_substitution_overflow() {
484 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
486 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
487
488 let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
489 let err = lu.solve_vec(b).unwrap_err();
490 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
491 }
492
493 #[test]
494 fn solve_vec_nonfinite_back_substitution_overflow() {
495 let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
497 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
498
499 let b = Vector::<2>::new([0.0, 1.0e300]);
500 let err = lu.solve_vec(b).unwrap_err();
501 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
502 }
503
504 #[test]
505 fn solve_vec_nonfinite_back_substitution_sum_overflow() {
506 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 1.0e200], [0.0, 0.0, 1.0]]);
513 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
514
515 let b = Vector::<3>::new([0.0, 0.0, 1.0e200]);
516 let err = lu.solve_vec(b).unwrap_err();
517 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
518 }
519
520 macro_rules! gen_solve_vec_defensive_tests {
534 ($d:literal) => {
535 paste! {
536 #[test]
540 fn [<solve_vec_defensive_sub_tolerance_diagonal_ $d d>]() {
541 let mut factors = Matrix::<$d>::identity();
542 factors.rows[$d - 1][$d - 1] = 0.0;
543
544 let mut piv = [0usize; $d];
545 for (i, p) in piv.iter_mut().enumerate() {
546 *p = i;
547 }
548
549 let lu = Lu::<$d> {
550 factors,
551 piv,
552 piv_sign: 1.0,
553 tol: DEFAULT_PIVOT_TOL,
554 };
555 let b = Vector::<$d>::new([0.0; $d]);
556 let err = lu.solve_vec(b).unwrap_err();
557 assert_eq!(err, LaError::Singular { pivot_col: $d - 1 });
558 }
559
560 #[test]
566 fn [<solve_vec_defensive_non_finite_diagonal_ $d d>]() {
567 let mut factors = Matrix::<$d>::identity();
568 factors.rows[$d - 1][$d - 1] = f64::NAN;
569
570 let mut piv = [0usize; $d];
571 for (i, p) in piv.iter_mut().enumerate() {
572 *p = i;
573 }
574
575 let lu = Lu::<$d> {
576 factors,
577 piv,
578 piv_sign: 1.0,
579 tol: DEFAULT_PIVOT_TOL,
580 };
581 let b = Vector::<$d>::new([1.0; $d]);
582 let err = lu.solve_vec(b).unwrap_err();
583 assert_eq!(
584 err,
585 LaError::NonFinite {
586 row: Some($d - 1),
587 col: $d - 1,
588 }
589 );
590 }
591 }
592 };
593 }
594
595 gen_solve_vec_defensive_tests!(2);
596 gen_solve_vec_defensive_tests!(3);
597 gen_solve_vec_defensive_tests!(4);
598 gen_solve_vec_defensive_tests!(5);
599
600 #[test]
610 fn lu_det_const_eval_d2() {
611 const DET: f64 = {
612 let mut factors = Matrix::<2>::identity();
614 factors.rows[0][0] = 2.0;
615 factors.rows[1][1] = 3.0;
616 let lu = Lu::<2> {
617 factors,
618 piv: [0, 1],
619 piv_sign: 1.0,
620 tol: DEFAULT_PIVOT_TOL,
621 };
622 lu.det()
623 };
624 assert!((DET - 6.0).abs() <= 1e-12);
625 }
626
627 #[test]
628 fn lu_det_const_eval_d3_row_swap() {
629 const DET: f64 = {
630 let lu = Lu::<3> {
633 factors: Matrix::<3>::identity(),
634 piv: [1, 0, 2],
635 piv_sign: -1.0,
636 tol: DEFAULT_PIVOT_TOL,
637 };
638 lu.det()
639 };
640 assert!((DET - (-1.0)).abs() <= 1e-12);
641 }
642
643 #[test]
644 fn lu_solve_vec_const_eval_d2() {
645 const X: [f64; 2] = {
647 let lu = Lu::<2> {
648 factors: Matrix::<2>::identity(),
649 piv: [0, 1],
650 piv_sign: 1.0,
651 tol: DEFAULT_PIVOT_TOL,
652 };
653 let b = Vector::<2>::new([1.0, 2.0]);
654 match lu.solve_vec(b) {
655 Ok(v) => v.into_array(),
656 Err(_) => [0.0, 0.0],
657 }
658 };
659 assert!((X[0] - 1.0).abs() <= 1e-12);
660 assert!((X[1] - 2.0).abs() <= 1e-12);
661 }
662}