1use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11 factors: Matrix<D>,
12 piv: [usize; D],
13 piv_sign: f64,
14 tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18 #[inline]
19 pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20 let mut lu = a;
21
22 let mut piv = [0usize; D];
23 for (i, p) in piv.iter_mut().enumerate() {
24 *p = i;
25 }
26
27 let mut piv_sign = 1.0;
28
29 for k in 0..D {
30 let mut pivot_row = k;
32 let mut pivot_abs = lu.rows[k][k].abs();
33 if !pivot_abs.is_finite() {
34 return Err(LaError::NonFinite {
35 row: Some(k),
36 col: k,
37 });
38 }
39
40 for r in (k + 1)..D {
41 let v = lu.rows[r][k].abs();
42 if !v.is_finite() {
43 return Err(LaError::NonFinite {
44 row: Some(r),
45 col: k,
46 });
47 }
48 if v > pivot_abs {
49 pivot_abs = v;
50 pivot_row = r;
51 }
52 }
53
54 if pivot_abs <= tol {
55 return Err(LaError::Singular { pivot_col: k });
56 }
57
58 if pivot_row != k {
59 lu.rows.swap(k, pivot_row);
60 piv.swap(k, pivot_row);
61 piv_sign = -piv_sign;
62 }
63
64 let pivot = lu.rows[k][k];
65 if !pivot.is_finite() {
66 return Err(LaError::NonFinite {
67 row: Some(k),
68 col: k,
69 });
70 }
71
72 for r in (k + 1)..D {
74 let mult = lu.rows[r][k] / pivot;
75 if !mult.is_finite() {
76 return Err(LaError::NonFinite {
77 row: Some(r),
78 col: k,
79 });
80 }
81 lu.rows[r][k] = mult;
82
83 for c in (k + 1)..D {
84 lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
85 }
86 }
87 }
88
89 Ok(Self {
90 factors: lu,
91 piv,
92 piv_sign,
93 tol,
94 })
95 }
96
97 #[inline]
121 pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
122 let mut x = [0.0; D];
123 for (i, x_i) in x.iter_mut().enumerate() {
124 *x_i = b.data[self.piv[i]];
125 }
126
127 for i in 0..D {
129 let mut sum = x[i];
130 let row = self.factors.rows[i];
131 for (j, x_j) in x.iter().enumerate().take(i) {
132 sum = (-row[j]).mul_add(*x_j, sum);
133 }
134 if !sum.is_finite() {
135 return Err(LaError::NonFinite { row: None, col: i });
136 }
137 x[i] = sum;
138 }
139
140 for ii in 0..D {
142 let i = D - 1 - ii;
143 let mut sum = x[i];
144 let row = self.factors.rows[i];
145 for (j, x_j) in x.iter().enumerate().skip(i + 1) {
146 sum = (-row[j]).mul_add(*x_j, sum);
147 }
148
149 let diag = row[i];
150 if !diag.is_finite() || !sum.is_finite() {
151 return Err(LaError::NonFinite { row: None, col: i });
152 }
153 if diag.abs() <= self.tol {
154 return Err(LaError::Singular { pivot_col: i });
155 }
156
157 let q = sum / diag;
158 if !q.is_finite() {
159 return Err(LaError::NonFinite { row: None, col: i });
160 }
161 x[i] = q;
162 }
163
164 Ok(Vector::new(x))
165 }
166
167 #[inline]
183 #[must_use]
184 pub fn det(&self) -> f64 {
185 let mut det = self.piv_sign;
186 for i in 0..D {
187 det *= self.factors.rows[i][i];
188 }
189 det
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::DEFAULT_PIVOT_TOL;
197
198 use core::hint::black_box;
199
200 use approx::assert_abs_diff_eq;
201 use pastey::paste;
202
203 macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
204 ($d:literal) => {
205 paste! {
206 #[test]
207 fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
208 let mut rows = [[0.0f64; $d]; $d];
214 for i in 0..$d {
215 rows[i][i] = 1.0;
216 }
217 rows.swap(0, 1);
218
219 let a = Matrix::<$d>::from_rows(black_box(rows));
220 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
221 black_box(Matrix::<$d>::lu);
222 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
223
224 let b_arr = {
226 let mut arr = [0.0f64; $d];
227 let mut val = 1.0f64;
228 for dst in arr.iter_mut() {
229 *dst = val;
230 val += 1.0;
231 }
232 arr
233 };
234 let mut expected = b_arr;
235 expected.swap(0, 1);
236 let b = Vector::<$d>::new(black_box(b_arr));
237
238 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
239 black_box(Lu::<$d>::solve_vec);
240 let x = solve_fn(&lu, b).unwrap().into_array();
241
242 for i in 0..$d {
243 assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
244 }
245 }
246
247 #[test]
248 fn [<public_api_lu_det_pivoting_ $d d>]() {
249 let mut rows = [[0.0f64; $d]; $d];
254 for i in 0..$d {
255 rows[i][i] = 1.0;
256 }
257 rows.swap(0, 1);
258
259 let a = Matrix::<$d>::from_rows(black_box(rows));
260 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
261 black_box(Matrix::<$d>::lu);
262 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
263
264 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
266 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
267 }
268 }
269 };
270 }
271
272 gen_public_api_pivoting_solve_vec_and_det_tests!(2);
273 gen_public_api_pivoting_solve_vec_and_det_tests!(3);
274 gen_public_api_pivoting_solve_vec_and_det_tests!(4);
275 gen_public_api_pivoting_solve_vec_and_det_tests!(5);
276
277 macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
278 ($d:literal) => {
279 paste! {
280 #[test]
281 fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
282 #[allow(clippy::large_stack_arrays)]
287 let mut rows = [[0.0f64; $d]; $d];
288 for i in 0..$d {
289 rows[i][i] = 2.0;
290 if i > 0 {
291 rows[i][i - 1] = -1.0;
292 }
293 if i + 1 < $d {
294 rows[i][i + 1] = -1.0;
295 }
296 }
297
298 let a = Matrix::<$d>::from_rows(black_box(rows));
299 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
300 black_box(Matrix::<$d>::lu);
301 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
302
303 let mut b_arr = [0.0f64; $d];
305 b_arr[0] = 1.0;
306 b_arr[$d - 1] = 1.0;
307 let b = Vector::<$d>::new(black_box(b_arr));
308
309 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
310 black_box(Lu::<$d>::solve_vec);
311 let x = solve_fn(&lu, b).unwrap().into_array();
312
313 for &x_i in &x {
314 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
315 }
316 }
317
318 #[test]
319 fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
320 #[allow(clippy::large_stack_arrays)]
326 let mut rows = [[0.0f64; $d]; $d];
327 for i in 0..$d {
328 rows[i][i] = 2.0;
329 if i > 0 {
330 rows[i][i - 1] = -1.0;
331 }
332 if i + 1 < $d {
333 rows[i][i + 1] = -1.0;
334 }
335 }
336
337 let a = Matrix::<$d>::from_rows(black_box(rows));
338 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
339 black_box(Matrix::<$d>::lu);
340 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
341
342 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
343 assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
344 }
345 }
346 };
347 }
348
349 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
350 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
351 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
352
353 #[test]
354 fn solve_1x1() {
355 let a = Matrix::<1>::from_rows(black_box([[2.0]]));
356 let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
357
358 let b = Vector::<1>::new(black_box([6.0]));
359 let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
360 black_box(Lu::<1>::solve_vec);
361 let x = solve_fn(&lu, b).unwrap().into_array();
362 assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
363
364 let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
365 assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
366 }
367
368 #[test]
369 fn solve_2x2_basic() {
370 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
371 let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
372 let b = Vector::<2>::new(black_box([5.0, 11.0]));
373
374 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
375 black_box(Lu::<2>::solve_vec);
376 let x = solve_fn(&lu, b).unwrap().into_array();
377
378 assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
379 assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
380 }
381
382 #[test]
383 fn det_2x2_basic() {
384 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
385 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
386
387 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
388 assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
389 }
390
391 #[test]
392 fn det_requires_pivot_sign() {
393 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
395 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
396
397 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
398 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
399 }
400
401 #[test]
402 fn solve_requires_pivoting() {
403 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
404 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
405 let b = Vector::<2>::new(black_box([1.0, 2.0]));
406
407 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
408 black_box(Lu::<2>::solve_vec);
409 let x = solve_fn(&lu, b).unwrap().into_array();
410
411 assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
412 assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
413 }
414
415 #[test]
416 fn singular_detected() {
417 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
418 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
419 assert_eq!(err, LaError::Singular { pivot_col: 1 });
420 }
421
422 #[test]
423 fn singular_due_to_tolerance_at_first_pivot() {
424 let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
426 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
427 assert_eq!(err, LaError::Singular { pivot_col: 0 });
428 }
429
430 #[test]
431 fn nonfinite_detected_on_pivot_entry() {
432 let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
433 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
434 assert_eq!(
435 err,
436 LaError::NonFinite {
437 row: Some(0),
438 col: 0
439 }
440 );
441 }
442
443 #[test]
444 fn nonfinite_detected_in_pivot_column_scan() {
445 let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
446 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
447 assert_eq!(
448 err,
449 LaError::NonFinite {
450 row: Some(1),
451 col: 0
452 }
453 );
454 }
455
456 #[test]
457 fn solve_vec_nonfinite_forward_substitution_overflow() {
458 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
460 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
461
462 let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
463 let err = lu.solve_vec(b).unwrap_err();
464 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
465 }
466
467 #[test]
468 fn solve_vec_nonfinite_back_substitution_overflow() {
469 let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
471 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
472
473 let b = Vector::<2>::new([0.0, 1.0e300]);
474 let err = lu.solve_vec(b).unwrap_err();
475 assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
476 }
477}