la_stack/
ldlt.rs

1//! LDLT factorization and solves.
2//!
3//! This module provides a stack-allocated LDLT factorization (`A = L D Lᵀ`) intended for
4//! symmetric positive definite (SPD) and positive semi-definite (PSD) matrices (e.g. Gram
5//! matrices) without pivoting.
6
7use crate::LaError;
8use crate::matrix::Matrix;
9use crate::vector::Vector;
10
11/// LDLT factorization (`A = L D Lᵀ`) for symmetric positive (semi)definite matrices.
12///
13/// This factorization is **not** a general-purpose symmetric-indefinite LDLT (no pivoting).
14/// It assumes the input matrix is symmetric and (numerically) SPD/PSD.
15///
16/// # Storage
17/// The factors are stored in a single [`Matrix`]:
18/// - `D` is stored on the diagonal.
19/// - The strict lower triangle stores the multipliers of `L`.
20/// - The diagonal of `L` is implicit ones.
21#[must_use]
22#[derive(Clone, Copy, Debug, PartialEq)]
23pub struct Ldlt<const D: usize> {
24    factors: Matrix<D>,
25    tol: f64,
26}
27
28impl<const D: usize> Ldlt<D> {
29    #[inline]
30    pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
31        debug_assert!(tol >= 0.0, "tol must be non-negative");
32
33        #[cfg(debug_assertions)]
34        debug_assert_symmetric(&a);
35
36        let mut f = a;
37
38        // LDLT via symmetric rank-1 updates, using only the lower triangle.
39        for j in 0..D {
40            let d = f.rows[j][j];
41            if !d.is_finite() {
42                return Err(LaError::NonFinite { pivot_col: j });
43            }
44            if d <= tol {
45                return Err(LaError::Singular { pivot_col: j });
46            }
47
48            // Compute L multipliers below the diagonal in column j.
49            for i in (j + 1)..D {
50                let l = f.rows[i][j] / d;
51                if !l.is_finite() {
52                    return Err(LaError::NonFinite { pivot_col: j });
53                }
54                f.rows[i][j] = l;
55            }
56
57            // Update the trailing submatrix (lower triangle): A := A - (L_col * d) * L_col^T.
58            for i in (j + 1)..D {
59                let l_i = f.rows[i][j];
60                let l_i_d = l_i * d;
61
62                for k in (j + 1)..=i {
63                    let l_k = f.rows[k][j];
64                    let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]);
65                    if !new_val.is_finite() {
66                        return Err(LaError::NonFinite { pivot_col: j });
67                    }
68                    f.rows[i][k] = new_val;
69                }
70            }
71        }
72
73        Ok(Self { factors: f, tol })
74    }
75
76    /// Determinant of the original matrix.
77    ///
78    /// For SPD/PSD matrices, this is the product of the diagonal terms of `D`.
79    ///
80    /// # Examples
81    /// ```
82    /// use la_stack::prelude::*;
83    ///
84    /// // Symmetric SPD matrix.
85    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
86    /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
87    ///
88    /// assert!((ldlt.det() - 8.0).abs() <= 1e-12);
89    /// ```
90    #[inline]
91    #[must_use]
92    pub fn det(&self) -> f64 {
93        let mut det = 1.0;
94        for i in 0..D {
95            det *= self.factors.rows[i][i];
96        }
97        det
98    }
99
100    /// Solve `A x = b` using this LDLT factorization.
101    ///
102    /// # Examples
103    /// ```
104    /// use la_stack::prelude::*;
105    ///
106    /// # fn main() -> Result<(), LaError> {
107    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
108    /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?;
109    ///
110    /// let b = Vector::<2>::new([1.0, 2.0]);
111    /// let x = ldlt.solve_vec(b)?.into_array();
112    ///
113    /// assert!((x[0] - (-0.125)).abs() <= 1e-12);
114    /// assert!((x[1] - 0.75).abs() <= 1e-12);
115    /// # Ok(())
116    /// # }
117    /// ```
118    ///
119    /// # Errors
120    /// Returns [`LaError::Singular`] if a diagonal entry `d = D[i,i]` satisfies `d <= tol`
121    /// (non-positive or too small), where `tol` is the tolerance that was used during factorization.
122    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected.
123    #[inline]
124    pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
125        let mut x = b.data;
126
127        // Forward substitution: L y = b (L has unit diagonal).
128        for i in 0..D {
129            let mut sum = x[i];
130            let row = self.factors.rows[i];
131            for (j, x_j) in x.iter().enumerate().take(i) {
132                sum = (-row[j]).mul_add(*x_j, sum);
133            }
134            if !sum.is_finite() {
135                return Err(LaError::NonFinite { pivot_col: i });
136            }
137            x[i] = sum;
138        }
139
140        // Diagonal solve: D z = y.
141        for (i, x_i) in x.iter_mut().enumerate().take(D) {
142            let diag = self.factors.rows[i][i];
143            if !diag.is_finite() {
144                return Err(LaError::NonFinite { pivot_col: i });
145            }
146            if diag <= self.tol {
147                return Err(LaError::Singular { pivot_col: i });
148            }
149
150            let v = *x_i / diag;
151            if !v.is_finite() {
152                return Err(LaError::NonFinite { pivot_col: i });
153            }
154            *x_i = v;
155        }
156
157        // Back substitution: Lᵀ x = z.
158        for ii in 0..D {
159            let i = D - 1 - ii;
160            let mut sum = x[i];
161            for (j, x_j) in x.iter().enumerate().skip(i + 1) {
162                sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum);
163            }
164            if !sum.is_finite() {
165                return Err(LaError::NonFinite { pivot_col: i });
166            }
167            x[i] = sum;
168        }
169
170        Ok(Vector::new(x))
171    }
172}
173
174#[cfg(debug_assertions)]
175fn debug_assert_symmetric<const D: usize>(a: &Matrix<D>) {
176    let scale = a.inf_norm().max(1.0);
177    let eps = 1e-12 * scale;
178
179    for r in 0..D {
180        for c in (r + 1)..D {
181            let diff = (a.rows[r][c] - a.rows[c][r]).abs();
182            debug_assert!(
183                diff <= eps,
184                "matrix must be symmetric (diff={diff}, eps={eps}) at ({r}, {c})"
185            );
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    use crate::DEFAULT_SINGULAR_TOL;
195
196    use core::hint::black_box;
197
198    use approx::assert_abs_diff_eq;
199    use pastey::paste;
200
201    macro_rules! gen_public_api_ldlt_identity_tests {
202        ($d:literal) => {
203            paste! {
204                #[test]
205                fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
206                    let a = Matrix::<$d>::identity();
207                    let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
208
209                    assert_abs_diff_eq!(ldlt.det(), 1.0, epsilon = 1e-12);
210
211                    let b_arr = {
212                        let mut arr = [0.0f64; $d];
213                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
214                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
215                            *dst = *src;
216                        }
217                        arr
218                    };
219                    let b = Vector::<$d>::new(black_box(b_arr));
220                    let x = ldlt.solve_vec(b).unwrap().into_array();
221
222                    for i in 0..$d {
223                        assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
224                    }
225                }
226            }
227        };
228    }
229
230    gen_public_api_ldlt_identity_tests!(2);
231    gen_public_api_ldlt_identity_tests!(3);
232    gen_public_api_ldlt_identity_tests!(4);
233    gen_public_api_ldlt_identity_tests!(5);
234
235    macro_rules! gen_public_api_ldlt_diagonal_tests {
236        ($d:literal) => {
237            paste! {
238                #[test]
239                fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
240                    let diag = {
241                        let mut arr = [0.0f64; $d];
242                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
243                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
244                            *dst = *src;
245                        }
246                        arr
247                    };
248
249                    let mut rows = [[0.0f64; $d]; $d];
250                    for i in 0..$d {
251                        rows[i][i] = diag[i];
252                    }
253
254                    let a = Matrix::<$d>::from_rows(black_box(rows));
255                    let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
256
257                    let expected_det = {
258                        let mut acc = 1.0;
259                        for i in 0..$d {
260                            acc *= diag[i];
261                        }
262                        acc
263                    };
264                    assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-12);
265
266                    let b_arr = {
267                        let mut arr = [0.0f64; $d];
268                        let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
269                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
270                            *dst = *src;
271                        }
272                        arr
273                    };
274
275                    let b = Vector::<$d>::new(black_box(b_arr));
276                    let x = ldlt.solve_vec(b).unwrap().into_array();
277
278                    for i in 0..$d {
279                        assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
280                    }
281                }
282            }
283        };
284    }
285
286    gen_public_api_ldlt_diagonal_tests!(2);
287    gen_public_api_ldlt_diagonal_tests!(3);
288    gen_public_api_ldlt_diagonal_tests!(4);
289    gen_public_api_ldlt_diagonal_tests!(5);
290
291    #[test]
292    fn solve_2x2_known_spd() {
293        let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]]));
294        let ldlt = (black_box(Ldlt::<2>::factor))(a, DEFAULT_SINGULAR_TOL).unwrap();
295
296        let b = Vector::<2>::new(black_box([1.0, 2.0]));
297        let x = ldlt.solve_vec(b).unwrap().into_array();
298
299        assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
300        assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
301        assert_abs_diff_eq!(ldlt.det(), 8.0, epsilon = 1e-12);
302    }
303
304    #[test]
305    fn solve_3x3_spd_tridiagonal_smoke() {
306        let a = Matrix::<3>::from_rows(black_box([
307            [2.0, -1.0, 0.0],
308            [-1.0, 2.0, -1.0],
309            [0.0, -1.0, 2.0],
310        ]));
311        let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
312
313        // Choose x = 1 so b = A x is simple: [1, 0, 1].
314        let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
315        let x = ldlt.solve_vec(b).unwrap().into_array();
316
317        for &x_i in &x {
318            assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
319        }
320    }
321
322    #[test]
323    fn singular_detected_for_degenerate_psd() {
324        // Rank-1 Gram-like matrix.
325        let a = Matrix::<2>::from_rows(black_box([[1.0, 1.0], [1.0, 1.0]]));
326        let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
327        assert_eq!(err, LaError::Singular { pivot_col: 1 });
328    }
329
330    #[test]
331    fn nonfinite_detected() {
332        let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
333        let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
334        assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
335    }
336}