Skip to main content

la_stack/
lu.rs

1#![forbid(unsafe_code)]
2
3//! LU decomposition and solves.
4
5use core::hint::cold_path;
6
7use crate::matrix::Matrix;
8use crate::vector::Vector;
9use crate::{LaError, Tolerance};
10
11/// LU decomposition (PA = LU) with partial pivoting.
12///
13/// `Lu<0>` represents the empty factorization. Its determinant is the empty
14/// product `1.0`, and solving against [`Vector<0>`] returns [`Vector<0>`].
15#[must_use]
16#[derive(Clone, Copy, Debug, PartialEq)]
17pub struct Lu<const D: usize> {
18    factors: LuFactors<D>,
19    piv: [usize; D],
20    piv_sign: f64,
21}
22
23/// In-place LU factor storage whose `U` diagonal is finite and usable.
24///
25/// Construction through [`Lu::factor_finite`] proves every stored entry is
26/// finite and every `U[i,i]` satisfies the factorization tolerance.
27#[derive(Clone, Copy, Debug, PartialEq)]
28struct LuFactors<const D: usize> {
29    storage: Matrix<D>,
30}
31
32impl<const D: usize> LuFactors<D> {
33    /// Construct factors after LU factorization has proven the storage invariant.
34    #[inline]
35    const fn new_unchecked(storage: Matrix<D>) -> Self {
36        Self { storage }
37    }
38
39    /// Borrow a factor row.
40    #[inline]
41    #[must_use]
42    const fn row(&self, index: usize) -> &[f64; D] {
43        &self.storage.rows()[index]
44    }
45
46    /// Return a diagonal entry of `U`.
47    #[inline]
48    #[must_use]
49    const fn diag(&self, index: usize) -> f64 {
50        self.storage.rows()[index][index]
51    }
52}
53
54impl<const D: usize> Lu<D> {
55    /// Factor a finite square matrix into in-place LU storage for
56    /// [`Matrix::lu`].
57    ///
58    /// The input has already proven finite entries, so LU construction rejects
59    /// numerically singular pivots and non-finite elimination intermediates
60    /// before callers can observe a [`Lu`] value. Completed factor storage is
61    /// checked before return so successful factors do not contain a non-finite
62    /// value produced during elimination.
63    #[inline]
64    #[allow(clippy::needless_range_loop)]
65    pub(crate) fn factor_finite(a: Matrix<D>, tol: Tolerance) -> Result<Self, LaError> {
66        let mut lu = a;
67        let tol = tol.get();
68
69        let mut piv = [0usize; D];
70        for (i, p) in piv.iter_mut().enumerate() {
71            *p = i;
72        }
73
74        let mut piv_sign = 1.0;
75
76        {
77            let rows = lu.rows_mut_unchecked();
78
79            for k in 0..D {
80                // Choose pivot row.
81                let mut pivot_row = k;
82                let mut pivot_abs = rows[k][k].abs();
83
84                for r in (k + 1)..D {
85                    let v = rows[r][k].abs();
86                    if v > pivot_abs {
87                        pivot_abs = v;
88                        pivot_row = r;
89                    }
90                }
91
92                if pivot_abs <= tol {
93                    cold_path();
94                    return Err(LaError::Singular { pivot_col: k });
95                }
96
97                if pivot_row != k {
98                    rows.swap(k, pivot_row);
99                    piv.swap(k, pivot_row);
100                    piv_sign = -piv_sign;
101                }
102
103                let pivot = rows[k][k];
104
105                // Eliminate below pivot.
106                for r in (k + 1)..D {
107                    let mult = rows[r][k] / pivot;
108                    rows[r][k] = mult;
109
110                    for c in (k + 1)..D {
111                        let updated = (-mult).mul_add(rows[k][c], rows[r][c]);
112                        rows[r][c] = updated;
113                    }
114                }
115            }
116        }
117
118        let lu = lu.validate_finite()?;
119
120        Ok(Self {
121            factors: LuFactors::new_unchecked(lu),
122            piv,
123            piv_sign,
124        })
125    }
126
127    /// Solve `A x = b` using this LU factorization.
128    ///
129    /// [`Vector`] is finite by construction, so this method only checks computed
130    /// substitution overflows. It performs floating-point forward/back
131    /// substitution and does not provide a certified absolute rounding-error
132    /// bound for the returned solution.
133    ///
134    /// # Examples
135    /// ```
136    /// use la_stack::prelude::*;
137    ///
138    /// # fn main() -> Result<(), LaError> {
139    /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?;
140    /// let lu = a.lu(DEFAULT_SINGULAR_TOL)?;
141    ///
142    /// let b = Vector::<2>::try_new([5.0, 11.0])?;
143    /// let x = lu.solve(b)?.into_array();
144    ///
145    /// assert!((x[0] - 1.0).abs() <= 1e-12);
146    /// assert!((x[1] - 2.0).abs() <= 1e-12);
147    /// # Ok(())
148    /// # }
149    /// ```
150    ///
151    /// # Errors
152    /// Returns [`LaError::NonFinite`] if a computed substitution intermediate
153    /// overflows to NaN or infinity.
154    #[inline]
155    pub const fn solve(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
156        self.solve_finite(b)
157    }
158
159    /// Solve `A x = b` using this LU factorization and a finite right-hand side.
160    ///
161    /// The right-hand side entries and stored factors are known finite, so this
162    /// path only checks computed substitution overflows.
163    ///
164    /// # Errors
165    /// Returns [`LaError::NonFinite`] if a computed substitution intermediate
166    /// overflows to NaN or infinity.
167    #[inline]
168    pub(crate) const fn solve_finite(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
169        let mut x = [0.0; D];
170        let b = b.as_array();
171        let mut i = 0;
172
173        if D <= 4 {
174            while i < D {
175                x[i] = b[self.piv[i]];
176                i += 1;
177            }
178
179            // Tiny matrices benchmark better when pivoted RHS materialization
180            // stays separate from forward substitution.
181            i = 0;
182            while i < D {
183                let mut sum = x[i];
184                let row = self.factors.row(i);
185                let mut j = 0;
186                while j < i {
187                    sum = (-row[j]).mul_add(x[j], sum);
188                    j += 1;
189                }
190                if !sum.is_finite() {
191                    cold_path();
192                    return Err(LaError::non_finite_at(i));
193                }
194                x[i] = sum;
195                i += 1;
196            }
197        } else {
198            // Larger fixed dimensions avoid an extra pass by reading the
199            // pivoted right-hand side directly into forward substitution.
200            while i < D {
201                let mut sum = b[self.piv[i]];
202                let row = self.factors.row(i);
203                let mut j = 0;
204                while j < i {
205                    sum = (-row[j]).mul_add(x[j], sum);
206                    j += 1;
207                }
208                if !sum.is_finite() {
209                    cold_path();
210                    return Err(LaError::non_finite_at(i));
211                }
212                x[i] = sum;
213                i += 1;
214            }
215        }
216
217        // Back substitution for U.
218        let mut ii = 0;
219        while ii < D {
220            let i = D - 1 - ii;
221            let mut sum = x[i];
222            let row = self.factors.row(i);
223            let mut j = i + 1;
224            while j < D {
225                sum = (-row[j]).mul_add(x[j], sum);
226                j += 1;
227            }
228
229            let diag = row[i];
230            if !sum.is_finite() {
231                cold_path();
232                return Err(LaError::non_finite_at(i));
233            }
234
235            let quotient = sum / diag;
236            if !quotient.is_finite() {
237                cold_path();
238                return Err(LaError::non_finite_at(i));
239            }
240            x[i] = quotient;
241            ii += 1;
242        }
243
244        Ok(Vector::new_unchecked(x))
245    }
246
247    /// Determinant of the original matrix.
248    ///
249    /// # Examples
250    /// ```
251    /// use la_stack::prelude::*;
252    ///
253    /// # fn main() -> Result<(), LaError> {
254    /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?;
255    /// let lu = a.lu(DEFAULT_SINGULAR_TOL)?;
256    ///
257    /// let det = lu.det()?;
258    /// assert!((det - (-2.0)).abs() <= 1e-12);
259    /// # Ok(())
260    /// # }
261    /// ```
262    ///
263    /// # Errors
264    /// Returns [`LaError::NonFinite`] if the determinant product overflows to
265    /// NaN or infinity.
266    #[inline]
267    pub const fn det(&self) -> Result<f64, LaError> {
268        let mut det = self.piv_sign;
269        let mut i = 0;
270        while i < D {
271            det *= self.factors.diag(i);
272            if !det.is_finite() {
273                cold_path();
274                return Err(LaError::non_finite_at(i));
275            }
276            i += 1;
277        }
278        Ok(det)
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::DEFAULT_SINGULAR_TOL;
286
287    use core::hint::black_box;
288
289    use approx::assert_abs_diff_eq;
290    use pastey::paste;
291
292    macro_rules! gen_public_api_pivoting_solve_and_det_tests {
293        ($d:literal) => {
294            paste! {
295                #[test]
296                fn [<public_api_lu_solve_pivoting_ $d d>]() {
297                    // Public API path under test:
298                    // Matrix::lu (pub) -> Lu::solve (pub).
299
300                    // Permutation matrix that swaps the first two basis vectors.
301                    // This forces pivoting in column 0 for any D >= 2.
302                    let mut rows = [[0.0f64; $d]; $d];
303                    for i in 0..$d {
304                        rows[i][i] = 1.0;
305                    }
306                    rows.swap(0, 1);
307
308                    let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
309                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
310                        black_box(Matrix::<$d>::lu);
311                    let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
312
313                    // Pick a simple RHS with unique entries, so the expected swap is obvious.
314                    let b_arr = {
315                        let mut arr = [0.0f64; $d];
316                        let mut val = 1.0f64;
317                        for dst in arr.iter_mut() {
318                            *dst = val;
319                            val += 1.0;
320                        }
321                        arr
322                    };
323                    let mut expected = b_arr;
324                    expected.swap(0, 1);
325                    let b = Vector::<$d>::new(black_box(b_arr));
326
327                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
328                        black_box(Lu::<$d>::solve);
329                    let x = solve_fn(&lu, b).unwrap().into_array();
330
331                    for i in 0..$d {
332                        assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
333                    }
334                }
335
336                #[test]
337                fn [<public_api_lu_det_pivoting_ $d d>]() {
338                    // Public API path under test:
339                    // Matrix::lu (pub) -> Lu::det (pub).
340
341                    // Permutation matrix that swaps the first two basis vectors.
342                    let mut rows = [[0.0f64; $d]; $d];
343                    for i in 0..$d {
344                        rows[i][i] = 1.0;
345                    }
346                    rows.swap(0, 1);
347
348                    let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
349                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
350                        black_box(Matrix::<$d>::lu);
351                    let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
352
353                    // Row swap ⇒ determinant sign flip.
354                    let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
355                        black_box(Lu::<$d>::det);
356                    assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 1e-12);
357                }
358            }
359        };
360    }
361
362    gen_public_api_pivoting_solve_and_det_tests!(2);
363    gen_public_api_pivoting_solve_and_det_tests!(3);
364    gen_public_api_pivoting_solve_and_det_tests!(4);
365    gen_public_api_pivoting_solve_and_det_tests!(5);
366
367    macro_rules! gen_public_api_tridiagonal_smoke_solve_and_det_tests {
368        ($d:literal) => {
369            paste! {
370                #[test]
371                fn [<public_api_lu_solve_tridiagonal_smoke_ $d d>]() {
372                    // Public API path under test:
373                    // Matrix::lu (pub) -> Lu::solve (pub).
374
375                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
376                    #[allow(clippy::large_stack_arrays)]
377                    let mut rows = [[0.0f64; $d]; $d];
378                    for i in 0..$d {
379                        rows[i][i] = 2.0;
380                        if i > 0 {
381                            rows[i][i - 1] = -1.0;
382                        }
383                        if i + 1 < $d {
384                            rows[i][i + 1] = -1.0;
385                        }
386                    }
387
388                    let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
389                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
390                        black_box(Matrix::<$d>::lu);
391                    let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
392
393                    // Choose x = 1, so b = A x is simple: [1, 0, 0, ..., 0, 1].
394                    let mut b_arr = [0.0f64; $d];
395                    b_arr[0] = 1.0;
396                    b_arr[$d - 1] = 1.0;
397                    let b = Vector::<$d>::new(black_box(b_arr));
398
399                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
400                        black_box(Lu::<$d>::solve);
401                    let x = solve_fn(&lu, b).unwrap().into_array();
402
403                    for &x_i in &x {
404                        assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
405                    }
406                }
407
408                #[test]
409                fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
410                    // Public API path under test:
411                    // Matrix::lu (pub) -> Lu::det (pub).
412
413                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
414                    // Determinant is known exactly: det = D + 1.
415                    #[allow(clippy::large_stack_arrays)]
416                    let mut rows = [[0.0f64; $d]; $d];
417                    for i in 0..$d {
418                        rows[i][i] = 2.0;
419                        if i > 0 {
420                            rows[i][i - 1] = -1.0;
421                        }
422                        if i + 1 < $d {
423                            rows[i][i + 1] = -1.0;
424                        }
425                    }
426
427                    let a = Matrix::<$d>::try_from_rows(black_box(rows)).unwrap();
428                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
429                        black_box(Matrix::<$d>::lu);
430                    let lu = lu_fn(a, DEFAULT_SINGULAR_TOL).unwrap();
431
432                    let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
433                        black_box(Lu::<$d>::det);
434                    assert_abs_diff_eq!(det_fn(&lu).unwrap(), f64::from($d) + 1.0, epsilon = 1e-8);
435                }
436            }
437        };
438    }
439
440    gen_public_api_tridiagonal_smoke_solve_and_det_tests!(16);
441    gen_public_api_tridiagonal_smoke_solve_and_det_tests!(32);
442    gen_public_api_tridiagonal_smoke_solve_and_det_tests!(64);
443
444    #[test]
445    fn solve_0x0_returns_empty_vector_and_unit_det() {
446        let a = Matrix::<0>::zero();
447        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
448
449        assert_eq!(lu.det(), Ok(1.0));
450        assert!(
451            lu.solve(Vector::<0>::zero())
452                .unwrap()
453                .into_array()
454                .is_empty()
455        );
456    }
457
458    #[test]
459    fn solve_1x1() {
460        let a = Matrix::<1>::try_from_rows(black_box([[2.0]])).unwrap();
461        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
462
463        let b = Vector::<1>::new(black_box([6.0]));
464        let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
465            black_box(Lu::<1>::solve);
466        let x = solve_fn(&lu, b).unwrap().into_array();
467        assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
468
469        let det_fn: fn(&Lu<1>) -> Result<f64, LaError> = black_box(Lu::<1>::det);
470        assert_abs_diff_eq!(det_fn(&lu).unwrap(), 2.0, epsilon = 0.0);
471    }
472
473    #[test]
474    fn solve_2x2_basic() {
475        let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [3.0, 4.0]])).unwrap();
476        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
477        let b = Vector::<2>::new(black_box([5.0, 11.0]));
478
479        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
480            black_box(Lu::<2>::solve);
481        let x = solve_fn(&lu, b).unwrap().into_array();
482
483        assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
484        assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
485    }
486
487    #[test]
488    fn det_2x2_basic() {
489        let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [3.0, 4.0]])).unwrap();
490        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
491
492        let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
493        assert_abs_diff_eq!(det_fn(&lu).unwrap(), -2.0, epsilon = 1e-12);
494    }
495
496    #[test]
497    fn det_requires_pivot_sign() {
498        // Row swap ⇒ determinant sign flip.
499        let a = Matrix::<2>::try_from_rows(black_box([[0.0, 1.0], [1.0, 0.0]])).unwrap();
500        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
501
502        let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
503        assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 0.0);
504    }
505
506    #[test]
507    fn solve_requires_pivoting() {
508        let a = Matrix::<2>::try_from_rows(black_box([[0.0, 1.0], [1.0, 0.0]])).unwrap();
509        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
510        let b = Vector::<2>::new(black_box([1.0, 2.0]));
511
512        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
513            black_box(Lu::<2>::solve);
514        let x = solve_fn(&lu, b).unwrap().into_array();
515
516        assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
517        assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
518    }
519
520    #[test]
521    fn singular_detected() {
522        let a = Matrix::<2>::try_from_rows(black_box([[1.0, 2.0], [2.0, 4.0]])).unwrap();
523        let err = a.lu(DEFAULT_SINGULAR_TOL).unwrap_err();
524        assert_eq!(err, LaError::Singular { pivot_col: 1 });
525    }
526
527    #[test]
528    fn singular_due_to_tolerance_at_first_pivot() {
529        // Not exactly singular, but below DEFAULT_SINGULAR_TOL.
530        let a = Matrix::<2>::try_from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]])).unwrap();
531        let err = a.lu(DEFAULT_SINGULAR_TOL).unwrap_err();
532        assert_eq!(err, LaError::Singular { pivot_col: 0 });
533    }
534
535    #[test]
536    fn matrix_constructor_rejects_nonfinite_pivot_entry() {
537        let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err();
538        assert_eq!(
539            err,
540            LaError::NonFinite {
541                row: Some(0),
542                col: 0
543            }
544        );
545    }
546
547    #[test]
548    fn matrix_constructor_rejects_nonfinite_pivot_column_entry() {
549        let err = Matrix::<2>::try_from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]).unwrap_err();
550        assert_eq!(
551            err,
552            LaError::NonFinite {
553                row: Some(1),
554                col: 0
555            }
556        );
557    }
558
559    #[test]
560    fn nonfinite_detected_in_trailing_update() {
561        let a = Matrix::<3>::try_from_rows([
562            [1.0, f64::MAX, 0.0],
563            [-1.0, f64::MAX, 0.0],
564            [0.0, 0.0, 1.0],
565        ])
566        .unwrap();
567
568        let err = a.lu(DEFAULT_SINGULAR_TOL).unwrap_err();
569        assert_eq!(
570            err,
571            LaError::NonFinite {
572                row: Some(1),
573                col: 1,
574            }
575        );
576    }
577
578    #[test]
579    fn solve_nonfinite_forward_substitution_overflow() {
580        // L has a -1 multiplier, and a large RHS makes forward substitution overflow.
581        let a = Matrix::<3>::try_from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
582            .unwrap();
583        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
584
585        let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
586        let err = lu.solve(b).unwrap_err();
587        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
588    }
589
590    #[test]
591    fn solve_nonfinite_forward_substitution_overflow_fused_branch_5d() {
592        // Exercises the D >= 5 fused pivot/forward-substitution branch with the
593        // same overflowing L multiplier as the D3 test.
594        let a = Matrix::<5>::try_from_rows([
595            [1.0, 0.0, 0.0, 0.0, 0.0],
596            [-1.0, 1.0, 0.0, 0.0, 0.0],
597            [0.0, 0.0, 1.0, 0.0, 0.0],
598            [0.0, 0.0, 0.0, 1.0, 0.0],
599            [0.0, 0.0, 0.0, 0.0, 1.0],
600        ])
601        .unwrap();
602        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
603
604        let b = Vector::<5>::new([1.0e308, 1.0e308, 0.0, 0.0, 0.0]);
605        let err = lu.solve(b).unwrap_err();
606        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
607    }
608
609    #[test]
610    fn solve_nonfinite_back_substitution_overflow() {
611        // Make x[1] overflow during back substitution, then ensure it is detected on the next row.
612        let a = Matrix::<2>::try_from_rows([[1.0, 1.0], [0.0, 2.0e-12]]).unwrap();
613        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
614
615        let b = Vector::<2>::new([0.0, 1.0e300]);
616        let err = lu.solve(b).unwrap_err();
617        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
618    }
619
620    #[test]
621    fn solve_nonfinite_back_substitution_sum_overflow() {
622        // Upper-triangular U with a very large off-diagonal in row 1 and a
623        // very large x[2] produced by the RHS.  The back-substitution
624        // accumulator `sum = (-row[j]).mul_add(x[j], sum)` overflows while
625        // reducing row 1, so the failure is detected via the `!sum.is_finite()`
626        // branch of the combined diag/sum check (distinct from the
627        // `q = sum / diag` overflow path covered above).
628        let a = Matrix::<3>::try_from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 1.0e200], [0.0, 0.0, 1.0]])
629            .unwrap();
630        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
631
632        let b = Vector::<3>::new([0.0, 0.0, 1.0e200]);
633        let err = lu.solve(b).unwrap_err();
634        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
635    }
636
637    #[test]
638    fn det_rejects_product_overflow() {
639        let a = Matrix::<5>::try_from_rows([
640            [1.0e100, 0.0, 0.0, 0.0, 0.0],
641            [0.0, 1.0e100, 0.0, 0.0, 0.0],
642            [0.0, 0.0, 1.0e100, 0.0, 0.0],
643            [0.0, 0.0, 0.0, 1.0e100, 0.0],
644            [0.0, 0.0, 0.0, 0.0, 1.0e100],
645        ])
646        .unwrap();
647        let lu = a.lu(DEFAULT_SINGULAR_TOL).unwrap();
648        assert_eq!(lu.det(), Err(LaError::NonFinite { row: None, col: 3 }));
649    }
650
651    macro_rules! gen_solve_boundary_tests {
652        ($d:literal) => {
653            paste! {
654                /// Raw non-finite right-hand sides are rejected before a
655                /// public caller can construct a `Vector`.
656                #[test]
657                fn [<solve_rhs_constructor_rejects_non_finite_ $d d>]() {
658                    let mut rhs = [1.0; $d];
659                    rhs[$d - 1] = f64::NAN;
660
661                    assert_eq!(
662                        Vector::<$d>::try_new(rhs),
663                        Err(LaError::NonFinite {
664                            row: None,
665                            col: $d - 1,
666                        })
667                    );
668                }
669            }
670        };
671    }
672
673    gen_solve_boundary_tests!(2);
674    gen_solve_boundary_tests!(3);
675    gen_solve_boundary_tests!(4);
676    gen_solve_boundary_tests!(5);
677
678    // -----------------------------------------------------------------------
679    // Const-evaluability tests.
680    //
681    // These prove that `Lu::det` and `Lu::solve` are truly `const fn` by
682    // forcing the compiler to evaluate them inside a `const` initializer.
683    // `Lu::factor` is not (yet) `const fn` because it relies on `<[T]>::swap`,
684    // which is not const-stable; we therefore construct `Lu<D>` directly.
685    // -----------------------------------------------------------------------
686
687    #[test]
688    fn lu_det_const_eval_d2() {
689        const DET: Result<f64, LaError> = {
690            // Triangular factors with diag [2.0, 3.0] and no row swaps.
691            let factors = Matrix::<2>::from_rows_unchecked([[2.0, 0.0], [0.0, 3.0]]);
692            let lu = Lu::<2> {
693                factors: LuFactors::new_unchecked(factors),
694                piv: [0, 1],
695                piv_sign: 1.0,
696            };
697            lu.det()
698        };
699        assert_eq!(DET, Ok(6.0));
700    }
701
702    #[test]
703    fn lu_det_const_eval_d3_row_swap() {
704        const DET: Result<f64, LaError> = {
705            // Identity factors but `piv_sign = -1.0` encoding a single row swap;
706            // the determinant magnitude is 1 but the sign flips.
707            let lu = Lu::<3> {
708                factors: LuFactors::new_unchecked(Matrix::<3>::identity()),
709                piv: [1, 0, 2],
710                piv_sign: -1.0,
711            };
712            lu.det()
713        };
714        assert_eq!(DET, Ok(-1.0));
715    }
716
717    #[test]
718    fn lu_solve_const_eval_d2() {
719        // Identity LU ⇒ solve returns the permuted RHS untouched.
720        const X: [f64; 2] = {
721            let lu = Lu::<2> {
722                factors: LuFactors::new_unchecked(Matrix::<2>::identity()),
723                piv: [0, 1],
724                piv_sign: 1.0,
725            };
726            let b = Vector::<2>::new([1.0, 2.0]);
727            match lu.solve(b) {
728                Ok(v) => v.into_array(),
729                Err(_) => [0.0, 0.0],
730            }
731        };
732        assert!((X[0] - 1.0).abs() <= 1e-12);
733        assert!((X[1] - 2.0).abs() <= 1e-12);
734    }
735}