Skip to main content

la_stack/
lu.rs

1//! LU decomposition and solves.
2
3use core::hint::cold_path;
4
5use crate::matrix::{FiniteMatrix, Matrix};
6use crate::vector::{FiniteVector, Vector};
7use crate::{LaError, Tolerance};
8
9/// LU decomposition (PA = LU) with partial pivoting.
10#[must_use]
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct Lu<const D: usize> {
13    factors: LuFactors<D>,
14    piv: [usize; D],
15    piv_sign: f64,
16}
17
18/// In-place LU factor storage whose `U` diagonal is finite and usable.
19///
20/// Construction through [`Lu::factor_finite`] proves every stored entry is
21/// finite and every `U[i,i]` satisfies the factorization tolerance.
22#[derive(Clone, Copy, Debug, PartialEq)]
23struct LuFactors<const D: usize> {
24    storage: Matrix<D>,
25}
26
27impl<const D: usize> LuFactors<D> {
28    /// Construct factors after LU factorization has proven the storage invariant.
29    #[inline]
30    const fn new_unchecked(storage: Matrix<D>) -> Self {
31        Self { storage }
32    }
33
34    /// Borrow a factor row.
35    #[inline]
36    #[must_use]
37    const fn row(&self, index: usize) -> &[f64; D] {
38        &self.storage.rows[index]
39    }
40
41    /// Return a diagonal entry of `U`.
42    #[inline]
43    #[must_use]
44    const fn diag(&self, index: usize) -> f64 {
45        self.storage.rows[index][index]
46    }
47}
48
49impl<const D: usize> Lu<D> {
50    /// Factor a finite square matrix into in-place LU storage for
51    /// `FiniteMatrix::lu`.
52    ///
53    /// The input has already proven finite entries, so LU construction rejects
54    /// numerically singular pivots and non-finite elimination intermediates
55    /// before callers can observe a [`Lu`] value. Completed factor storage is
56    /// checked before return so successful factors do not contain a non-finite
57    /// value produced during elimination.
58    #[inline]
59    pub(crate) fn factor_finite(a: FiniteMatrix<D>, tol: Tolerance) -> Result<Self, LaError> {
60        let mut lu = a.into_matrix();
61        let tol = tol.get();
62
63        let mut piv = [0usize; D];
64        for (i, p) in piv.iter_mut().enumerate() {
65            *p = i;
66        }
67
68        let mut piv_sign = 1.0;
69
70        for k in 0..D {
71            // Choose pivot row.
72            let mut pivot_row = k;
73            let mut pivot_abs = lu.rows[k][k].abs();
74
75            for r in (k + 1)..D {
76                let v = lu.rows[r][k].abs();
77                if v > pivot_abs {
78                    pivot_abs = v;
79                    pivot_row = r;
80                }
81            }
82
83            if pivot_abs <= tol {
84                cold_path();
85                return Err(LaError::Singular { pivot_col: k });
86            }
87
88            if pivot_row != k {
89                lu.rows.swap(k, pivot_row);
90                piv.swap(k, pivot_row);
91                piv_sign = -piv_sign;
92            }
93
94            let pivot = lu.rows[k][k];
95
96            // Eliminate below pivot.
97            for r in (k + 1)..D {
98                let mult = lu.rows[r][k] / pivot;
99                lu.rows[r][k] = mult;
100
101                for c in (k + 1)..D {
102                    let updated = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
103                    lu.rows[r][c] = updated;
104                }
105            }
106        }
107
108        let lu = FiniteMatrix::new(lu)?.into_matrix();
109
110        Ok(Self {
111            factors: LuFactors::new_unchecked(lu),
112            piv,
113            piv_sign,
114        })
115    }
116
117    /// Solve `A x = b` using this LU factorization.
118    ///
119    /// `solve_vec` performs floating-point forward/back substitution and does
120    /// not provide a certified absolute rounding-error bound for the returned
121    /// solution. Callers should not expect an absolute error bound like
122    /// [`Matrix::det_errbound`](crate::Matrix::det_errbound) or the
123    /// `ERR_COEFF_*` determinant constants provide.
124    ///
125    /// # Examples
126    /// ```
127    /// use la_stack::prelude::*;
128    ///
129    /// # fn main() -> Result<(), LaError> {
130    /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?;
131    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
132    ///
133    /// let b = Vector::<2>::try_new([5.0, 11.0])?;
134    /// let x = lu.solve_vec(b)?.into_array();
135    ///
136    /// assert!((x[0] - 1.0).abs() <= 1e-12);
137    /// assert!((x[1] - 2.0).abs() <= 1e-12);
138    /// # Ok(())
139    /// # }
140    /// ```
141    ///
142    /// # Errors
143    /// Returns [`LaError::NonFinite`] if the right-hand side contains
144    /// NaN/infinity or a computed substitution intermediate overflows to NaN or
145    /// infinity. The right-hand side is revalidated at this boundary before
146    /// entering substitution; public callers normally encounter the same
147    /// rejection earlier through [`Vector::try_new`](crate::Vector::try_new).
148    #[inline]
149    pub const fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
150        let b = match FiniteVector::new(b) {
151            Ok(b) => b,
152            Err(err) => return Err(err),
153        };
154        match self.solve_finite_vec(b) {
155            Ok(x) => Ok(x.into_vector()),
156            Err(err) => Err(err),
157        }
158    }
159
160    /// Solve `A x = b` using this LU factorization and a finite right-hand side.
161    ///
162    /// The right-hand side entries and stored factors are known finite, so this
163    /// path only checks computed substitution overflows.
164    ///
165    /// # Errors
166    /// Returns [`LaError::NonFinite`] if a computed substitution intermediate
167    /// overflows to NaN or infinity.
168    #[inline]
169    pub(crate) const fn solve_finite_vec(
170        &self,
171        b: FiniteVector<D>,
172    ) -> Result<FiniteVector<D>, LaError> {
173        let mut x = [0.0; D];
174        let mut i = 0;
175        let b = b.as_array();
176        while i < D {
177            x[i] = b[self.piv[i]];
178            i += 1;
179        }
180
181        // Forward substitution for L (unit diagonal).
182        let mut i = 0;
183        while i < D {
184            let mut sum = x[i];
185            let row = self.factors.row(i);
186            let mut j = 0;
187            while j < i {
188                sum = (-row[j]).mul_add(x[j], sum);
189                j += 1;
190            }
191            if !sum.is_finite() {
192                cold_path();
193                return Err(LaError::non_finite_at(i));
194            }
195            x[i] = sum;
196            i += 1;
197        }
198
199        // Back substitution for U.
200        let mut ii = 0;
201        while ii < D {
202            let i = D - 1 - ii;
203            let mut sum = x[i];
204            let row = self.factors.row(i);
205            let mut j = i + 1;
206            while j < D {
207                sum = (-row[j]).mul_add(x[j], sum);
208                j += 1;
209            }
210
211            let diag = row[i];
212            if !sum.is_finite() {
213                cold_path();
214                return Err(LaError::non_finite_at(i));
215            }
216
217            let quotient = sum / diag;
218            if !quotient.is_finite() {
219                cold_path();
220                return Err(LaError::non_finite_at(i));
221            }
222            x[i] = quotient;
223            ii += 1;
224        }
225
226        Ok(FiniteVector::new_unchecked(Vector::new_unchecked(x)))
227    }
228
229    /// Determinant of the original matrix.
230    ///
231    /// # Examples
232    /// ```
233    /// use la_stack::prelude::*;
234    ///
235    /// # fn main() -> Result<(), LaError> {
236    /// let a = Matrix::<2>::try_from_rows([[1.0, 2.0], [3.0, 4.0]])?;
237    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
238    ///
239    /// let det = lu.det()?;
240    /// assert!((det - (-2.0)).abs() <= 1e-12);
241    /// # Ok(())
242    /// # }
243    /// ```
244    ///
245    /// # Errors
246    /// Returns [`LaError::NonFinite`] if the determinant product overflows to
247    /// NaN or infinity.
248    #[inline]
249    pub const fn det(&self) -> Result<f64, LaError> {
250        let mut det = self.piv_sign;
251        let mut i = 0;
252        while i < D {
253            det *= self.factors.diag(i);
254            if !det.is_finite() {
255                cold_path();
256                return Err(LaError::non_finite_at(i));
257            }
258            i += 1;
259        }
260        Ok(det)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::DEFAULT_PIVOT_TOL;
268
269    use core::assert_matches;
270    use core::hint::black_box;
271
272    use approx::assert_abs_diff_eq;
273    use pastey::paste;
274
275    macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
276        ($d:literal) => {
277            paste! {
278                #[test]
279                fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
280                    // Public API path under test:
281                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
282
283                    // Permutation matrix that swaps the first two basis vectors.
284                    // This forces pivoting in column 0 for any D >= 2.
285                    let mut rows = [[0.0f64; $d]; $d];
286                    for i in 0..$d {
287                        rows[i][i] = 1.0;
288                    }
289                    rows.swap(0, 1);
290
291                    let a = Matrix::<$d>::from_rows(black_box(rows));
292                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
293                        black_box(Matrix::<$d>::lu);
294                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
295
296                    // Pick a simple RHS with unique entries, so the expected swap is obvious.
297                    let b_arr = {
298                        let mut arr = [0.0f64; $d];
299                        let mut val = 1.0f64;
300                        for dst in arr.iter_mut() {
301                            *dst = val;
302                            val += 1.0;
303                        }
304                        arr
305                    };
306                    let mut expected = b_arr;
307                    expected.swap(0, 1);
308                    let b = Vector::<$d>::new(black_box(b_arr));
309
310                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
311                        black_box(Lu::<$d>::solve_vec);
312                    let x = solve_fn(&lu, b).unwrap().into_array();
313
314                    for i in 0..$d {
315                        assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
316                    }
317                }
318
319                #[test]
320                fn [<public_api_lu_det_pivoting_ $d d>]() {
321                    // Public API path under test:
322                    // Matrix::lu (pub) -> Lu::det (pub).
323
324                    // Permutation matrix that swaps the first two basis vectors.
325                    let mut rows = [[0.0f64; $d]; $d];
326                    for i in 0..$d {
327                        rows[i][i] = 1.0;
328                    }
329                    rows.swap(0, 1);
330
331                    let a = Matrix::<$d>::from_rows(black_box(rows));
332                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
333                        black_box(Matrix::<$d>::lu);
334                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
335
336                    // Row swap ⇒ determinant sign flip.
337                    let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
338                        black_box(Lu::<$d>::det);
339                    assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 1e-12);
340                }
341            }
342        };
343    }
344
345    gen_public_api_pivoting_solve_vec_and_det_tests!(2);
346    gen_public_api_pivoting_solve_vec_and_det_tests!(3);
347    gen_public_api_pivoting_solve_vec_and_det_tests!(4);
348    gen_public_api_pivoting_solve_vec_and_det_tests!(5);
349
350    macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
351        ($d:literal) => {
352            paste! {
353                #[test]
354                fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
355                    // Public API path under test:
356                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
357
358                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
359                    #[allow(clippy::large_stack_arrays)]
360                    let mut rows = [[0.0f64; $d]; $d];
361                    for i in 0..$d {
362                        rows[i][i] = 2.0;
363                        if i > 0 {
364                            rows[i][i - 1] = -1.0;
365                        }
366                        if i + 1 < $d {
367                            rows[i][i + 1] = -1.0;
368                        }
369                    }
370
371                    let a = Matrix::<$d>::from_rows(black_box(rows));
372                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
373                        black_box(Matrix::<$d>::lu);
374                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
375
376                    // Choose x = 1, so b = A x is simple: [1, 0, 0, ..., 0, 1].
377                    let mut b_arr = [0.0f64; $d];
378                    b_arr[0] = 1.0;
379                    b_arr[$d - 1] = 1.0;
380                    let b = Vector::<$d>::new(black_box(b_arr));
381
382                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
383                        black_box(Lu::<$d>::solve_vec);
384                    let x = solve_fn(&lu, b).unwrap().into_array();
385
386                    for &x_i in &x {
387                        assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
388                    }
389                }
390
391                #[test]
392                fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
393                    // Public API path under test:
394                    // Matrix::lu (pub) -> Lu::det (pub).
395
396                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
397                    // Determinant is known exactly: det = D + 1.
398                    #[allow(clippy::large_stack_arrays)]
399                    let mut rows = [[0.0f64; $d]; $d];
400                    for i in 0..$d {
401                        rows[i][i] = 2.0;
402                        if i > 0 {
403                            rows[i][i - 1] = -1.0;
404                        }
405                        if i + 1 < $d {
406                            rows[i][i + 1] = -1.0;
407                        }
408                    }
409
410                    let a = Matrix::<$d>::from_rows(black_box(rows));
411                    let lu_fn: fn(Matrix<$d>, Tolerance) -> Result<Lu<$d>, LaError> =
412                        black_box(Matrix::<$d>::lu);
413                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
414
415                    let det_fn: fn(&Lu<$d>) -> Result<f64, LaError> =
416                        black_box(Lu::<$d>::det);
417                    assert_abs_diff_eq!(det_fn(&lu).unwrap(), f64::from($d) + 1.0, epsilon = 1e-8);
418                }
419            }
420        };
421    }
422
423    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
424    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
425    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
426
427    #[test]
428    fn solve_1x1() {
429        let a = Matrix::<1>::from_rows(black_box([[2.0]]));
430        let lu = FiniteMatrix::new(a).unwrap().lu(DEFAULT_PIVOT_TOL).unwrap();
431
432        let b = Vector::<1>::new(black_box([6.0]));
433        let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
434            black_box(Lu::<1>::solve_vec);
435        let x = solve_fn(&lu, b).unwrap().into_array();
436        assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
437
438        let det_fn: fn(&Lu<1>) -> Result<f64, LaError> = black_box(Lu::<1>::det);
439        assert_abs_diff_eq!(det_fn(&lu).unwrap(), 2.0, epsilon = 0.0);
440    }
441
442    #[test]
443    fn solve_2x2_basic() {
444        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
445        let lu = FiniteMatrix::new(a).unwrap().lu(DEFAULT_PIVOT_TOL).unwrap();
446        let b = Vector::<2>::new(black_box([5.0, 11.0]));
447
448        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
449            black_box(Lu::<2>::solve_vec);
450        let x = solve_fn(&lu, b).unwrap().into_array();
451
452        assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
453        assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
454    }
455
456    #[test]
457    fn det_2x2_basic() {
458        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
459        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
460
461        let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
462        assert_abs_diff_eq!(det_fn(&lu).unwrap(), -2.0, epsilon = 1e-12);
463    }
464
465    #[test]
466    fn det_requires_pivot_sign() {
467        // Row swap ⇒ determinant sign flip.
468        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
469        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
470
471        let det_fn: fn(&Lu<2>) -> Result<f64, LaError> = black_box(Lu::<2>::det);
472        assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 0.0);
473    }
474
475    #[test]
476    fn solve_requires_pivoting() {
477        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
478        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
479        let b = Vector::<2>::new(black_box([1.0, 2.0]));
480
481        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
482            black_box(Lu::<2>::solve_vec);
483        let x = solve_fn(&lu, b).unwrap().into_array();
484
485        assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
486        assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
487    }
488
489    #[test]
490    fn singular_detected() {
491        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
492        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
493        assert_eq!(err, LaError::Singular { pivot_col: 1 });
494    }
495
496    #[test]
497    fn singular_due_to_tolerance_at_first_pivot() {
498        // Not exactly singular, but below DEFAULT_PIVOT_TOL.
499        let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
500        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
501        assert_eq!(err, LaError::Singular { pivot_col: 0 });
502    }
503
504    #[test]
505    fn invalid_tolerance_rejected() {
506        assert_eq!(
507            Tolerance::new(-1.0),
508            Err(LaError::InvalidTolerance { value: -1.0 })
509        );
510
511        assert_matches!(
512            Tolerance::new(f64::NAN),
513            Err(LaError::InvalidTolerance { value }) if value.is_nan()
514        );
515        assert_eq!(
516            Tolerance::new(f64::INFINITY),
517            Err(LaError::InvalidTolerance {
518                value: f64::INFINITY,
519            })
520        );
521    }
522
523    #[test]
524    fn matrix_constructor_rejects_nonfinite_pivot_entry() {
525        let err = Matrix::<2>::try_from_rows([[f64::NAN, 0.0], [0.0, 1.0]]).unwrap_err();
526        assert_eq!(
527            err,
528            LaError::NonFinite {
529                row: Some(0),
530                col: 0
531            }
532        );
533    }
534
535    #[test]
536    fn matrix_constructor_rejects_nonfinite_pivot_column_entry() {
537        let err = Matrix::<2>::try_from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]).unwrap_err();
538        assert_eq!(
539            err,
540            LaError::NonFinite {
541                row: Some(1),
542                col: 0
543            }
544        );
545    }
546
547    #[test]
548    fn nonfinite_detected_in_trailing_update() {
549        let a =
550            Matrix::<3>::from_rows([[1.0, f64::MAX, 0.0], [-1.0, f64::MAX, 0.0], [0.0, 0.0, 1.0]]);
551
552        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
553        assert_eq!(
554            err,
555            LaError::NonFinite {
556                row: Some(1),
557                col: 1,
558            }
559        );
560    }
561
562    #[test]
563    fn solve_vec_nonfinite_forward_substitution_overflow() {
564        // L has a -1 multiplier, and a large RHS makes forward substitution overflow.
565        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
566        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
567
568        let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
569        let err = lu.solve_vec(b).unwrap_err();
570        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
571    }
572
573    #[test]
574    fn solve_vec_nonfinite_back_substitution_overflow() {
575        // Make x[1] overflow during back substitution, then ensure it is detected on the next row.
576        let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
577        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
578
579        let b = Vector::<2>::new([0.0, 1.0e300]);
580        let err = lu.solve_vec(b).unwrap_err();
581        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
582    }
583
584    #[test]
585    fn solve_vec_nonfinite_back_substitution_sum_overflow() {
586        // Upper-triangular U with a very large off-diagonal in row 1 and a
587        // very large x[2] produced by the RHS.  The back-substitution
588        // accumulator `sum = (-row[j]).mul_add(x[j], sum)` overflows while
589        // reducing row 1, so the failure is detected via the `!sum.is_finite()`
590        // branch of the combined diag/sum check (distinct from the
591        // `q = sum / diag` overflow path covered above).
592        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 1.0e200], [0.0, 0.0, 1.0]]);
593        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
594
595        let b = Vector::<3>::new([0.0, 0.0, 1.0e200]);
596        let err = lu.solve_vec(b).unwrap_err();
597        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
598    }
599
600    #[test]
601    fn det_rejects_product_overflow() {
602        let a = Matrix::<5>::from_rows([
603            [1.0e100, 0.0, 0.0, 0.0, 0.0],
604            [0.0, 1.0e100, 0.0, 0.0, 0.0],
605            [0.0, 0.0, 1.0e100, 0.0, 0.0],
606            [0.0, 0.0, 0.0, 1.0e100, 0.0],
607            [0.0, 0.0, 0.0, 0.0, 1.0e100],
608        ]);
609        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
610        assert_eq!(lu.det(), Err(LaError::NonFinite { row: None, col: 3 }));
611    }
612
613    macro_rules! gen_solve_vec_boundary_tests {
614        ($d:literal) => {
615            paste! {
616                /// Raw non-finite right-hand sides are rejected before a
617                /// public caller can construct a `Vector`.
618                #[test]
619                fn [<solve_vec_rhs_boundary_rejects_non_finite_ $d d>]() {
620                    let mut rhs = [1.0; $d];
621                    rhs[$d - 1] = f64::NAN;
622
623                    assert_eq!(
624                        Vector::<$d>::try_new(rhs),
625                        Err(LaError::NonFinite {
626                            row: None,
627                            col: $d - 1,
628                        })
629                    );
630                }
631
632                #[test]
633                fn [<solve_vec_revalidates_unchecked_rhs_storage_ $d d>]() {
634                    let lu = Matrix::<$d>::identity().lu(DEFAULT_PIVOT_TOL).unwrap();
635                    let mut rhs = [1.0; $d];
636                    rhs[$d - 1] = f64::NAN;
637                    let rhs = Vector::<$d>::new_unchecked(rhs);
638
639                    assert_eq!(
640                        lu.solve_vec(rhs),
641                        Err(LaError::NonFinite {
642                            row: None,
643                            col: $d - 1,
644                        })
645                    );
646                }
647            }
648        };
649    }
650
651    gen_solve_vec_boundary_tests!(2);
652    gen_solve_vec_boundary_tests!(3);
653    gen_solve_vec_boundary_tests!(4);
654    gen_solve_vec_boundary_tests!(5);
655
656    // -----------------------------------------------------------------------
657    // Const-evaluability tests.
658    //
659    // These prove that `Lu::det` and `Lu::solve_vec` are truly `const fn` by
660    // forcing the compiler to evaluate them inside a `const` initializer.
661    // `Lu::factor` is not (yet) `const fn` because it relies on `<[T]>::swap`,
662    // which is not const-stable; we therefore construct `Lu<D>` directly.
663    // -----------------------------------------------------------------------
664
665    #[test]
666    fn lu_det_const_eval_d2() {
667        const DET: Result<f64, LaError> = {
668            // Triangular factors with diag [2.0, 3.0] and no row swaps.
669            let mut factors = Matrix::<2>::identity();
670            factors.rows[0][0] = 2.0;
671            factors.rows[1][1] = 3.0;
672            let lu = Lu::<2> {
673                factors: LuFactors::new_unchecked(factors),
674                piv: [0, 1],
675                piv_sign: 1.0,
676            };
677            lu.det()
678        };
679        assert_eq!(DET, Ok(6.0));
680    }
681
682    #[test]
683    fn lu_det_const_eval_d3_row_swap() {
684        const DET: Result<f64, LaError> = {
685            // Identity factors but `piv_sign = -1.0` encoding a single row swap;
686            // the determinant magnitude is 1 but the sign flips.
687            let lu = Lu::<3> {
688                factors: LuFactors::new_unchecked(Matrix::<3>::identity()),
689                piv: [1, 0, 2],
690                piv_sign: -1.0,
691            };
692            lu.det()
693        };
694        assert_eq!(DET, Ok(-1.0));
695    }
696
697    #[test]
698    fn lu_solve_vec_const_eval_d2() {
699        // Identity LU ⇒ solve_vec returns the permuted RHS untouched.
700        const X: [f64; 2] = {
701            let lu = Lu::<2> {
702                factors: LuFactors::new_unchecked(Matrix::<2>::identity()),
703                piv: [0, 1],
704                piv_sign: 1.0,
705            };
706            let b = Vector::<2>::new([1.0, 2.0]);
707            match lu.solve_vec(b) {
708                Ok(v) => v.into_array(),
709                Err(_) => [0.0, 0.0],
710            }
711        };
712        assert!((X[0] - 1.0).abs() <= 1e-12);
713        assert!((X[1] - 2.0).abs() <= 1e-12);
714    }
715}