la_stack/
lu.rs

1//! LU decomposition and solves.
2
3use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7/// LU decomposition (PA = LU) with partial pivoting.
8#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11    factors: Matrix<D>,
12    piv: [usize; D],
13    piv_sign: f64,
14    tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18    #[inline]
19    pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20        let mut lu = a;
21
22        let mut piv = [0usize; D];
23        for (i, p) in piv.iter_mut().enumerate() {
24            *p = i;
25        }
26
27        let mut piv_sign = 1.0;
28
29        for k in 0..D {
30            // Choose pivot row.
31            let mut pivot_row = k;
32            let mut pivot_abs = lu.rows[k][k].abs();
33            if !pivot_abs.is_finite() {
34                return Err(LaError::NonFinite { pivot_col: k });
35            }
36
37            for r in (k + 1)..D {
38                let v = lu.rows[r][k].abs();
39                if !v.is_finite() {
40                    return Err(LaError::NonFinite { pivot_col: k });
41                }
42                if v > pivot_abs {
43                    pivot_abs = v;
44                    pivot_row = r;
45                }
46            }
47
48            if pivot_abs <= tol {
49                return Err(LaError::Singular { pivot_col: k });
50            }
51
52            if pivot_row != k {
53                lu.rows.swap(k, pivot_row);
54                piv.swap(k, pivot_row);
55                piv_sign = -piv_sign;
56            }
57
58            let pivot = lu.rows[k][k];
59            if !pivot.is_finite() {
60                return Err(LaError::NonFinite { pivot_col: k });
61            }
62
63            // Eliminate below pivot.
64            for r in (k + 1)..D {
65                let mult = lu.rows[r][k] / pivot;
66                if !mult.is_finite() {
67                    return Err(LaError::NonFinite { pivot_col: k });
68                }
69                lu.rows[r][k] = mult;
70
71                for c in (k + 1)..D {
72                    lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
73                }
74            }
75        }
76
77        Ok(Self {
78            factors: lu,
79            piv,
80            piv_sign,
81            tol,
82        })
83    }
84
85    /// Solve `A x = b` using this LU factorization.
86    ///
87    /// # Examples
88    /// ```
89    /// use la_stack::prelude::*;
90    ///
91    /// # fn main() -> Result<(), LaError> {
92    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
93    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
94    ///
95    /// let b = Vector::<2>::new([5.0, 11.0]);
96    /// let x = lu.solve_vec(b)?.into_array();
97    ///
98    /// assert!((x[0] - 1.0).abs() <= 1e-12);
99    /// assert!((x[1] - 2.0).abs() <= 1e-12);
100    /// # Ok(())
101    /// # }
102    /// ```
103    ///
104    /// # Errors
105    /// Returns [`LaError::Singular`] if a diagonal entry of `U` satisfies `|u_ii| <= tol`, where
106    /// `tol` is the tolerance that was used during factorization.
107    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected.
108    #[inline]
109    pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
110        let mut x = [0.0; D];
111        for (i, x_i) in x.iter_mut().enumerate() {
112            *x_i = b.data[self.piv[i]];
113        }
114
115        // Forward substitution for L (unit diagonal).
116        for i in 0..D {
117            let mut sum = x[i];
118            let row = self.factors.rows[i];
119            for (j, x_j) in x.iter().enumerate().take(i) {
120                sum = (-row[j]).mul_add(*x_j, sum);
121            }
122            if !sum.is_finite() {
123                return Err(LaError::NonFinite { pivot_col: i });
124            }
125            x[i] = sum;
126        }
127
128        // Back substitution for U.
129        for ii in 0..D {
130            let i = D - 1 - ii;
131            let mut sum = x[i];
132            let row = self.factors.rows[i];
133            for (j, x_j) in x.iter().enumerate().skip(i + 1) {
134                sum = (-row[j]).mul_add(*x_j, sum);
135            }
136
137            let diag = row[i];
138            if !diag.is_finite() || !sum.is_finite() {
139                return Err(LaError::NonFinite { pivot_col: i });
140            }
141            if diag.abs() <= self.tol {
142                return Err(LaError::Singular { pivot_col: i });
143            }
144
145            x[i] = sum / diag;
146        }
147
148        Ok(Vector::new(x))
149    }
150
151    /// Determinant of the original matrix.
152    ///
153    /// # Examples
154    /// ```
155    /// use la_stack::prelude::*;
156    ///
157    /// # fn main() -> Result<(), LaError> {
158    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
159    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
160    ///
161    /// let det = lu.det();
162    /// assert!((det - (-2.0)).abs() <= 1e-12);
163    /// # Ok(())
164    /// # }
165    /// ```
166    #[inline]
167    #[must_use]
168    pub fn det(&self) -> f64 {
169        let mut det = self.piv_sign;
170        for i in 0..D {
171            det *= self.factors.rows[i][i];
172        }
173        det
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::DEFAULT_PIVOT_TOL;
181
182    use core::hint::black_box;
183
184    use approx::assert_abs_diff_eq;
185    use pastey::paste;
186
187    macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
188        ($d:literal) => {
189            paste! {
190                #[test]
191                fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
192                    // Public API path under test:
193                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
194
195                    // Permutation matrix that swaps the first two basis vectors.
196                    // This forces pivoting in column 0 for any D >= 2.
197                    let mut rows = [[0.0f64; $d]; $d];
198                    for i in 0..$d {
199                        rows[i][i] = 1.0;
200                    }
201                    rows.swap(0, 1);
202
203                    let a = Matrix::<$d>::from_rows(black_box(rows));
204                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
205                        black_box(Matrix::<$d>::lu);
206                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
207
208                    // Pick a simple RHS with unique entries, so the expected swap is obvious.
209                    let b_arr = {
210                        let mut arr = [0.0f64; $d];
211                        let mut val = 1.0f64;
212                        for dst in arr.iter_mut() {
213                            *dst = val;
214                            val += 1.0;
215                        }
216                        arr
217                    };
218                    let mut expected = b_arr;
219                    expected.swap(0, 1);
220                    let b = Vector::<$d>::new(black_box(b_arr));
221
222                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
223                        black_box(Lu::<$d>::solve_vec);
224                    let x = solve_fn(&lu, b).unwrap().into_array();
225
226                    for i in 0..$d {
227                        assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
228                    }
229                }
230
231                #[test]
232                fn [<public_api_lu_det_pivoting_ $d d>]() {
233                    // Public API path under test:
234                    // Matrix::lu (pub) -> Lu::det (pub).
235
236                    // Permutation matrix that swaps the first two basis vectors.
237                    let mut rows = [[0.0f64; $d]; $d];
238                    for i in 0..$d {
239                        rows[i][i] = 1.0;
240                    }
241                    rows.swap(0, 1);
242
243                    let a = Matrix::<$d>::from_rows(black_box(rows));
244                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
245                        black_box(Matrix::<$d>::lu);
246                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
247
248                    // Row swap ⇒ determinant sign flip.
249                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
250                    assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
251                }
252            }
253        };
254    }
255
256    gen_public_api_pivoting_solve_vec_and_det_tests!(2);
257    gen_public_api_pivoting_solve_vec_and_det_tests!(3);
258    gen_public_api_pivoting_solve_vec_and_det_tests!(4);
259    gen_public_api_pivoting_solve_vec_and_det_tests!(5);
260
261    macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
262        ($d:literal) => {
263            paste! {
264                #[test]
265                fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
266                    // Public API path under test:
267                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
268
269                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
270                    #[allow(clippy::large_stack_arrays)]
271                    let mut rows = [[0.0f64; $d]; $d];
272                    for i in 0..$d {
273                        rows[i][i] = 2.0;
274                        if i > 0 {
275                            rows[i][i - 1] = -1.0;
276                        }
277                        if i + 1 < $d {
278                            rows[i][i + 1] = -1.0;
279                        }
280                    }
281
282                    let a = Matrix::<$d>::from_rows(black_box(rows));
283                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
284                        black_box(Matrix::<$d>::lu);
285                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
286
287                    // Choose x = 1, so b = A x is simple: [1, 0, 0, ..., 0, 1].
288                    let mut b_arr = [0.0f64; $d];
289                    b_arr[0] = 1.0;
290                    b_arr[$d - 1] = 1.0;
291                    let b = Vector::<$d>::new(black_box(b_arr));
292
293                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
294                        black_box(Lu::<$d>::solve_vec);
295                    let x = solve_fn(&lu, b).unwrap().into_array();
296
297                    for &x_i in &x {
298                        assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
299                    }
300                }
301
302                #[test]
303                fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
304                    // Public API path under test:
305                    // Matrix::lu (pub) -> Lu::det (pub).
306
307                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
308                    // Determinant is known exactly: det = D + 1.
309                    #[allow(clippy::large_stack_arrays)]
310                    let mut rows = [[0.0f64; $d]; $d];
311                    for i in 0..$d {
312                        rows[i][i] = 2.0;
313                        if i > 0 {
314                            rows[i][i - 1] = -1.0;
315                        }
316                        if i + 1 < $d {
317                            rows[i][i + 1] = -1.0;
318                        }
319                    }
320
321                    let a = Matrix::<$d>::from_rows(black_box(rows));
322                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
323                        black_box(Matrix::<$d>::lu);
324                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
325
326                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
327                    assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
328                }
329            }
330        };
331    }
332
333    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
334    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
335    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
336
337    #[test]
338    fn solve_1x1() {
339        let a = Matrix::<1>::from_rows(black_box([[2.0]]));
340        let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
341
342        let b = Vector::<1>::new(black_box([6.0]));
343        let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
344            black_box(Lu::<1>::solve_vec);
345        let x = solve_fn(&lu, b).unwrap().into_array();
346        assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
347
348        let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
349        assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
350    }
351
352    #[test]
353    fn solve_2x2_basic() {
354        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
355        let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
356        let b = Vector::<2>::new(black_box([5.0, 11.0]));
357
358        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
359            black_box(Lu::<2>::solve_vec);
360        let x = solve_fn(&lu, b).unwrap().into_array();
361
362        assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
363        assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
364    }
365
366    #[test]
367    fn det_2x2_basic() {
368        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
369        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
370
371        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
372        assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
373    }
374
375    #[test]
376    fn det_requires_pivot_sign() {
377        // Row swap ⇒ determinant sign flip.
378        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
379        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
380
381        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
382        assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
383    }
384
385    #[test]
386    fn solve_requires_pivoting() {
387        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
388        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
389        let b = Vector::<2>::new(black_box([1.0, 2.0]));
390
391        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
392            black_box(Lu::<2>::solve_vec);
393        let x = solve_fn(&lu, b).unwrap().into_array();
394
395        assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
396        assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
397    }
398
399    #[test]
400    fn singular_detected() {
401        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
402        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
403        assert_eq!(err, LaError::Singular { pivot_col: 1 });
404    }
405
406    #[test]
407    fn singular_due_to_tolerance_at_first_pivot() {
408        // Not exactly singular, but below DEFAULT_PIVOT_TOL.
409        let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
410        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
411        assert_eq!(err, LaError::Singular { pivot_col: 0 });
412    }
413
414    #[test]
415    fn nonfinite_detected_on_pivot_entry() {
416        let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
417        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
418        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
419    }
420
421    #[test]
422    fn nonfinite_detected_in_pivot_column_scan() {
423        let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
424        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
425        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
426    }
427
428    #[test]
429    fn solve_vec_nonfinite_forward_substitution_overflow() {
430        // L has a -1 multiplier, and a large RHS makes forward substitution overflow.
431        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
432        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
433
434        let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
435        let err = lu.solve_vec(b).unwrap_err();
436        assert_eq!(err, LaError::NonFinite { pivot_col: 1 });
437    }
438
439    #[test]
440    fn solve_vec_nonfinite_back_substitution_overflow() {
441        // Make x[1] overflow during back substitution, then ensure it is detected on the next row.
442        let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
443        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
444
445        let b = Vector::<2>::new([0.0, 1.0e300]);
446        let err = lu.solve_vec(b).unwrap_err();
447        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
448    }
449}