Skip to main content

la_stack/
lu.rs

1//! LU decomposition and solves.
2
3use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7/// LU decomposition (PA = LU) with partial pivoting.
8#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11    factors: Matrix<D>,
12    piv: [usize; D],
13    piv_sign: f64,
14    tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18    #[inline]
19    pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20        let mut lu = a;
21
22        let mut piv = [0usize; D];
23        for (i, p) in piv.iter_mut().enumerate() {
24            *p = i;
25        }
26
27        let mut piv_sign = 1.0;
28
29        for k in 0..D {
30            // Choose pivot row.
31            let mut pivot_row = k;
32            let mut pivot_abs = lu.rows[k][k].abs();
33            if !pivot_abs.is_finite() {
34                return Err(LaError::NonFinite {
35                    row: Some(k),
36                    col: k,
37                });
38            }
39
40            for r in (k + 1)..D {
41                let v = lu.rows[r][k].abs();
42                if !v.is_finite() {
43                    return Err(LaError::NonFinite {
44                        row: Some(r),
45                        col: k,
46                    });
47                }
48                if v > pivot_abs {
49                    pivot_abs = v;
50                    pivot_row = r;
51                }
52            }
53
54            if pivot_abs <= tol {
55                return Err(LaError::Singular { pivot_col: k });
56            }
57
58            if pivot_row != k {
59                lu.rows.swap(k, pivot_row);
60                piv.swap(k, pivot_row);
61                piv_sign = -piv_sign;
62            }
63
64            let pivot = lu.rows[k][k];
65            if !pivot.is_finite() {
66                return Err(LaError::NonFinite {
67                    row: Some(k),
68                    col: k,
69                });
70            }
71
72            // Eliminate below pivot.
73            for r in (k + 1)..D {
74                let mult = lu.rows[r][k] / pivot;
75                if !mult.is_finite() {
76                    return Err(LaError::NonFinite {
77                        row: Some(r),
78                        col: k,
79                    });
80                }
81                lu.rows[r][k] = mult;
82
83                for c in (k + 1)..D {
84                    lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
85                }
86            }
87        }
88
89        Ok(Self {
90            factors: lu,
91            piv,
92            piv_sign,
93            tol,
94        })
95    }
96
97    /// Solve `A x = b` using this LU factorization.
98    ///
99    /// # Examples
100    /// ```
101    /// use la_stack::prelude::*;
102    ///
103    /// # fn main() -> Result<(), LaError> {
104    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
105    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
106    ///
107    /// let b = Vector::<2>::new([5.0, 11.0]);
108    /// let x = lu.solve_vec(b)?.into_array();
109    ///
110    /// assert!((x[0] - 1.0).abs() <= 1e-12);
111    /// assert!((x[1] - 2.0).abs() <= 1e-12);
112    /// # Ok(())
113    /// # }
114    /// ```
115    ///
116    /// # Errors
117    /// Returns [`LaError::Singular`] if a diagonal entry of `U` satisfies `|u_ii| <= tol`, where
118    /// `tol` is the tolerance that was used during factorization.
119    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected.
120    #[inline]
121    pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
122        let mut x = [0.0; D];
123        for (i, x_i) in x.iter_mut().enumerate() {
124            *x_i = b.data[self.piv[i]];
125        }
126
127        // Forward substitution for L (unit diagonal).
128        for i in 0..D {
129            let mut sum = x[i];
130            let row = self.factors.rows[i];
131            for (j, x_j) in x.iter().enumerate().take(i) {
132                sum = (-row[j]).mul_add(*x_j, sum);
133            }
134            if !sum.is_finite() {
135                return Err(LaError::NonFinite { row: None, col: i });
136            }
137            x[i] = sum;
138        }
139
140        // Back substitution for U.
141        for ii in 0..D {
142            let i = D - 1 - ii;
143            let mut sum = x[i];
144            let row = self.factors.rows[i];
145            for (j, x_j) in x.iter().enumerate().skip(i + 1) {
146                sum = (-row[j]).mul_add(*x_j, sum);
147            }
148
149            let diag = row[i];
150            if !diag.is_finite() || !sum.is_finite() {
151                return Err(LaError::NonFinite { row: None, col: i });
152            }
153            if diag.abs() <= self.tol {
154                return Err(LaError::Singular { pivot_col: i });
155            }
156
157            let q = sum / diag;
158            if !q.is_finite() {
159                return Err(LaError::NonFinite { row: None, col: i });
160            }
161            x[i] = q;
162        }
163
164        Ok(Vector::new(x))
165    }
166
167    /// Determinant of the original matrix.
168    ///
169    /// # Examples
170    /// ```
171    /// use la_stack::prelude::*;
172    ///
173    /// # fn main() -> Result<(), LaError> {
174    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
175    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
176    ///
177    /// let det = lu.det();
178    /// assert!((det - (-2.0)).abs() <= 1e-12);
179    /// # Ok(())
180    /// # }
181    /// ```
182    #[inline]
183    #[must_use]
184    pub fn det(&self) -> f64 {
185        let mut det = self.piv_sign;
186        for i in 0..D {
187            det *= self.factors.rows[i][i];
188        }
189        det
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::DEFAULT_PIVOT_TOL;
197
198    use core::hint::black_box;
199
200    use approx::assert_abs_diff_eq;
201    use pastey::paste;
202
203    macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
204        ($d:literal) => {
205            paste! {
206                #[test]
207                fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
208                    // Public API path under test:
209                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
210
211                    // Permutation matrix that swaps the first two basis vectors.
212                    // This forces pivoting in column 0 for any D >= 2.
213                    let mut rows = [[0.0f64; $d]; $d];
214                    for i in 0..$d {
215                        rows[i][i] = 1.0;
216                    }
217                    rows.swap(0, 1);
218
219                    let a = Matrix::<$d>::from_rows(black_box(rows));
220                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
221                        black_box(Matrix::<$d>::lu);
222                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
223
224                    // Pick a simple RHS with unique entries, so the expected swap is obvious.
225                    let b_arr = {
226                        let mut arr = [0.0f64; $d];
227                        let mut val = 1.0f64;
228                        for dst in arr.iter_mut() {
229                            *dst = val;
230                            val += 1.0;
231                        }
232                        arr
233                    };
234                    let mut expected = b_arr;
235                    expected.swap(0, 1);
236                    let b = Vector::<$d>::new(black_box(b_arr));
237
238                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
239                        black_box(Lu::<$d>::solve_vec);
240                    let x = solve_fn(&lu, b).unwrap().into_array();
241
242                    for i in 0..$d {
243                        assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
244                    }
245                }
246
247                #[test]
248                fn [<public_api_lu_det_pivoting_ $d d>]() {
249                    // Public API path under test:
250                    // Matrix::lu (pub) -> Lu::det (pub).
251
252                    // Permutation matrix that swaps the first two basis vectors.
253                    let mut rows = [[0.0f64; $d]; $d];
254                    for i in 0..$d {
255                        rows[i][i] = 1.0;
256                    }
257                    rows.swap(0, 1);
258
259                    let a = Matrix::<$d>::from_rows(black_box(rows));
260                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
261                        black_box(Matrix::<$d>::lu);
262                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
263
264                    // Row swap ⇒ determinant sign flip.
265                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
266                    assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
267                }
268            }
269        };
270    }
271
272    gen_public_api_pivoting_solve_vec_and_det_tests!(2);
273    gen_public_api_pivoting_solve_vec_and_det_tests!(3);
274    gen_public_api_pivoting_solve_vec_and_det_tests!(4);
275    gen_public_api_pivoting_solve_vec_and_det_tests!(5);
276
277    macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
278        ($d:literal) => {
279            paste! {
280                #[test]
281                fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
282                    // Public API path under test:
283                    // Matrix::lu (pub) -> Lu::solve_vec (pub).
284
285                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
286                    #[allow(clippy::large_stack_arrays)]
287                    let mut rows = [[0.0f64; $d]; $d];
288                    for i in 0..$d {
289                        rows[i][i] = 2.0;
290                        if i > 0 {
291                            rows[i][i - 1] = -1.0;
292                        }
293                        if i + 1 < $d {
294                            rows[i][i + 1] = -1.0;
295                        }
296                    }
297
298                    let a = Matrix::<$d>::from_rows(black_box(rows));
299                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
300                        black_box(Matrix::<$d>::lu);
301                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
302
303                    // Choose x = 1, so b = A x is simple: [1, 0, 0, ..., 0, 1].
304                    let mut b_arr = [0.0f64; $d];
305                    b_arr[0] = 1.0;
306                    b_arr[$d - 1] = 1.0;
307                    let b = Vector::<$d>::new(black_box(b_arr));
308
309                    let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
310                        black_box(Lu::<$d>::solve_vec);
311                    let x = solve_fn(&lu, b).unwrap().into_array();
312
313                    for &x_i in &x {
314                        assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
315                    }
316                }
317
318                #[test]
319                fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
320                    // Public API path under test:
321                    // Matrix::lu (pub) -> Lu::det (pub).
322
323                    // Classic SPD tridiagonal: 2 on diagonal, -1 on sub/super-diagonals.
324                    // Determinant is known exactly: det = D + 1.
325                    #[allow(clippy::large_stack_arrays)]
326                    let mut rows = [[0.0f64; $d]; $d];
327                    for i in 0..$d {
328                        rows[i][i] = 2.0;
329                        if i > 0 {
330                            rows[i][i - 1] = -1.0;
331                        }
332                        if i + 1 < $d {
333                            rows[i][i + 1] = -1.0;
334                        }
335                    }
336
337                    let a = Matrix::<$d>::from_rows(black_box(rows));
338                    let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
339                        black_box(Matrix::<$d>::lu);
340                    let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
341
342                    let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
343                    assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
344                }
345            }
346        };
347    }
348
349    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
350    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
351    gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
352
353    #[test]
354    fn solve_1x1() {
355        let a = Matrix::<1>::from_rows(black_box([[2.0]]));
356        let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
357
358        let b = Vector::<1>::new(black_box([6.0]));
359        let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
360            black_box(Lu::<1>::solve_vec);
361        let x = solve_fn(&lu, b).unwrap().into_array();
362        assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
363
364        let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
365        assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
366    }
367
368    #[test]
369    fn solve_2x2_basic() {
370        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
371        let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
372        let b = Vector::<2>::new(black_box([5.0, 11.0]));
373
374        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
375            black_box(Lu::<2>::solve_vec);
376        let x = solve_fn(&lu, b).unwrap().into_array();
377
378        assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
379        assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
380    }
381
382    #[test]
383    fn det_2x2_basic() {
384        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
385        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
386
387        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
388        assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
389    }
390
391    #[test]
392    fn det_requires_pivot_sign() {
393        // Row swap ⇒ determinant sign flip.
394        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
395        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
396
397        let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
398        assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
399    }
400
401    #[test]
402    fn solve_requires_pivoting() {
403        let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
404        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
405        let b = Vector::<2>::new(black_box([1.0, 2.0]));
406
407        let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
408            black_box(Lu::<2>::solve_vec);
409        let x = solve_fn(&lu, b).unwrap().into_array();
410
411        assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
412        assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
413    }
414
415    #[test]
416    fn singular_detected() {
417        let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
418        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
419        assert_eq!(err, LaError::Singular { pivot_col: 1 });
420    }
421
422    #[test]
423    fn singular_due_to_tolerance_at_first_pivot() {
424        // Not exactly singular, but below DEFAULT_PIVOT_TOL.
425        let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
426        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
427        assert_eq!(err, LaError::Singular { pivot_col: 0 });
428    }
429
430    #[test]
431    fn nonfinite_detected_on_pivot_entry() {
432        let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
433        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
434        assert_eq!(
435            err,
436            LaError::NonFinite {
437                row: Some(0),
438                col: 0
439            }
440        );
441    }
442
443    #[test]
444    fn nonfinite_detected_in_pivot_column_scan() {
445        let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
446        let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
447        assert_eq!(
448            err,
449            LaError::NonFinite {
450                row: Some(1),
451                col: 0
452            }
453        );
454    }
455
456    #[test]
457    fn solve_vec_nonfinite_forward_substitution_overflow() {
458        // L has a -1 multiplier, and a large RHS makes forward substitution overflow.
459        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
460        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
461
462        let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
463        let err = lu.solve_vec(b).unwrap_err();
464        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
465    }
466
467    #[test]
468    fn solve_vec_nonfinite_back_substitution_overflow() {
469        // Make x[1] overflow during back substitution, then ensure it is detected on the next row.
470        let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
471        let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
472
473        let b = Vector::<2>::new([0.0, 1.0e300]);
474        let err = lu.solve_vec(b).unwrap_err();
475        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
476    }
477}