1use crate::LaError;
8use crate::matrix::Matrix;
9use crate::vector::Vector;
10
11#[must_use]
22#[derive(Clone, Copy, Debug, PartialEq)]
23pub struct Ldlt<const D: usize> {
24 factors: Matrix<D>,
25 tol: f64,
26}
27
28impl<const D: usize> Ldlt<D> {
29 #[inline]
30 pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
31 debug_assert!(tol >= 0.0, "tol must be non-negative");
32
33 #[cfg(debug_assertions)]
34 debug_assert_symmetric(&a);
35
36 let mut f = a;
37
38 for j in 0..D {
40 let d = f.rows[j][j];
41 if !d.is_finite() {
42 return Err(LaError::NonFinite {
43 row: Some(j),
44 col: j,
45 });
46 }
47 if d <= tol {
48 return Err(LaError::Singular { pivot_col: j });
49 }
50
51 for i in (j + 1)..D {
53 let l = f.rows[i][j] / d;
54 if !l.is_finite() {
55 return Err(LaError::NonFinite {
56 row: Some(i),
57 col: j,
58 });
59 }
60 f.rows[i][j] = l;
61 }
62
63 for i in (j + 1)..D {
65 let l_i = f.rows[i][j];
66 let l_i_d = l_i * d;
67
68 for k in (j + 1)..=i {
69 let l_k = f.rows[k][j];
70 let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]);
71 if !new_val.is_finite() {
72 return Err(LaError::NonFinite {
73 row: Some(i),
74 col: k,
75 });
76 }
77 f.rows[i][k] = new_val;
78 }
79 }
80 }
81
82 Ok(Self { factors: f, tol })
83 }
84
85 #[inline]
100 #[must_use]
101 pub fn det(&self) -> f64 {
102 let mut det = 1.0;
103 for i in 0..D {
104 det *= self.factors.rows[i][i];
105 }
106 det
107 }
108
109 #[inline]
133 pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
134 let mut x = b.data;
135
136 for i in 0..D {
138 let mut sum = x[i];
139 let row = self.factors.rows[i];
140 for (j, x_j) in x.iter().enumerate().take(i) {
141 sum = (-row[j]).mul_add(*x_j, sum);
142 }
143 if !sum.is_finite() {
144 return Err(LaError::NonFinite { row: None, col: i });
145 }
146 x[i] = sum;
147 }
148
149 for (i, x_i) in x.iter_mut().enumerate().take(D) {
151 let diag = self.factors.rows[i][i];
152 if !diag.is_finite() {
153 return Err(LaError::NonFinite { row: None, col: i });
154 }
155 if diag <= self.tol {
156 return Err(LaError::Singular { pivot_col: i });
157 }
158
159 let v = *x_i / diag;
160 if !v.is_finite() {
161 return Err(LaError::NonFinite { row: None, col: i });
162 }
163 *x_i = v;
164 }
165
166 for ii in 0..D {
168 let i = D - 1 - ii;
169 let mut sum = x[i];
170 for (j, x_j) in x.iter().enumerate().skip(i + 1) {
171 sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum);
172 }
173 if !sum.is_finite() {
174 return Err(LaError::NonFinite { row: None, col: i });
175 }
176 x[i] = sum;
177 }
178
179 Ok(Vector::new(x))
180 }
181}
182
183#[cfg(debug_assertions)]
184fn debug_assert_symmetric<const D: usize>(a: &Matrix<D>) {
185 let scale = a.inf_norm().max(1.0);
186 let eps = 1e-12 * scale;
187
188 for r in 0..D {
189 for c in (r + 1)..D {
190 let diff = (a.rows[r][c] - a.rows[c][r]).abs();
191 debug_assert!(
192 diff <= eps,
193 "matrix must be symmetric (diff={diff}, eps={eps}) at ({r}, {c})"
194 );
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 use crate::DEFAULT_SINGULAR_TOL;
204
205 use core::hint::black_box;
206
207 use approx::assert_abs_diff_eq;
208 use pastey::paste;
209
210 macro_rules! gen_public_api_ldlt_identity_tests {
211 ($d:literal) => {
212 paste! {
213 #[test]
214 fn [<public_api_ldlt_det_and_solve_identity_ $d d>]() {
215 let a = Matrix::<$d>::identity();
216 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
217
218 assert_abs_diff_eq!(ldlt.det(), 1.0, epsilon = 1e-12);
219
220 let b_arr = {
221 let mut arr = [0.0f64; $d];
222 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
223 for (dst, src) in arr.iter_mut().zip(values.iter()) {
224 *dst = *src;
225 }
226 arr
227 };
228 let b = Vector::<$d>::new(black_box(b_arr));
229 let x = ldlt.solve_vec(b).unwrap().into_array();
230
231 for i in 0..$d {
232 assert_abs_diff_eq!(x[i], b_arr[i], epsilon = 1e-12);
233 }
234 }
235 }
236 };
237 }
238
239 gen_public_api_ldlt_identity_tests!(2);
240 gen_public_api_ldlt_identity_tests!(3);
241 gen_public_api_ldlt_identity_tests!(4);
242 gen_public_api_ldlt_identity_tests!(5);
243
244 macro_rules! gen_public_api_ldlt_diagonal_tests {
245 ($d:literal) => {
246 paste! {
247 #[test]
248 fn [<public_api_ldlt_det_and_solve_diagonal_spd_ $d d>]() {
249 let diag = {
250 let mut arr = [0.0f64; $d];
251 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
252 for (dst, src) in arr.iter_mut().zip(values.iter()) {
253 *dst = *src;
254 }
255 arr
256 };
257
258 let mut rows = [[0.0f64; $d]; $d];
259 for i in 0..$d {
260 rows[i][i] = diag[i];
261 }
262
263 let a = Matrix::<$d>::from_rows(black_box(rows));
264 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
265
266 let expected_det = {
267 let mut acc = 1.0;
268 for i in 0..$d {
269 acc *= diag[i];
270 }
271 acc
272 };
273 assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-12);
274
275 let b_arr = {
276 let mut arr = [0.0f64; $d];
277 let values = [5.0f64, 4.0, 3.0, 2.0, 1.0];
278 for (dst, src) in arr.iter_mut().zip(values.iter()) {
279 *dst = *src;
280 }
281 arr
282 };
283
284 let b = Vector::<$d>::new(black_box(b_arr));
285 let x = ldlt.solve_vec(b).unwrap().into_array();
286
287 for i in 0..$d {
288 assert_abs_diff_eq!(x[i], b_arr[i] / diag[i], epsilon = 1e-12);
289 }
290 }
291 }
292 };
293 }
294
295 gen_public_api_ldlt_diagonal_tests!(2);
296 gen_public_api_ldlt_diagonal_tests!(3);
297 gen_public_api_ldlt_diagonal_tests!(4);
298 gen_public_api_ldlt_diagonal_tests!(5);
299
300 #[test]
301 fn solve_2x2_known_spd() {
302 let a = Matrix::<2>::from_rows(black_box([[4.0, 2.0], [2.0, 3.0]]));
303 let ldlt = (black_box(Ldlt::<2>::factor))(a, DEFAULT_SINGULAR_TOL).unwrap();
304
305 let b = Vector::<2>::new(black_box([1.0, 2.0]));
306 let x = ldlt.solve_vec(b).unwrap().into_array();
307
308 assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12);
309 assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12);
310 assert_abs_diff_eq!(ldlt.det(), 8.0, epsilon = 1e-12);
311 }
312
313 #[test]
314 fn solve_3x3_spd_tridiagonal_smoke() {
315 let a = Matrix::<3>::from_rows(black_box([
316 [2.0, -1.0, 0.0],
317 [-1.0, 2.0, -1.0],
318 [0.0, -1.0, 2.0],
319 ]));
320 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
321
322 let b = Vector::<3>::new(black_box([1.0, 0.0, 1.0]));
324 let x = ldlt.solve_vec(b).unwrap().into_array();
325
326 for &x_i in &x {
327 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
328 }
329 }
330
331 #[test]
332 fn singular_detected_for_degenerate_psd() {
333 let a = Matrix::<2>::from_rows(black_box([[1.0, 1.0], [1.0, 1.0]]));
335 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
336 assert_eq!(err, LaError::Singular { pivot_col: 1 });
337 }
338
339 #[test]
340 fn nonfinite_detected() {
341 let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
342 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
343 assert_eq!(
344 err,
345 LaError::NonFinite {
346 row: Some(0),
347 col: 0
348 }
349 );
350 }
351
352 #[test]
353 fn nonfinite_l_multiplier_overflow() {
354 let a = Matrix::<2>::from_rows([[1e-11, 1e300], [1e300, 1.0]]);
356 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
357 assert_eq!(
358 err,
359 LaError::NonFinite {
360 row: Some(1),
361 col: 0
362 }
363 );
364 }
365
366 #[test]
367 fn nonfinite_trailing_submatrix_overflow() {
368 let a = Matrix::<2>::from_rows([[1.0, 1e200], [1e200, 1.0]]);
371 let err = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap_err();
372 assert_eq!(
373 err,
374 LaError::NonFinite {
375 row: Some(1),
376 col: 1
377 }
378 );
379 }
380
381 #[test]
382 fn nonfinite_solve_vec_forward_substitution_overflow() {
383 let a = Matrix::<3>::from_rows([
386 [1.0, 1e153, 0.0],
387 [1e153, 1e306 + 1.0, 0.0],
388 [0.0, 0.0, 1.0],
389 ]);
390 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
391
392 let b = Vector::<3>::new([1e156, 0.0, 0.0]);
393 let err = ldlt.solve_vec(b).unwrap_err();
394 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
395 }
396
397 #[test]
398 fn nonfinite_solve_vec_back_substitution_overflow() {
399 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 2.0, 5.0]]);
404 let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();
405
406 let b = Vector::<3>::new([0.0, 0.0, 1e308]);
407 let err = ldlt.solve_vec(b).unwrap_err();
408 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
409 }
410}