la_stack/
lu.rs

1//! LU decomposition and solves.
2
3use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7/// LU decomposition (PA = LU) with partial pivoting.
8#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11    factors: Matrix<D>,
12    piv: [usize; D],
13    piv_sign: f64,
14    tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18    #[inline]
19    pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20        let mut lu = a;
21
22        let mut piv = [0usize; D];
23        for (i, p) in piv.iter_mut().enumerate() {
24            *p = i;
25        }
26
27        let mut piv_sign = 1.0;
28
29        for k in 0..D {
30            // Choose pivot row.
31            let mut pivot_row = k;
32            let mut pivot_abs = lu.rows[k][k].abs();
33            if !pivot_abs.is_finite() {
34                return Err(LaError::NonFinite { pivot_col: k });
35            }
36
37            for r in (k + 1)..D {
38                let v = lu.rows[r][k].abs();
39                if !v.is_finite() {
40                    return Err(LaError::NonFinite { pivot_col: k });
41                }
42                if v > pivot_abs {
43                    pivot_abs = v;
44                    pivot_row = r;
45                }
46            }
47
48            if pivot_abs <= tol {
49                return Err(LaError::Singular { pivot_col: k });
50            }
51
52            if pivot_row != k {
53                lu.rows.swap(k, pivot_row);
54                piv.swap(k, pivot_row);
55                piv_sign = -piv_sign;
56            }
57
58            let pivot = lu.rows[k][k];
59            if !pivot.is_finite() {
60                return Err(LaError::NonFinite { pivot_col: k });
61            }
62
63            // Eliminate below pivot.
64            for r in (k + 1)..D {
65                let mult = lu.rows[r][k] / pivot;
66                if !mult.is_finite() {
67                    return Err(LaError::NonFinite { pivot_col: k });
68                }
69                lu.rows[r][k] = mult;
70
71                for c in (k + 1)..D {
72                    lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
73                }
74            }
75        }
76
77        Ok(Self {
78            factors: lu,
79            piv,
80            piv_sign,
81            tol,
82        })
83    }
84
85    /// Solve `A x = b` using this LU factorization.
86    ///
87    /// # Examples
88    /// ```
89    /// #![allow(unused_imports)]
90    /// use la_stack::prelude::*;
91    ///
92    /// # fn main() -> Result<(), LaError> {
93    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
94    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
95    ///
96    /// let b = Vector::<2>::new([5.0, 11.0]);
97    /// let x = lu.solve_vec(b)?.into_array();
98    ///
99    /// assert!((x[0] - 1.0).abs() <= 1e-12);
100    /// assert!((x[1] - 2.0).abs() <= 1e-12);
101    /// # Ok(())
102    /// # }
103    /// ```
104    ///
105    /// # Errors
106    /// Returns [`LaError::Singular`] if a diagonal of `U` is (numerically) zero.
107    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected.
108    #[inline]
109    pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
110        let mut x = [0.0; D];
111        for (i, x_i) in x.iter_mut().enumerate() {
112            *x_i = b.data[self.piv[i]];
113        }
114
115        // Forward substitution for L (unit diagonal).
116        for i in 0..D {
117            let mut sum = x[i];
118            let row = self.factors.rows[i];
119            for (j, x_j) in x.iter().enumerate().take(i) {
120                sum = (-row[j]).mul_add(*x_j, sum);
121            }
122            if !sum.is_finite() {
123                return Err(LaError::NonFinite { pivot_col: i });
124            }
125            x[i] = sum;
126        }
127
128        // Back substitution for U.
129        for ii in 0..D {
130            let i = D - 1 - ii;
131            let mut sum = x[i];
132            let row = self.factors.rows[i];
133            for (j, x_j) in x.iter().enumerate().skip(i + 1) {
134                sum = (-row[j]).mul_add(*x_j, sum);
135            }
136
137            let diag = row[i];
138            if !diag.is_finite() || !sum.is_finite() {
139                return Err(LaError::NonFinite { pivot_col: i });
140            }
141            if diag.abs() <= self.tol {
142                return Err(LaError::Singular { pivot_col: i });
143            }
144
145            x[i] = sum / diag;
146        }
147
148        Ok(Vector::new(x))
149    }
150
151    /// Determinant of the original matrix.
152    ///
153    /// # Examples
154    /// ```
155    /// #![allow(unused_imports)]
156    /// use la_stack::prelude::*;
157    ///
158    /// # fn main() -> Result<(), LaError> {
159    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
160    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
161    ///
162    /// let det = lu.det();
163    /// assert!((det - (-2.0)).abs() <= 1e-12);
164    /// # Ok(())
165    /// # }
166    /// ```
167    #[inline]
168    #[must_use]
169    pub fn det(&self) -> f64 {
170        let mut det = self.piv_sign;
171        for i in 0..D {
172            det *= self.factors.rows[i][i];
173        }
174        det
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::DEFAULT_PIVOT_TOL;
182
183    use core::hint::black_box;
184
185    use approx::assert_abs_diff_eq;
186    use pastey::paste;
187
188    macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
189        ($d:literal) => {
190            paste! {
191                #[test]
192                fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
193                    // Public API path under test:
194                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
195
196                    // Permutation matrix that swaps the first two basis vectors.
197                    // This forces pivoting in column 0 for any D >= 2.
198                    let mut rows = [[0.0f64; $d]; $d];
199                    for i in 0..$d {
200                        rows[i][i] = 1.0;
201                    }
202                    rows.swap(0, 1);
203
204                    let a = Matrix::<$d>::from_rows(black_box(rows));
205                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
206                        black_box(Matrix::<$d>::lu);
207                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
208
209                    // Pick a simple RHS with unique entries, so the expected swap is obvious.
210                    let b_arr = {
211                        let mut arr = [0.0f64; $d];
212                        let mut val = 1.0f64;
213                        for dst in arr.iter_mut() {
214                            *dst = val;
215                            val += 1.0;
216                        }
217                        arr
218                    };
219                    let mut expected = b_arr;
220                    expected.swap(0, 1);
221                    let b = Vector::<$d>::new(black_box(b_arr));
222
223                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
224                        black_box(Lu::<$d>::solve_vec);
225                    let x = solve_fn(&lu, b).unwrap().into_array();
226
227                    for i in 0..$d {
228                        assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
229                    }
230                }
231
232                #[test]
233                fn [<public_api_lu_det_pivoting_ $d d>]() {
234                    // Public API path under test:
235                    // Matrix::lu (pub) -> Lu::det (pub).
236
237                    // Permutation matrix that swaps the first two basis vectors.
238                    let mut rows = [[0.0f64; $d]; $d];
239                    for i in 0..$d {
240                        rows[i][i] = 1.0;
241                    }
242                    rows.swap(0, 1);
243
244                    let a = Matrix::<$d>::from_rows(black_box(rows));
245                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
246                        black_box(Matrix::<$d>::lu);
247                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
248
249                    // Row swap ⇒ determinant sign flip.
250                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
251                    assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
252                }
253            }
254        };
255    }
256
257    gen_public_api_pivoting_solve_vec_and_det_tests!(2);
258    gen_public_api_pivoting_solve_vec_and_det_tests!(3);
259    gen_public_api_pivoting_solve_vec_and_det_tests!(4);
260    gen_public_api_pivoting_solve_vec_and_det_tests!(5);
261
262    macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
263        ($d:literal) => {
264            paste! {
265                #[test]
266                fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
267                    // Public API path under test:
268                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
269
270                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
271                    #[allow(clippy::large_stack_arrays)]
272                    let mut rows = [[0.0f64; $d]; $d];
273                    for i in 0..$d {
274                        rows[i][i] = 2.0;
275                        if i > 0 {
276                            rows[i][i - 1] = -1.0;
277                        }
278                        if i + 1 < $d {
279                            rows[i][i + 1] = -1.0;
280                        }
281                    }
282
283                    let a = Matrix::<$d>::from_rows(black_box(rows));
284                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
285                        black_box(Matrix::<$d>::lu);
286                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
287
288                    // Choose x = 1, so b = A x is simple: [1, 0, 0, ..., 0, 1].
289                    let mut b_arr = [0.0f64; $d];
290                    b_arr[0] = 1.0;
291                    b_arr[$d - 1] = 1.0;
292                    let b = Vector::<$d>::new(black_box(b_arr));
293
294                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
295                        black_box(Lu::<$d>::solve_vec);
296                    let x = solve_fn(&lu, b).unwrap().into_array();
297
298                    for &x_i in &x {
299                        assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
300                    }
301                }
302
303                #[test]
304                fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
305                    // Public API path under test:
306                    // Matrix::lu (pub) -> Lu::det (pub).
307
308                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
309                    // Determinant is known exactly: det = D + 1.
310                    #[allow(clippy::large_stack_arrays)]
311                    let mut rows = [[0.0f64; $d]; $d];
312                    for i in 0..$d {
313                        rows[i][i] = 2.0;
314                        if i > 0 {
315                            rows[i][i - 1] = -1.0;
316                        }
317                        if i + 1 < $d {
318                            rows[i][i + 1] = -1.0;
319                        }
320                    }
321
322                    let a = Matrix::<$d>::from_rows(black_box(rows));
323                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
324                        black_box(Matrix::<$d>::lu);
325                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
326
327                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
328                    assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
329                }
330            }
331        };
332    }
333
334    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
335    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
336    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
337
338    #[test]
339    fn solve_1x1() {
340        let a = Matrix::<1>::from_rows(black_box([[2.0]]));
341        let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
342
343        let b = Vector::<1>::new(black_box([6.0]));
344        let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
345            black_box(Lu::<1>::solve_vec);
346        let x = solve_fn(&lu, b).unwrap().into_array();
347        assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
348
349        let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
350        assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
351    }
352
353    #[test]
354    fn solve_2x2_basic() {
355        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
356        let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
357        let b = Vector::<2>::new(black_box([5.0, 11.0]));
358
359        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
360            black_box(Lu::<2>::solve_vec);
361        let x = solve_fn(&lu, b).unwrap().into_array();
362
363        assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
364        assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
365    }
366
367    #[test]
368    fn det_2x2_basic() {
369        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
370        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
371
372        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
373        assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
374    }
375
376    #[test]
377    fn det_requires_pivot_sign() {
378        // Row swap ⇒ determinant sign flip.
379        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
380        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
381
382        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
383        assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
384    }
385
386    #[test]
387    fn solve_requires_pivoting() {
388        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
389        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
390        let b = Vector::<2>::new(black_box([1.0, 2.0]));
391
392        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
393            black_box(Lu::<2>::solve_vec);
394        let x = solve_fn(&lu, b).unwrap().into_array();
395
396        assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
397        assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
398    }
399
400    #[test]
401    fn singular_detected() {
402        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
403        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
404        assert_eq!(err, LaError::Singular { pivot_col: 1 });
405    }
406
407    #[test]
408    fn singular_due_to_tolerance_at_first_pivot() {
409        // Not exactly singular, but below DEFAULT_PIVOT_TOL.
410        let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
411        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
412        assert_eq!(err, LaError::Singular { pivot_col: 0 });
413    }
414
415    #[test]
416    fn nonfinite_detected_on_pivot_entry() {
417        let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
418        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
419        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
420    }
421
422    #[test]
423    fn nonfinite_detected_in_pivot_column_scan() {
424        let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
425        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
426        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
427    }
428
429    #[test]
430    fn solve_vec_nonfinite_forward_substitution_overflow() {
431        // L has a -1 multiplier, and a large RHS makes forward substitution overflow.
432        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
433        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
434
435        let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
436        let err = lu.solve_vec(b).unwrap_err();
437        assert_eq!(err, LaError::NonFinite { pivot_col: 1 });
438    }
439
440    #[test]
441    fn solve_vec_nonfinite_back_substitution_overflow() {
442        // Make x[1] overflow during back substitution, then ensure it is detected on the next row.
443        let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
444        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
445
446        let b = Vector::<2>::new([0.0, 1.0e300]);
447        let err = lu.solve_vec(b).unwrap_err();
448        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
449    }
450}