1use crate::LaError;
8use crate::matrix::Matrix;
9use crate::vector::Vector;
10
11#[must_use]
22#[derive(Clone, Copy, Debug, PartialEq)]
23pub struct Ldlt<const D: usize> {
24 factors: Matrix<D>,
25 tol: f64,
26}
27
28impl<const D: usize> Ldlt<D> {
29 #[inline]
30 pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
31 debug_assert!(tol >= 0.0, "tol must be non-negative");
32
33 #[cfg(debug_assertions)]
34 debug_assert_symmetric(&a);
35
36 let mut f = a;
37
38 for j in 0..D {
40 let d = f.rows[j][j];
41 if !d.is_finite() {
42 return Err(LaError::NonFinite { pivot_col: j });
43 }
44 if d <= tol {
45 return Err(LaError::Singular { pivot_col: j });
46 }
47
48 for i in (j + 1)..D {
50 let l = f.rows[i][j] / d;
51 if !l.is_finite() {
52 return Err(LaError::NonFinite { pivot_col: j });
53 }
54 f.rows[i][j] = l;
55 }
56
57 for i in (j + 1)..D {
59 let l_i = f.rows[i][j];
60 let l_i_d = l_i * d;
61
62 for k in (j + 1)..=i {
63 let l_k = f.rows[k][j];
64 let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]);
65 if !new_val.is_finite() {
66 return Err(LaError::NonFinite { pivot_col: j });
67 }
68 f.rows[i][k] = new_val;
69 }
70 }
71 }
72
73 Ok(Self { factors: f, tol })
74 }
75
76 #[inline]
91 #[must_use]
92 pub fn det(&self) -> f64 {
93 let mut det = 1.0;
94 for i in 0..D {
95 det *= self.factors.rows[i][i];
96 }
97 det
98 }
99
100 #[inline]
124 pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
125 let mut x = b.data;
126
127 for i in 0..D {
129 let mut sum = x[i];
130 let row = self.factors.rows[i];
131 for (j, x_j) in x.iter().enumerate().take(i) {
132 sum = (-row[j]).mul_add(*x_j, sum);
133 }
134 if !sum.is_finite() {
135 return Err(LaError::NonFinite { pivot_col: i });
136 }
137 x[i] = sum;
138 }
139
140 for (i, x_i) in x.iter_mut().enumerate().take(D) {
142 let diag = self.factors.rows[i][i];
143 if !diag.is_finite() {
144 return Err(LaError::NonFinite { pivot_col: i });
145 }
146 if diag <= self.tol {
147 return Err(LaError::Singular { pivot_col: i });
148 }
149
150 let v = *x_i / diag;
151 if !v.is_finite() {
152 return Err(LaError::NonFinite { pivot_col: i });
153 }
154 *x_i = v;
155 }
156
157 for ii in 0..D {
159 let i = D - 1 - ii;
160 let mut sum = x[i];
161 for (j, x_j) in x.iter().enumerate().skip(i + 1) {
162 sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum);
163 }
164 if !sum.is_finite() {
165 return Err(LaError::NonFinite { pivot_col: i });
166 }
167 x[i] = sum;
168 }
169
170 Ok(Vector::new(x))
171 }
172}
173
174#[cfg(debug_assertions)]
175fn debug_assert_symmetric<const D: usize>(a: &Matrix<D>) {
176 let scale = a.inf_norm().max(1.0);
177 let eps = 1e-12 * scale;
178
179 for r in 0..D {
180 for c in (r + 1)..D {
181 let diff = (a.rows[r][c] - a.rows[c][r]).abs();
182 debug_assert!(
183 diff <= eps,
184 "matrix must be symmetric (diff={diff}, eps={eps}) at ({r}, {c})"
185 );
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 use crate::DEFAULT_SINGULAR_TOL;
195
196 use core::hint::black_box;
197
198 use approx::assert_abs_diff_eq;
199 use pastey::paste;
200
201 macro_rules! gen_public_api_ldlt_identity_tests {
202 ($d:literal) => {
203 paste! {
204 #[test]
205 fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
206 let a = Matrix::<$d>::identity();
207 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
208
209 assert_abs_diff_eq!(ldlt.det(), 1.0, epsilon = 1e-12);
210
211 let b_arr = {
212 let mut arr = [0.0f64; $d];
213 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
214 for (dst, src) in arr.iter_mut().zip(values.iter()) {
215 *dst = *src;
216 }
217 arr
218 };
219 let b = Vector::<$d>::new(black_box(b_arr));
220 let x = ldlt.solve_vec(b).unwrap().into_array();
221
222 for i in 0..$d {
223 assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
224 }
225 }
226 }
227 };
228 }
229
230 gen_public_api_ldlt_identity_tests!(2);
231 gen_public_api_ldlt_identity_tests!(3);
232 gen_public_api_ldlt_identity_tests!(4);
233 gen_public_api_ldlt_identity_tests!(5);
234
235 macro_rules! gen_public_api_ldlt_diagonal_tests {
236 ($d:literal) => {
237 paste! {
238 #[test]
239 fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
240 let diag = {
241 let mut arr = [0.0f64; $d];
242 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
243 for (dst, src) in arr.iter_mut().zip(values.iter()) {
244 *dst = *src;
245 }
246 arr
247 };
248
249 let mut rows = [[0.0f64; $d]; $d];
250 for i in 0..$d {
251 rows[i][i] = diag[i];
252 }
253
254 let a = Matrix::<$d>::from_rows(black_box(rows));
255 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
256
257 let expected_det = {
258 let mut acc = 1.0;
259 for i in 0..$d {
260 acc *= diag[i];
261 }
262 acc
263 };
264 assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-12);
265
266 let b_arr = {
267 let mut arr = [0.0f64; $d];
268 let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
269 for (dst, src) in arr.iter_mut().zip(values.iter()) {
270 *dst = *src;
271 }
272 arr
273 };
274
275 let b = Vector::<$d>::new(black_box(b_arr));
276 let x = ldlt.solve_vec(b).unwrap().into_array();
277
278 for i in 0..$d {
279 assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
280 }
281 }
282 }
283 };
284 }
285
286 gen_public_api_ldlt_diagonal_tests!(2);
287 gen_public_api_ldlt_diagonal_tests!(3);
288 gen_public_api_ldlt_diagonal_tests!(4);
289 gen_public_api_ldlt_diagonal_tests!(5);
290
291 #[test]
292 fn solve_2x2_known_spd() {
293 let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]]));
294 let ldlt = (black_box(Ldlt::<2>::factor))(a, DEFAULT_SINGULAR_TOL).unwrap();
295
296 let b = Vector::<2>::new(black_box([1.0, 2.0]));
297 let x = ldlt.solve_vec(b).unwrap().into_array();
298
299 assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
300 assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
301 assert_abs_diff_eq!(ldlt.det(), 8.0, epsilon = 1e-12);
302 }
303
304 #[test]
305 fn solve_3x3_spd_tridiagonal_smoke() {
306 let a = Matrix::<3>::from_rows(black_box([
307 [2.0, -1.0, 0.0],
308 [-1.0, 2.0, -1.0],
309 [0.0, -1.0, 2.0],
310 ]));
311 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
312
313 let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
315 let x = ldlt.solve_vec(b).unwrap().into_array();
316
317 for &x_i in &x {
318 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
319 }
320 }
321
322 #[test]
323 fn singular_detected_for_degenerate_psd() {
324 let a = Matrix::<2>::from_rows(black_box([[1.0, 1.0], [1.0, 1.0]]));
326 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
327 assert_eq!(err, LaError::Singular { pivot_col: 1 });
328 }
329
330 #[test]
331 fn nonfinite_detected() {
332 let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
333 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
334 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
335 }
336}