Skip to main content

la_stack/
ldlt.rs

1//! LDLT factorization and solves.
2//!
3//! This module provides a stack-allocated LDLT factorization (`A = L D Lᵀ`) intended for
4//! symmetric positive definite (SPD) and positive semi-definite (PSD) matrices (e.g. Gram
5//! matrices) without pivoting.
6
7use crate::LaError;
8use crate::matrix::Matrix;
9use crate::vector::Vector;
10
11/// LDLT factorization (`A = L D Lᵀ`) for symmetric positive (semi)definite matrices.
12///
13/// This factorization is **not** a general-purpose symmetric-indefinite LDLT (no pivoting).
14/// It assumes the input matrix is symmetric and (numerically) SPD/PSD.
15///
16/// # Storage
17/// The factors are stored in a single [`Matrix`]:
18/// - `D` is stored on the diagonal.
19/// - The strict lower triangle stores the multipliers of `L`.
20/// - The diagonal of `L` is implicit ones.
21#[must_use]
22#[derive(Clone, Copy, Debug, PartialEq)]
23pub struct Ldlt<const D: usize> {
24    factors: Matrix<D>,
25    tol: f64,
26}
27
28impl<const D: usize> Ldlt<D> {
29    #[inline]
30    pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
31        debug_assert!(tol >= 0.0, "tol must be non-negative");
32
33        #[cfg(debug_assertions)]
34        debug_assert_symmetric(&a);
35
36        let mut f = a;
37
38        // LDLT via symmetric rank-1 updates, using only the lower triangle.
39        for j in 0..D {
40            let d = f.rows[j][j];
41            if !d.is_finite() {
42                return Err(LaError::NonFinite {
43                    row: Some(j),
44                    col: j,
45                });
46            }
47            if d <= tol {
48                return Err(LaError::Singular { pivot_col: j });
49            }
50
51            // Compute L multipliers below the diagonal in column j.
52            for i in (j + 1)..D {
53                let l = f.rows[i][j] / d;
54                if !l.is_finite() {
55                    return Err(LaError::NonFinite {
56                        row: Some(i),
57                        col: j,
58                    });
59                }
60                f.rows[i][j] = l;
61            }
62
63            // Update the trailing submatrix (lower triangle): A := A - (L_col * d) * L_col^T.
64            for i in (j + 1)..D {
65                let l_i = f.rows[i][j];
66                let l_i_d = l_i * d;
67
68                for k in (j + 1)..=i {
69                    let l_k = f.rows[k][j];
70                    let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]);
71                    if !new_val.is_finite() {
72                        return Err(LaError::NonFinite {
73                            row: Some(i),
74                            col: k,
75                        });
76                    }
77                    f.rows[i][k] = new_val;
78                }
79            }
80        }
81
82        Ok(Self { factors: f, tol })
83    }
84
85    /// Determinant of the original matrix.
86    ///
87    /// For SPD/PSD matrices, this is the product of the diagonal terms of `D`.
88    ///
89    /// # Examples
90    /// ```
91    /// use la_stack::prelude::*;
92    ///
93    /// // Symmetric SPD matrix.
94    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
95    /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
96    ///
97    /// assert!((ldlt.det() - 8.0).abs() <= 1e-12);
98    /// ```
99    #[inline]
100    #[must_use]
101    pub fn det(&self) -> f64 {
102        let mut det = 1.0;
103        for i in 0..D {
104            det *= self.factors.rows[i][i];
105        }
106        det
107    }
108
109    /// Solve `A x = b` using this LDLT factorization.
110    ///
111    /// # Examples
112    /// ```
113    /// use la_stack::prelude::*;
114    ///
115    /// # fn main() -> Result<(), LaError> {
116    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
117    /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?;
118    ///
119    /// let b = Vector::<2>::new([1.0, 2.0]);
120    /// let x = ldlt.solve_vec(b)?.into_array();
121    ///
122    /// assert!((x[0] - (-0.125)).abs() <= 1e-12);
123    /// assert!((x[1] - 0.75).abs() <= 1e-12);
124    /// # Ok(())
125    /// # }
126    /// ```
127    ///
128    /// # Errors
129    /// Returns [`LaError::Singular`] if a diagonal entry `d = D[i,i]` satisfies `d <= tol`
130    /// (non-positive or too small), where `tol` is the tolerance that was used during factorization.
131    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected.
132    #[inline]
133    pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
134        let mut x = b.data;
135
136        // Forward substitution: L y = b (L has unit diagonal).
137        for i in 0..D {
138            let mut sum = x[i];
139            let row = self.factors.rows[i];
140            for (j, x_j) in x.iter().enumerate().take(i) {
141                sum = (-row[j]).mul_add(*x_j, sum);
142            }
143            if !sum.is_finite() {
144                return Err(LaError::NonFinite { row: None, col: i });
145            }
146            x[i] = sum;
147        }
148
149        // Diagonal solve: D z = y.
150        for (i, x_i) in x.iter_mut().enumerate().take(D) {
151            let diag = self.factors.rows[i][i];
152            if !diag.is_finite() {
153                return Err(LaError::NonFinite { row: None, col: i });
154            }
155            if diag <= self.tol {
156                return Err(LaError::Singular { pivot_col: i });
157            }
158
159            let v = *x_i / diag;
160            if !v.is_finite() {
161                return Err(LaError::NonFinite { row: None, col: i });
162            }
163            *x_i = v;
164        }
165
166        // Back substitution: Lᵀ x = z.
167        for ii in 0..D {
168            let i = D - 1 - ii;
169            let mut sum = x[i];
170            for (j, x_j) in x.iter().enumerate().skip(i + 1) {
171                sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum);
172            }
173            if !sum.is_finite() {
174                return Err(LaError::NonFinite { row: None, col: i });
175            }
176            x[i] = sum;
177        }
178
179        Ok(Vector::new(x))
180    }
181}
182
183#[cfg(debug_assertions)]
184fn debug_assert_symmetric<const D: usize>(a: &Matrix<D>) {
185    let scale = a.inf_norm().max(1.0);
186    let eps = 1e-12 * scale;
187
188    for r in 0..D {
189        for c in (r + 1)..D {
190            let diff = (a.rows[r][c] - a.rows[c][r]).abs();
191            debug_assert!(
192                diff <= eps,
193                "matrix must be symmetric (diff={diff}, eps={eps}) at ({r}, {c})"
194            );
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    use crate::DEFAULT_SINGULAR_TOL;
204
205    use core::hint::black_box;
206
207    use approx::assert_abs_diff_eq;
208    use pastey::paste;
209
210    macro_rules! gen_public_api_ldlt_identity_tests {
211        ($d:literal) => {
212            paste! {
213                #[test]
214                fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
215                    let a = Matrix::<$d>::identity();
216                    let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
217
218                    assert_abs_diff_eq!(ldlt.det(), 1.0, epsilon = 1e-12);
219
220                    let b_arr = {
221                        let mut arr = [0.0f64; $d];
222                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
223                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
224                            *dst = *src;
225                        }
226                        arr
227                    };
228                    let b = Vector::<$d>::new(black_box(b_arr));
229                    let x = ldlt.solve_vec(b).unwrap().into_array();
230
231                    for i in 0..$d {
232                        assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
233                    }
234                }
235            }
236        };
237    }
238
239    gen_public_api_ldlt_identity_tests!(2);
240    gen_public_api_ldlt_identity_tests!(3);
241    gen_public_api_ldlt_identity_tests!(4);
242    gen_public_api_ldlt_identity_tests!(5);
243
244    macro_rules! gen_public_api_ldlt_diagonal_tests {
245        ($d:literal) => {
246            paste! {
247                #[test]
248                fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
249                    let diag = {
250                        let mut arr = [0.0f64; $d];
251                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
252                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
253                            *dst = *src;
254                        }
255                        arr
256                    };
257
258                    let mut rows = [[0.0f64; $d]; $d];
259                    for i in 0..$d {
260                        rows[i][i] = diag[i];
261                    }
262
263                    let a = Matrix::<$d>::from_rows(black_box(rows));
264                    let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
265
266                    let expected_det = {
267                        let mut acc = 1.0;
268                        for i in 0..$d {
269                            acc *= diag[i];
270                        }
271                        acc
272                    };
273                    assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-12);
274
275                    let b_arr = {
276                        let mut arr = [0.0f64; $d];
277                        let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
278                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
279                            *dst = *src;
280                        }
281                        arr
282                    };
283
284                    let b = Vector::<$d>::new(black_box(b_arr));
285                    let x = ldlt.solve_vec(b).unwrap().into_array();
286
287                    for i in 0..$d {
288                        assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
289                    }
290                }
291            }
292        };
293    }
294
295    gen_public_api_ldlt_diagonal_tests!(2);
296    gen_public_api_ldlt_diagonal_tests!(3);
297    gen_public_api_ldlt_diagonal_tests!(4);
298    gen_public_api_ldlt_diagonal_tests!(5);
299
300    #[test]
301    fn solve_2x2_known_spd() {
302        let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]]));
303        let ldlt = (black_box(Ldlt::<2>::factor))(a, DEFAULT_SINGULAR_TOL).unwrap();
304
305        let b = Vector::<2>::new(black_box([1.0, 2.0]));
306        let x = ldlt.solve_vec(b).unwrap().into_array();
307
308        assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
309        assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
310        assert_abs_diff_eq!(ldlt.det(), 8.0, epsilon = 1e-12);
311    }
312
313    #[test]
314    fn solve_3x3_spd_tridiagonal_smoke() {
315        let a = Matrix::<3>::from_rows(black_box([
316            [2.0, -1.0, 0.0],
317            [-1.0, 2.0, -1.0],
318            [0.0, -1.0, 2.0],
319        ]));
320        let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
321
322        // Choose x = 1 so b = A x is simple: [1, 0, 1].
323        let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
324        let x = ldlt.solve_vec(b).unwrap().into_array();
325
326        for &x_i in &x {
327            assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
328        }
329    }
330
331    #[test]
332    fn singular_detected_for_degenerate_psd() {
333        // Rank-1 Gram-like matrix.
334        let a = Matrix::<2>::from_rows(black_box([[1.0, 1.0], [1.0, 1.0]]));
335        let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
336        assert_eq!(err, LaError::Singular { pivot_col: 1 });
337    }
338
339    #[test]
340    fn nonfinite_detected() {
341        let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
342        let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
343        assert_eq!(
344            err,
345            LaError::NonFinite {
346                row: Some(0),
347                col: 0
348            }
349        );
350    }
351
352    #[test]
353    fn nonfinite_l_multiplier_overflow() {
354        // d = 1e-11 > tol, but l = 1e300 / 1e-11 = 1e311 overflows f64.
355        let a = Matrix::<2>::from_rows([[1e-11, 1e300], [1e300, 1.0]]);
356        let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
357        assert_eq!(
358            err,
359            LaError::NonFinite {
360                row: Some(1),
361                col: 0
362            }
363        );
364    }
365
366    #[test]
367    fn nonfinite_trailing_submatrix_overflow() {
368        // L multiplier is finite (1e200), but the rank-1 update
369        // (-1e200 * 1.0) * 1e200 + 1.0 overflows.
370        let a = Matrix::<2>::from_rows([[1.0, 1e200], [1e200, 1.0]]);
371        let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
372        assert_eq!(
373            err,
374            LaError::NonFinite {
375                row: Some(1),
376                col: 1
377            }
378        );
379    }
380
381    #[test]
382    fn nonfinite_solve_vec_forward_substitution_overflow() {
383        // SPD matrix with large L multiplier: L[1,0] = 1e153.
384        // Forward substitution overflows: y[1] = 0 - 1e153 * 1e156 = -inf.
385        let a = Matrix::<3>::from_rows([
386            [1.0, 1e153, 0.0],
387            [1e153, 1e306 + 1.0, 0.0],
388            [0.0, 0.0, 1.0],
389        ]);
390        let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
391
392        let b = Vector::<3>::new([1e156, 0.0, 0.0]);
393        let err = ldlt.solve_vec(b).unwrap_err();
394        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
395    }
396
397    #[test]
398    fn nonfinite_solve_vec_back_substitution_overflow() {
399        // SPD matrix: [[1,0,0],[0,1,2],[0,2,5]] has LDLT factors
400        // D=[1,1,1], L[2,1]=2.  Forward sub and diagonal solve produce
401        // z=[0,0,1e308].  Back-substitution: x[2]=1e308 then
402        // x[1] = 0 - 2*1e308 = -inf (overflows f64).
403        let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 2.0, 5.0]]);
404        let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
405
406        let b = Vector::<3>::new([0.0, 0.0, 1e308]);
407        let err = ldlt.solve_vec(b).unwrap_err();
408        assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
409    }
410}