la_stack/
lu.rs

1//! LU decomposition and solves.
2
3use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7/// LU decomposition (PA = LU) with partial pivoting.
8#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11    factors: Matrix<D>,
12    piv: [usize; D],
13    piv_sign: f64,
14    tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18    #[inline]
19    pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20        let mut lu = a;
21
22        let mut piv = [0usize; D];
23        for (i, p) in piv.iter_mut().enumerate() {
24            *p = i;
25        }
26
27        let mut piv_sign = 1.0;
28
29        for k in 0..D {
30            // Choose pivot row.
31            let mut pivot_row = k;
32            let mut pivot_abs = lu.rows[k][k].abs();
33            if !pivot_abs.is_finite() {
34                return Err(LaError::NonFinite { pivot_col: k });
35            }
36
37            for r in (k + 1)..D {
38                let v = lu.rows[r][k].abs();
39                if !v.is_finite() {
40                    return Err(LaError::NonFinite { pivot_col: k });
41                }
42                if v > pivot_abs {
43                    pivot_abs = v;
44                    pivot_row = r;
45                }
46            }
47
48            if pivot_abs <= tol {
49                return Err(LaError::Singular { pivot_col: k });
50            }
51
52            if pivot_row != k {
53                lu.rows.swap(k, pivot_row);
54                piv.swap(k, pivot_row);
55                piv_sign = -piv_sign;
56            }
57
58            let pivot = lu.rows[k][k];
59            if !pivot.is_finite() {
60                return Err(LaError::NonFinite { pivot_col: k });
61            }
62
63            // Eliminate below pivot.
64            for r in (k + 1)..D {
65                let mult = lu.rows[r][k] / pivot;
66                if !mult.is_finite() {
67                    return Err(LaError::NonFinite { pivot_col: k });
68                }
69                lu.rows[r][k] = mult;
70
71                for c in (k + 1)..D {
72                    lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
73                }
74            }
75        }
76
77        Ok(Self {
78            factors: lu,
79            piv,
80            piv_sign,
81            tol,
82        })
83    }
84
85    /// Solve `A x = b` using this LU factorization.
86    ///
87    /// # Examples
88    /// ```
89    /// use la_stack::prelude::*;
90    ///
91    /// # fn main() -> Result<(), LaError> {
92    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
93    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
94    ///
95    /// let b = Vector::<2>::new([5.0, 11.0]);
96    /// let x = lu.solve_vec(b)?.into_array();
97    ///
98    /// assert!((x[0] - 1.0).abs() <= 1e-12);
99    /// assert!((x[1] - 2.0).abs() <= 1e-12);
100    /// # Ok(())
101    /// # }
102    /// ```
103    ///
104    /// # Errors
105    /// Returns [`LaError::Singular`] if a diagonal of `U` is (numerically) zero.
106    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected.
107    #[inline]
108    pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
109        let mut x = [0.0; D];
110        for (i, x_i) in x.iter_mut().enumerate() {
111            *x_i = b.data[self.piv[i]];
112        }
113
114        // Forward substitution for L (unit diagonal).
115        for i in 0..D {
116            let mut sum = x[i];
117            let row = self.factors.rows[i];
118            for (j, x_j) in x.iter().enumerate().take(i) {
119                sum = (-row[j]).mul_add(*x_j, sum);
120            }
121            if !sum.is_finite() {
122                return Err(LaError::NonFinite { pivot_col: i });
123            }
124            x[i] = sum;
125        }
126
127        // Back substitution for U.
128        for ii in 0..D {
129            let i = D - 1 - ii;
130            let mut sum = x[i];
131            let row = self.factors.rows[i];
132            for (j, x_j) in x.iter().enumerate().skip(i + 1) {
133                sum = (-row[j]).mul_add(*x_j, sum);
134            }
135
136            let diag = row[i];
137            if !diag.is_finite() || !sum.is_finite() {
138                return Err(LaError::NonFinite { pivot_col: i });
139            }
140            if diag.abs() <= self.tol {
141                return Err(LaError::Singular { pivot_col: i });
142            }
143
144            x[i] = sum / diag;
145        }
146
147        Ok(Vector::new(x))
148    }
149
150    /// Determinant of the original matrix.
151    ///
152    /// # Examples
153    /// ```
154    /// use la_stack::prelude::*;
155    ///
156    /// # fn main() -> Result<(), LaError> {
157    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
158    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
159    ///
160    /// let det = lu.det();
161    /// assert!((det - (-2.0)).abs() <= 1e-12);
162    /// # Ok(())
163    /// # }
164    /// ```
165    #[inline]
166    #[must_use]
167    pub fn det(&self) -> f64 {
168        let mut det = self.piv_sign;
169        for i in 0..D {
170            det *= self.factors.rows[i][i];
171        }
172        det
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::DEFAULT_PIVOT_TOL;
180
181    use core::hint::black_box;
182
183    use approx::assert_abs_diff_eq;
184    use pastey::paste;
185
186    macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
187        ($d:literal) => {
188            paste! {
189                #[test]
190                fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
191                    // Public API path under test:
192                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
193
194                    // Permutation matrix that swaps the first two basis vectors.
195                    // This forces pivoting in column 0 for any D >= 2.
196                    let mut rows = [[0.0f64; $d]; $d];
197                    for i in 0..$d {
198                        rows[i][i] = 1.0;
199                    }
200                    rows.swap(0, 1);
201
202                    let a = Matrix::<$d>::from_rows(black_box(rows));
203                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
204                        black_box(Matrix::<$d>::lu);
205                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
206
207                    // Pick a simple RHS with unique entries, so the expected swap is obvious.
208                    let b_arr = {
209                        let mut arr = [0.0f64; $d];
210                        let mut val = 1.0f64;
211                        for dst in arr.iter_mut() {
212                            *dst = val;
213                            val += 1.0;
214                        }
215                        arr
216                    };
217                    let mut expected = b_arr;
218                    expected.swap(0, 1);
219                    let b = Vector::<$d>::new(black_box(b_arr));
220
221                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
222                        black_box(Lu::<$d>::solve_vec);
223                    let x = solve_fn(&lu, b).unwrap().into_array();
224
225                    for i in 0..$d {
226                        assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
227                    }
228                }
229
230                #[test]
231                fn [<public_api_lu_det_pivoting_ $d d>]() {
232                    // Public API path under test:
233                    // Matrix::lu (pub) -> Lu::det (pub).
234
235                    // Permutation matrix that swaps the first two basis vectors.
236                    let mut rows = [[0.0f64; $d]; $d];
237                    for i in 0..$d {
238                        rows[i][i] = 1.0;
239                    }
240                    rows.swap(0, 1);
241
242                    let a = Matrix::<$d>::from_rows(black_box(rows));
243                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
244                        black_box(Matrix::<$d>::lu);
245                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
246
247                    // Row swap ⇒ determinant sign flip.
248                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
249                    assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
250                }
251            }
252        };
253    }
254
255    gen_public_api_pivoting_solve_vec_and_det_tests!(2);
256    gen_public_api_pivoting_solve_vec_and_det_tests!(3);
257    gen_public_api_pivoting_solve_vec_and_det_tests!(4);
258    gen_public_api_pivoting_solve_vec_and_det_tests!(5);
259
260    macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
261        ($d:literal) => {
262            paste! {
263                #[test]
264                fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
265                    // Public API path under test:
266                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
267
268                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
269                    #[allow(clippy::large_stack_arrays)]
270                    let mut rows = [[0.0f64; $d]; $d];
271                    for i in 0..$d {
272                        rows[i][i] = 2.0;
273                        if i > 0 {
274                            rows[i][i - 1] = -1.0;
275                        }
276                        if i + 1 < $d {
277                            rows[i][i + 1] = -1.0;
278                        }
279                    }
280
281                    let a = Matrix::<$d>::from_rows(black_box(rows));
282                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
283                        black_box(Matrix::<$d>::lu);
284                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
285
286                    // Choose x = 1, so b = A x is simple: [1, 0, 0, ..., 0, 1].
287                    let mut b_arr = [0.0f64; $d];
288                    b_arr[0] = 1.0;
289                    b_arr[$d - 1] = 1.0;
290                    let b = Vector::<$d>::new(black_box(b_arr));
291
292                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
293                        black_box(Lu::<$d>::solve_vec);
294                    let x = solve_fn(&lu, b).unwrap().into_array();
295
296                    for &x_i in &x {
297                        assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
298                    }
299                }
300
301                #[test]
302                fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
303                    // Public API path under test:
304                    // Matrix::lu (pub) -> Lu::det (pub).
305
306                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
307                    // Determinant is known exactly: det = D + 1.
308                    #[allow(clippy::large_stack_arrays)]
309                    let mut rows = [[0.0f64; $d]; $d];
310                    for i in 0..$d {
311                        rows[i][i] = 2.0;
312                        if i > 0 {
313                            rows[i][i - 1] = -1.0;
314                        }
315                        if i + 1 < $d {
316                            rows[i][i + 1] = -1.0;
317                        }
318                    }
319
320                    let a = Matrix::<$d>::from_rows(black_box(rows));
321                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
322                        black_box(Matrix::<$d>::lu);
323                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
324
325                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
326                    assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
327                }
328            }
329        };
330    }
331
332    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
333    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
334    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
335
336    #[test]
337    fn solve_1x1() {
338        let a = Matrix::<1>::from_rows(black_box([[2.0]]));
339        let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
340
341        let b = Vector::<1>::new(black_box([6.0]));
342        let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
343            black_box(Lu::<1>::solve_vec);
344        let x = solve_fn(&lu, b).unwrap().into_array();
345        assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
346
347        let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
348        assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
349    }
350
351    #[test]
352    fn solve_2x2_basic() {
353        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
354        let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
355        let b = Vector::<2>::new(black_box([5.0, 11.0]));
356
357        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
358            black_box(Lu::<2>::solve_vec);
359        let x = solve_fn(&lu, b).unwrap().into_array();
360
361        assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
362        assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
363    }
364
365    #[test]
366    fn det_2x2_basic() {
367        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
368        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
369
370        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
371        assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
372    }
373
374    #[test]
375    fn det_requires_pivot_sign() {
376        // Row swap ⇒ determinant sign flip.
377        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
378        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
379
380        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
381        assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
382    }
383
384    #[test]
385    fn solve_requires_pivoting() {
386        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
387        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
388        let b = Vector::<2>::new(black_box([1.0, 2.0]));
389
390        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
391            black_box(Lu::<2>::solve_vec);
392        let x = solve_fn(&lu, b).unwrap().into_array();
393
394        assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
395        assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
396    }
397
398    #[test]
399    fn singular_detected() {
400        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
401        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
402        assert_eq!(err, LaError::Singular { pivot_col: 1 });
403    }
404
405    #[test]
406    fn singular_due_to_tolerance_at_first_pivot() {
407        // Not exactly singular, but below DEFAULT_PIVOT_TOL.
408        let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
409        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
410        assert_eq!(err, LaError::Singular { pivot_col: 0 });
411    }
412
413    #[test]
414    fn nonfinite_detected_on_pivot_entry() {
415        let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
416        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
417        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
418    }
419
420    #[test]
421    fn nonfinite_detected_in_pivot_column_scan() {
422        let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
423        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
424        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
425    }
426
427    #[test]
428    fn solve_vec_nonfinite_forward_substitution_overflow() {
429        // L has a -1 multiplier, and a large RHS makes forward substitution overflow.
430        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
431        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
432
433        let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
434        let err = lu.solve_vec(b).unwrap_err();
435        assert_eq!(err, LaError::NonFinite { pivot_col: 1 });
436    }
437
438    #[test]
439    fn solve_vec_nonfinite_back_substitution_overflow() {
440        // Make x[1] overflow during back substitution, then ensure it is detected on the next row.
441        let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
442        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
443
444        let b = Vector::<2>::new([0.0, 1.0e300]);
445        let err = lu.solve_vec(b).unwrap_err();
446        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
447    }
448}