Skip to main content

la_stack/
lu.rs

1//! LU decomposition and solves.
2
3use core::hint::cold_path;
4
5use crate::LaError;
6use crate::matrix::Matrix;
7use crate::vector::Vector;
8
9/// LU decomposition (PA = LU) with partial pivoting.
10#[must_use]
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct Lu<const D: usize> {
13    factors: Matrix<D>,
14    piv: [usize; D],
15    piv_sign: f64,
16    tol: f64,
17}
18
19impl<const D: usize> Lu<D> {
20    #[inline]
21    pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
22        let mut lu = a;
23
24        let mut piv = [0usize; D];
25        for (i, p) in piv.iter_mut().enumerate() {
26            *p = i;
27        }
28
29        let mut piv_sign = 1.0;
30
31        for k in 0..D {
32            // Choose pivot row.
33            let mut pivot_row = k;
34            let mut pivot_abs = lu.rows[k][k].abs();
35            if !pivot_abs.is_finite() {
36                cold_path();
37                return Err(LaError::non_finite_cell(k, k));
38            }
39
40            for r in (k + 1)..D {
41                let v = lu.rows[r][k].abs();
42                if !v.is_finite() {
43                    cold_path();
44                    return Err(LaError::non_finite_cell(r, k));
45                }
46                if v > pivot_abs {
47                    pivot_abs = v;
48                    pivot_row = r;
49                }
50            }
51
52            if pivot_abs <= tol {
53                cold_path();
54                return Err(LaError::Singular { pivot_col: k });
55            }
56
57            if pivot_row != k {
58                lu.rows.swap(k, pivot_row);
59                piv.swap(k, pivot_row);
60                piv_sign = -piv_sign;
61            }
62
63            let pivot = lu.rows[k][k];
64            if !pivot.is_finite() {
65                cold_path();
66                return Err(LaError::non_finite_cell(k, k));
67            }
68
69            // Eliminate below pivot.
70            for r in (k + 1)..D {
71                let mult = lu.rows[r][k] / pivot;
72                if !mult.is_finite() {
73                    cold_path();
74                    return Err(LaError::non_finite_cell(r, k));
75                }
76                lu.rows[r][k] = mult;
77
78                for c in (k + 1)..D {
79                    lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
80                }
81            }
82        }
83
84        Ok(Self {
85            factors: lu,
86            piv,
87            piv_sign,
88            tol,
89        })
90    }
91
92    /// Solve `A x = b` using this LU factorization.
93    ///
94    /// # Examples
95    /// ```
96    /// use la_stack::prelude::*;
97    ///
98    /// # fn main() -> Result<(), LaError> {
99    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
100    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
101    ///
102    /// let b = Vector::<2>::new([5.0, 11.0]);
103    /// let x = lu.solve_vec(b)?.into_array();
104    ///
105    /// assert!((x[0] - 1.0).abs() <= 1e-12);
106    /// assert!((x[1] - 2.0).abs() <= 1e-12);
107    /// # Ok(())
108    /// # }
109    /// ```
110    ///
111    /// # Errors
112    /// Returns [`LaError::Singular`] if a diagonal entry of `U` satisfies `|u_ii| <= tol`, where
113    /// `tol` is the tolerance that was used during factorization.
114    ///
115    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected. The `row`/`col` coordinates
116    /// follow the convention documented on [`LaError::NonFinite`]:
117    ///
118    /// - `row: Some(i), col: i` — the stored `U` diagonal at `(i, i)` is non-finite
119    ///   (only reachable via direct `Lu` construction; [`Matrix::lu`](crate::Matrix::lu)
120    ///   rejects such factorizations).
121    /// - `row: None, col: i` — a computed intermediate (forward/back-substitution
122    ///   accumulator or the quotient `sum / diag`) overflowed to NaN/∞ at step `i`.
123    #[inline]
124    pub const fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
125        let mut x = [0.0; D];
126        let mut i = 0;
127        while i < D {
128            x[i] = b.data[self.piv[i]];
129            i += 1;
130        }
131
132        // Forward substitution for L (unit diagonal).
133        let mut i = 0;
134        while i < D {
135            let mut sum = x[i];
136            let row = self.factors.rows[i];
137            let mut j = 0;
138            while j < i {
139                sum = (-row[j]).mul_add(x[j], sum);
140                j += 1;
141            }
142            if !sum.is_finite() {
143                cold_path();
144                return Err(LaError::non_finite_at(i));
145            }
146            x[i] = sum;
147            i += 1;
148        }
149
150        // Back substitution for U.
151        let mut ii = 0;
152        while ii < D {
153            let i = D - 1 - ii;
154            let mut sum = x[i];
155            let row = self.factors.rows[i];
156            let mut j = i + 1;
157            while j < D {
158                sum = (-row[j]).mul_add(x[j], sum);
159                j += 1;
160            }
161
162            let diag = row[i];
163            // Distinguish a corrupt stored pivot (row: Some(i), col: i) from
164            // a computed intermediate overflow (row: None, col: i) so callers
165            // can diagnose the failure source without inspecting internals.
166            if !diag.is_finite() {
167                cold_path();
168                return Err(LaError::non_finite_cell(i, i));
169            }
170            if !sum.is_finite() {
171                cold_path();
172                return Err(LaError::non_finite_at(i));
173            }
174            if diag.abs() <= self.tol {
175                cold_path();
176                return Err(LaError::Singular { pivot_col: i });
177            }
178
179            let quotient = sum / diag;
180            if !quotient.is_finite() {
181                cold_path();
182                return Err(LaError::non_finite_at(i));
183            }
184            x[i] = quotient;
185            ii += 1;
186        }
187
188        Ok(Vector::new(x))
189    }
190
191    /// Determinant of the original matrix.
192    ///
193    /// # Examples
194    /// ```
195    /// use la_stack::prelude::*;
196    ///
197    /// # fn main() -> Result<(), LaError> {
198    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
199    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
200    ///
201    /// let det = lu.det();
202    /// assert!((det - (-2.0)).abs() <= 1e-12);
203    /// # Ok(())
204    /// # }
205    /// ```
206    #[inline]
207    #[must_use]
208    pub const fn det(&self) -> f64 {
209        let mut det = self.piv_sign;
210        let mut i = 0;
211        while i < D {
212            det *= self.factors.rows[i][i];
213            i += 1;
214        }
215        det
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::DEFAULT_PIVOT_TOL;
223
224    use core::hint::black_box;
225
226    use approx::assert_abs_diff_eq;
227    use pastey::paste;
228
229    macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
230        ($d:literal) => {
231            paste! {
232                #[test]
233                fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
234                    // Public API path under test:
235                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
236
237                    // Permutation matrix that swaps the first two basis vectors.
238                    // This forces pivoting in column 0 for any D >= 2.
239                    let mut rows = [[0.0f64; $d]; $d];
240                    for i in 0..$d {
241                        rows[i][i] = 1.0;
242                    }
243                    rows.swap(0, 1);
244
245                    let a = Matrix::<$d>::from_rows(black_box(rows));
246                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
247                        black_box(Matrix::<$d>::lu);
248                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
249
250                    // Pick a simple RHS with unique entries, so the expected swap is obvious.
251                    let b_arr = {
252                        let mut arr = [0.0f64; $d];
253                        let mut val = 1.0f64;
254                        for dst in arr.iter_mut() {
255                            *dst = val;
256                            val += 1.0;
257                        }
258                        arr
259                    };
260                    let mut expected = b_arr;
261                    expected.swap(0, 1);
262                    let b = Vector::<$d>::new(black_box(b_arr));
263
264                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
265                        black_box(Lu::<$d>::solve_vec);
266                    let x = solve_fn(&lu, b).unwrap().into_array();
267
268                    for i in 0..$d {
269                        assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
270                    }
271                }
272
273                #[test]
274                fn [<public_api_lu_det_pivoting_ $d d>]() {
275                    // Public API path under test:
276                    // Matrix::lu (pub) -> Lu::det (pub).
277
278                    // Permutation matrix that swaps the first two basis vectors.
279                    let mut rows = [[0.0f64; $d]; $d];
280                    for i in 0..$d {
281                        rows[i][i] = 1.0;
282                    }
283                    rows.swap(0, 1);
284
285                    let a = Matrix::<$d>::from_rows(black_box(rows));
286                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
287                        black_box(Matrix::<$d>::lu);
288                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
289
290                    // Row swap ⇒ determinant sign flip.
291                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
292                    assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
293                }
294            }
295        };
296    }
297
298    gen_public_api_pivoting_solve_vec_and_det_tests!(2);
299    gen_public_api_pivoting_solve_vec_and_det_tests!(3);
300    gen_public_api_pivoting_solve_vec_and_det_tests!(4);
301    gen_public_api_pivoting_solve_vec_and_det_tests!(5);
302
303    macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
304        ($d:literal) => {
305            paste! {
306                #[test]
307                fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
308                    // Public API path under test:
309                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
310
311                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
312                    #[allow(clippy::large_stack_arrays)]
313                    let mut rows = [[0.0f64; $d]; $d];
314                    for i in 0..$d {
315                        rows[i][i] = 2.0;
316                        if i > 0 {
317                            rows[i][i - 1] = -1.0;
318                        }
319                        if i + 1 < $d {
320                            rows[i][i + 1] = -1.0;
321                        }
322                    }
323
324                    let a = Matrix::<$d>::from_rows(black_box(rows));
325                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
326                        black_box(Matrix::<$d>::lu);
327                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
328
329                    // Choose x = 1, so b = A x is simple: [1, 0, 0, ..., 0, 1].
330                    let mut b_arr = [0.0f64; $d];
331                    b_arr[0] = 1.0;
332                    b_arr[$d - 1] = 1.0;
333                    let b = Vector::<$d>::new(black_box(b_arr));
334
335                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
336                        black_box(Lu::<$d>::solve_vec);
337                    let x = solve_fn(&lu, b).unwrap().into_array();
338
339                    for &x_i in &x {
340                        assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
341                    }
342                }
343
344                #[test]
345                fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
346                    // Public API path under test:
347                    // Matrix::lu (pub) -> Lu::det (pub).
348
349                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
350                    // Determinant is known exactly: det = D + 1.
351                    #[allow(clippy::large_stack_arrays)]
352                    let mut rows = [[0.0f64; $d]; $d];
353                    for i in 0..$d {
354                        rows[i][i] = 2.0;
355                        if i > 0 {
356                            rows[i][i - 1] = -1.0;
357                        }
358                        if i + 1 < $d {
359                            rows[i][i + 1] = -1.0;
360                        }
361                    }
362
363                    let a = Matrix::<$d>::from_rows(black_box(rows));
364                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
365                        black_box(Matrix::<$d>::lu);
366                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
367
368                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
369                    assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
370                }
371            }
372        };
373    }
374
375    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
376    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
377    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
378
379    #[test]
380    fn solve_1x1() {
381        let a = Matrix::<1>::from_rows(black_box([[2.0]]));
382        let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
383
384        let b = Vector::<1>::new(black_box([6.0]));
385        let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
386            black_box(Lu::<1>::solve_vec);
387        let x = solve_fn(&lu, b).unwrap().into_array();
388        assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
389
390        let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
391        assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
392    }
393
394    #[test]
395    fn solve_2x2_basic() {
396        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
397        let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
398        let b = Vector::<2>::new(black_box([5.0, 11.0]));
399
400        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
401            black_box(Lu::<2>::solve_vec);
402        let x = solve_fn(&lu, b).unwrap().into_array();
403
404        assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
405        assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
406    }
407
408    #[test]
409    fn det_2x2_basic() {
410        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
411        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
412
413        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
414        assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
415    }
416
417    #[test]
418    fn det_requires_pivot_sign() {
419        // Row swap ⇒ determinant sign flip.
420        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
421        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
422
423        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
424        assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
425    }
426
427    #[test]
428    fn solve_requires_pivoting() {
429        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
430        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
431        let b = Vector::<2>::new(black_box([1.0, 2.0]));
432
433        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
434            black_box(Lu::<2>::solve_vec);
435        let x = solve_fn(&lu, b).unwrap().into_array();
436
437        assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
438        assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
439    }
440
441    #[test]
442    fn singular_detected() {
443        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
444        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
445        assert_eq!(err, LaError::Singular { pivot_col: 1 });
446    }
447
448    #[test]
449    fn singular_due_to_tolerance_at_first_pivot() {
450        // Not exactly singular, but below DEFAULT_PIVOT_TOL.
451        let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
452        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
453        assert_eq!(err, LaError::Singular { pivot_col: 0 });
454    }
455
456    #[test]
457    fn nonfinite_detected_on_pivot_entry() {
458        let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
459        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
460        assert_eq!(
461            err,
462            LaError::NonFinite {
463                row: Some(0),
464                col: 0
465            }
466        );
467    }
468
469    #[test]
470    fn nonfinite_detected_in_pivot_column_scan() {
471        let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
472        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
473        assert_eq!(
474            err,
475            LaError::NonFinite {
476                row: Some(1),
477                col: 0
478            }
479        );
480    }
481
482    #[test]
483    fn solve_vec_nonfinite_forward_substitution_overflow() {
484        // L has a -1 multiplier, and a large RHS makes forward substitution overflow.
485        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
486        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
487
488        let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
489        let err = lu.solve_vec(b).unwrap_err();
490        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
491    }
492
493    #[test]
494    fn solve_vec_nonfinite_back_substitution_overflow() {
495        // Make x[1] overflow during back substitution, then ensure it is detected on the next row.
496        let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
497        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
498
499        let b = Vector::<2>::new([0.0, 1.0e300]);
500        let err = lu.solve_vec(b).unwrap_err();
501        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
502    }
503
504    #[test]
505    fn solve_vec_nonfinite_back_substitution_sum_overflow() {
506        // Upper-triangular U with a very large off-diagonal in row 1 and a
507        // very large x[2] produced by the RHS.  The back-substitution
508        // accumulator `sum = (-row[j]).mul_add(x[j], sum)` overflows while
509        // reducing row 1, so the failure is detected via the `!sum.is_finite()`
510        // branch of the combined diag/sum check (distinct from the
511        // `q = sum / diag` overflow path covered above).
512        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 1.0e200], [0.0, 0.0, 1.0]]);
513        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
514
515        let b = Vector::<3>::new([0.0, 0.0, 1.0e200]);
516        let err = lu.solve_vec(b).unwrap_err();
517        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
518    }
519
520    // -----------------------------------------------------------------------
521    // Defensive-path coverage for `solve_vec`.
522    //
523    // `Lu::factor` guarantees that every stored U diagonal is finite and
524    // satisfies `|U[i,i]| > tol`.  `solve_vec` still re-checks during
525    // back-substitution as a safety net (see the `!diag.is_finite()` and
526    // `diag.abs() <= self.tol` guards).  Those branches are unreachable
527    // through the public API, so the only way to exercise them is to
528    // construct `Lu` directly with a corrupt U.  The tests below document
529    // and verify that the safety nets return the documented error variants
530    // with coordinates that locate the offending stored cell.
531    // -----------------------------------------------------------------------
532
533    macro_rules! gen_solve_vec_defensive_tests {
534        ($d:literal) => {
535            paste! {
536                /// `solve_vec` must surface `Singular` when a stored U
537                /// diagonal is at or below the recorded tolerance, even
538                /// though `factor` cannot produce such a factorization.
539                #[test]
540                fn [<solve_vec_defensive_sub_tolerance_diagonal_ $d d>]() {
541                    let mut factors = Matrix::<$d>::identity();
542                    factors.rows[$d - 1][$d - 1] = 0.0;
543
544                    let mut piv = [0usize; $d];
545                    for (i, p) in piv.iter_mut().enumerate() {
546                        *p = i;
547                    }
548
549                    let lu = Lu::<$d> {
550                        factors,
551                        piv,
552                        piv_sign: 1.0,
553                        tol: DEFAULT_PIVOT_TOL,
554                    };
555                    let b = Vector::<$d>::new([0.0; $d]);
556                    let err = lu.solve_vec(b).unwrap_err();
557                    assert_eq!(err, LaError::Singular { pivot_col: $d - 1 });
558                }
559
560                /// `solve_vec` must surface `NonFinite` with the corrupt
561                /// cell's coordinates when a stored U diagonal is NaN,
562                /// even though `factor` cannot produce such a
563                /// factorization. The error must pinpoint `(D-1, D-1)`
564                /// per the [`LaError::NonFinite`] convention.
565                #[test]
566                fn [<solve_vec_defensive_non_finite_diagonal_ $d d>]() {
567                    let mut factors = Matrix::<$d>::identity();
568                    factors.rows[$d - 1][$d - 1] = f64::NAN;
569
570                    let mut piv = [0usize; $d];
571                    for (i, p) in piv.iter_mut().enumerate() {
572                        *p = i;
573                    }
574
575                    let lu = Lu::<$d> {
576                        factors,
577                        piv,
578                        piv_sign: 1.0,
579                        tol: DEFAULT_PIVOT_TOL,
580                    };
581                    let b = Vector::<$d>::new([1.0; $d]);
582                    let err = lu.solve_vec(b).unwrap_err();
583                    assert_eq!(
584                        err,
585                        LaError::NonFinite {
586                            row: Some($d - 1),
587                            col: $d - 1,
588                        }
589                    );
590                }
591            }
592        };
593    }
594
595    gen_solve_vec_defensive_tests!(2);
596    gen_solve_vec_defensive_tests!(3);
597    gen_solve_vec_defensive_tests!(4);
598    gen_solve_vec_defensive_tests!(5);
599
600    // -----------------------------------------------------------------------
601    // Const-evaluability tests.
602    //
603    // These prove that `Lu::det` and `Lu::solve_vec` are truly `const fn` by
604    // forcing the compiler to evaluate them inside a `const` initializer.
605    // `Lu::factor` is not (yet) `const fn` because it relies on `<[T]>::swap`,
606    // which is not const-stable; we therefore construct `Lu<D>` directly.
607    // -----------------------------------------------------------------------
608
609    #[test]
610    fn lu_det_const_eval_d2() {
611        const DET: f64 = {
612            // Triangular factors with diag [2.0, 3.0] and no row swaps.
613            let mut factors = Matrix::<2>::identity();
614            factors.rows[0][0] = 2.0;
615            factors.rows[1][1] = 3.0;
616            let lu = Lu::<2> {
617                factors,
618                piv: [0, 1],
619                piv_sign: 1.0,
620                tol: DEFAULT_PIVOT_TOL,
621            };
622            lu.det()
623        };
624        assert!((DET - 6.0).abs() <= 1e-12);
625    }
626
627    #[test]
628    fn lu_det_const_eval_d3_row_swap() {
629        const DET: f64 = {
630            // Identity factors but `piv_sign = -1.0` encoding a single row swap;
631            // the determinant magnitude is 1 but the sign flips.
632            let lu = Lu::<3> {
633                factors: Matrix::<3>::identity(),
634                piv: [1, 0, 2],
635                piv_sign: -1.0,
636                tol: DEFAULT_PIVOT_TOL,
637            };
638            lu.det()
639        };
640        assert!((DET - (-1.0)).abs() <= 1e-12);
641    }
642
643    #[test]
644    fn lu_solve_vec_const_eval_d2() {
645        // Identity LU ⇒ solve_vec returns the permuted RHS untouched.
646        const X: [f64; 2] = {
647            let lu = Lu::<2> {
648                factors: Matrix::<2>::identity(),
649                piv: [0, 1],
650                piv_sign: 1.0,
651                tol: DEFAULT_PIVOT_TOL,
652            };
653            let b = Vector::<2>::new([1.0, 2.0]);
654            match lu.solve_vec(b) {
655                Ok(v) => v.into_array(),
656                Err(_) => [0.0, 0.0],
657            }
658        };
659        assert!((X[0] - 1.0).abs() <= 1e-12);
660        assert!((X[1] - 2.0).abs() <= 1e-12);
661    }
662}