1use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11 factors: Matrix<D>,
12 piv: [usize; D],
13 piv_sign: f64,
14 tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18 #[inline]
19 pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20 let mut lu = a;
21
22 let mut piv = [0usize; D];
23 for (i, p) in piv.iter_mut().enumerate() {
24 *p = i;
25 }
26
27 let mut piv_sign = 1.0;
28
29 for k in 0..D {
30 let mut pivot_row = k;
32 let mut pivot_abs = lu.rows[k][k].abs();
33 if !pivot_abs.is_finite() {
34 return Err(LaError::NonFinite { pivot_col: k });
35 }
36
37 for r in (k + 1)..D {
38 let v = lu.rows[r][k].abs();
39 if !v.is_finite() {
40 return Err(LaError::NonFinite { pivot_col: k });
41 }
42 if v > pivot_abs {
43 pivot_abs = v;
44 pivot_row = r;
45 }
46 }
47
48 if pivot_abs <= tol {
49 return Err(LaError::Singular { pivot_col: k });
50 }
51
52 if pivot_row != k {
53 lu.rows.swap(k, pivot_row);
54 piv.swap(k, pivot_row);
55 piv_sign = -piv_sign;
56 }
57
58 let pivot = lu.rows[k][k];
59 if !pivot.is_finite() {
60 return Err(LaError::NonFinite { pivot_col: k });
61 }
62
63 for r in (k + 1)..D {
65 let mult = lu.rows[r][k] / pivot;
66 if !mult.is_finite() {
67 return Err(LaError::NonFinite { pivot_col: k });
68 }
69 lu.rows[r][k] = mult;
70
71 for c in (k + 1)..D {
72 lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
73 }
74 }
75 }
76
77 Ok(Self {
78 factors: lu,
79 piv,
80 piv_sign,
81 tol,
82 })
83 }
84
85 #[inline]
109 pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
110 let mut x = [0.0; D];
111 for (i, x_i) in x.iter_mut().enumerate() {
112 *x_i = b.data[self.piv[i]];
113 }
114
115 for i in 0..D {
117 let mut sum = x[i];
118 let row = self.factors.rows[i];
119 for (j, x_j) in x.iter().enumerate().take(i) {
120 sum = (-row[j]).mul_add(*x_j, sum);
121 }
122 if !sum.is_finite() {
123 return Err(LaError::NonFinite { pivot_col: i });
124 }
125 x[i] = sum;
126 }
127
128 for ii in 0..D {
130 let i = D - 1 - ii;
131 let mut sum = x[i];
132 let row = self.factors.rows[i];
133 for (j, x_j) in x.iter().enumerate().skip(i + 1) {
134 sum = (-row[j]).mul_add(*x_j, sum);
135 }
136
137 let diag = row[i];
138 if !diag.is_finite() || !sum.is_finite() {
139 return Err(LaError::NonFinite { pivot_col: i });
140 }
141 if diag.abs() <= self.tol {
142 return Err(LaError::Singular { pivot_col: i });
143 }
144
145 x[i] = sum / diag;
146 }
147
148 Ok(Vector::new(x))
149 }
150
151 #[inline]
168 #[must_use]
169 pub fn det(&self) -> f64 {
170 let mut det = self.piv_sign;
171 for i in 0..D {
172 det *= self.factors.rows[i][i];
173 }
174 det
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::DEFAULT_PIVOT_TOL;
182
183 use core::hint::black_box;
184
185 use approx::assert_abs_diff_eq;
186 use pastey::paste;
187
188 macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
189 ($d:literal) => {
190 paste! {
191 #[test]
192 fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
193 let mut rows = [[0.0f64; $d]; $d];
199 for i in 0..$d {
200 rows[i][i] = 1.0;
201 }
202 rows.swap(0, 1);
203
204 let a = Matrix::<$d>::from_rows(black_box(rows));
205 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
206 black_box(Matrix::<$d>::lu);
207 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
208
209 let b_arr = {
211 let mut arr = [0.0f64; $d];
212 let mut val = 1.0f64;
213 for dst in arr.iter_mut() {
214 *dst = val;
215 val += 1.0;
216 }
217 arr
218 };
219 let mut expected = b_arr;
220 expected.swap(0, 1);
221 let b = Vector::<$d>::new(black_box(b_arr));
222
223 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
224 black_box(Lu::<$d>::solve_vec);
225 let x = solve_fn(&lu, b).unwrap().into_array();
226
227 for i in 0..$d {
228 assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
229 }
230 }
231
232 #[test]
233 fn [<public_api_lu_det_pivoting_ $d d>]() {
234 let mut rows = [[0.0f64; $d]; $d];
239 for i in 0..$d {
240 rows[i][i] = 1.0;
241 }
242 rows.swap(0, 1);
243
244 let a = Matrix::<$d>::from_rows(black_box(rows));
245 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
246 black_box(Matrix::<$d>::lu);
247 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
248
249 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
251 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
252 }
253 }
254 };
255 }
256
257 gen_public_api_pivoting_solve_vec_and_det_tests!(2);
258 gen_public_api_pivoting_solve_vec_and_det_tests!(3);
259 gen_public_api_pivoting_solve_vec_and_det_tests!(4);
260 gen_public_api_pivoting_solve_vec_and_det_tests!(5);
261
262 macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
263 ($d:literal) => {
264 paste! {
265 #[test]
266 fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
267 #[allow(clippy::large_stack_arrays)]
272 let mut rows = [[0.0f64; $d]; $d];
273 for i in 0..$d {
274 rows[i][i] = 2.0;
275 if i > 0 {
276 rows[i][i - 1] = -1.0;
277 }
278 if i + 1 < $d {
279 rows[i][i + 1] = -1.0;
280 }
281 }
282
283 let a = Matrix::<$d>::from_rows(black_box(rows));
284 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
285 black_box(Matrix::<$d>::lu);
286 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
287
288 let mut b_arr = [0.0f64; $d];
290 b_arr[0] = 1.0;
291 b_arr[$d - 1] = 1.0;
292 let b = Vector::<$d>::new(black_box(b_arr));
293
294 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
295 black_box(Lu::<$d>::solve_vec);
296 let x = solve_fn(&lu, b).unwrap().into_array();
297
298 for &x_i in &x {
299 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
300 }
301 }
302
303 #[test]
304 fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
305 #[allow(clippy::large_stack_arrays)]
311 let mut rows = [[0.0f64; $d]; $d];
312 for i in 0..$d {
313 rows[i][i] = 2.0;
314 if i > 0 {
315 rows[i][i - 1] = -1.0;
316 }
317 if i + 1 < $d {
318 rows[i][i + 1] = -1.0;
319 }
320 }
321
322 let a = Matrix::<$d>::from_rows(black_box(rows));
323 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
324 black_box(Matrix::<$d>::lu);
325 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
326
327 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
328 assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
329 }
330 }
331 };
332 }
333
334 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
335 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
336 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
337
338 #[test]
339 fn solve_1x1() {
340 let a = Matrix::<1>::from_rows(black_box([[2.0]]));
341 let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
342
343 let b = Vector::<1>::new(black_box([6.0]));
344 let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
345 black_box(Lu::<1>::solve_vec);
346 let x = solve_fn(&lu, b).unwrap().into_array();
347 assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
348
349 let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
350 assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
351 }
352
353 #[test]
354 fn solve_2x2_basic() {
355 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
356 let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
357 let b = Vector::<2>::new(black_box([5.0, 11.0]));
358
359 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
360 black_box(Lu::<2>::solve_vec);
361 let x = solve_fn(&lu, b).unwrap().into_array();
362
363 assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
364 assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
365 }
366
367 #[test]
368 fn det_2x2_basic() {
369 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
370 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
371
372 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
373 assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
374 }
375
376 #[test]
377 fn det_requires_pivot_sign() {
378 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
380 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
381
382 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
383 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
384 }
385
386 #[test]
387 fn solve_requires_pivoting() {
388 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
389 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
390 let b = Vector::<2>::new(black_box([1.0, 2.0]));
391
392 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
393 black_box(Lu::<2>::solve_vec);
394 let x = solve_fn(&lu, b).unwrap().into_array();
395
396 assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
397 assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
398 }
399
400 #[test]
401 fn singular_detected() {
402 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
403 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
404 assert_eq!(err, LaError::Singular { pivot_col: 1 });
405 }
406
407 #[test]
408 fn singular_due_to_tolerance_at_first_pivot() {
409 let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
411 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
412 assert_eq!(err, LaError::Singular { pivot_col: 0 });
413 }
414
415 #[test]
416 fn nonfinite_detected_on_pivot_entry() {
417 let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
418 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
419 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
420 }
421
422 #[test]
423 fn nonfinite_detected_in_pivot_column_scan() {
424 let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
425 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
426 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
427 }
428
429 #[test]
430 fn solve_vec_nonfinite_forward_substitution_overflow() {
431 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
433 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
434
435 let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
436 let err = lu.solve_vec(b).unwrap_err();
437 assert_eq!(err, LaError::NonFinite { pivot_col: 1 });
438 }
439
440 #[test]
441 fn solve_vec_nonfinite_back_substitution_overflow() {
442 let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
444 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
445
446 let b = Vector::<2>::new([0.0, 1.0e300]);
447 let err = lu.solve_vec(b).unwrap_err();
448 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
449 }
450}