1use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11 factors: Matrix<D>,
12 piv: [usize; D],
13 piv_sign: f64,
14 tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18 #[inline]
19 pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20 let mut lu = a;
21
22 let mut piv = [0usize; D];
23 for (i, p) in piv.iter_mut().enumerate() {
24 *p = i;
25 }
26
27 let mut piv_sign = 1.0;
28
29 for k in 0..D {
30 let mut pivot_row = k;
32 let mut pivot_abs = lu.rows[k][k].abs();
33 if !pivot_abs.is_finite() {
34 return Err(LaError::NonFinite { pivot_col: k });
35 }
36
37 for r in (k + 1)..D {
38 let v = lu.rows[r][k].abs();
39 if !v.is_finite() {
40 return Err(LaError::NonFinite { pivot_col: k });
41 }
42 if v > pivot_abs {
43 pivot_abs = v;
44 pivot_row = r;
45 }
46 }
47
48 if pivot_abs <= tol {
49 return Err(LaError::Singular { pivot_col: k });
50 }
51
52 if pivot_row != k {
53 lu.rows.swap(k, pivot_row);
54 piv.swap(k, pivot_row);
55 piv_sign = -piv_sign;
56 }
57
58 let pivot = lu.rows[k][k];
59 if !pivot.is_finite() {
60 return Err(LaError::NonFinite { pivot_col: k });
61 }
62
63 for r in (k + 1)..D {
65 let mult = lu.rows[r][k] / pivot;
66 if !mult.is_finite() {
67 return Err(LaError::NonFinite { pivot_col: k });
68 }
69 lu.rows[r][k] = mult;
70
71 for c in (k + 1)..D {
72 lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
73 }
74 }
75 }
76
77 Ok(Self {
78 factors: lu,
79 piv,
80 piv_sign,
81 tol,
82 })
83 }
84
85 #[inline]
109 pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
110 let mut x = [0.0; D];
111 for (i, x_i) in x.iter_mut().enumerate() {
112 *x_i = b.data[self.piv[i]];
113 }
114
115 for i in 0..D {
117 let mut sum = x[i];
118 let row = self.factors.rows[i];
119 for (j, x_j) in x.iter().enumerate().take(i) {
120 sum = (-row[j]).mul_add(*x_j, sum);
121 }
122 if !sum.is_finite() {
123 return Err(LaError::NonFinite { pivot_col: i });
124 }
125 x[i] = sum;
126 }
127
128 for ii in 0..D {
130 let i = D - 1 - ii;
131 let mut sum = x[i];
132 let row = self.factors.rows[i];
133 for (j, x_j) in x.iter().enumerate().skip(i + 1) {
134 sum = (-row[j]).mul_add(*x_j, sum);
135 }
136
137 let diag = row[i];
138 if !diag.is_finite() || !sum.is_finite() {
139 return Err(LaError::NonFinite { pivot_col: i });
140 }
141 if diag.abs() <= self.tol {
142 return Err(LaError::Singular { pivot_col: i });
143 }
144
145 x[i] = sum / diag;
146 }
147
148 Ok(Vector::new(x))
149 }
150
151 #[inline]
167 #[must_use]
168 pub fn det(&self) -> f64 {
169 let mut det = self.piv_sign;
170 for i in 0..D {
171 det *= self.factors.rows[i][i];
172 }
173 det
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::DEFAULT_PIVOT_TOL;
181
182 use core::hint::black_box;
183
184 use approx::assert_abs_diff_eq;
185 use pastey::paste;
186
187 macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
188 ($d:literal) => {
189 paste! {
190 #[test]
191 fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
192 let mut rows = [[0.0f64; $d]; $d];
198 for i in 0..$d {
199 rows[i][i] = 1.0;
200 }
201 rows.swap(0, 1);
202
203 let a = Matrix::<$d>::from_rows(black_box(rows));
204 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
205 black_box(Matrix::<$d>::lu);
206 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
207
208 let b_arr = {
210 let mut arr = [0.0f64; $d];
211 let mut val = 1.0f64;
212 for dst in arr.iter_mut() {
213 *dst = val;
214 val += 1.0;
215 }
216 arr
217 };
218 let mut expected = b_arr;
219 expected.swap(0, 1);
220 let b = Vector::<$d>::new(black_box(b_arr));
221
222 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
223 black_box(Lu::<$d>::solve_vec);
224 let x = solve_fn(&lu, b).unwrap().into_array();
225
226 for i in 0..$d {
227 assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
228 }
229 }
230
231 #[test]
232 fn [<public_api_lu_det_pivoting_ $d d>]() {
233 let mut rows = [[0.0f64; $d]; $d];
238 for i in 0..$d {
239 rows[i][i] = 1.0;
240 }
241 rows.swap(0, 1);
242
243 let a = Matrix::<$d>::from_rows(black_box(rows));
244 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
245 black_box(Matrix::<$d>::lu);
246 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
247
248 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
250 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
251 }
252 }
253 };
254 }
255
256 gen_public_api_pivoting_solve_vec_and_det_tests!(2);
257 gen_public_api_pivoting_solve_vec_and_det_tests!(3);
258 gen_public_api_pivoting_solve_vec_and_det_tests!(4);
259 gen_public_api_pivoting_solve_vec_and_det_tests!(5);
260
261 macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
262 ($d:literal) => {
263 paste! {
264 #[test]
265 fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
266 #[allow(clippy::large_stack_arrays)]
271 let mut rows = [[0.0f64; $d]; $d];
272 for i in 0..$d {
273 rows[i][i] = 2.0;
274 if i > 0 {
275 rows[i][i - 1] = -1.0;
276 }
277 if i + 1 < $d {
278 rows[i][i + 1] = -1.0;
279 }
280 }
281
282 let a = Matrix::<$d>::from_rows(black_box(rows));
283 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
284 black_box(Matrix::<$d>::lu);
285 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
286
287 let mut b_arr = [0.0f64; $d];
289 b_arr[0] = 1.0;
290 b_arr[$d - 1] = 1.0;
291 let b = Vector::<$d>::new(black_box(b_arr));
292
293 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
294 black_box(Lu::<$d>::solve_vec);
295 let x = solve_fn(&lu, b).unwrap().into_array();
296
297 for &x_i in &x {
298 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
299 }
300 }
301
302 #[test]
303 fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
304 #[allow(clippy::large_stack_arrays)]
310 let mut rows = [[0.0f64; $d]; $d];
311 for i in 0..$d {
312 rows[i][i] = 2.0;
313 if i > 0 {
314 rows[i][i - 1] = -1.0;
315 }
316 if i + 1 < $d {
317 rows[i][i + 1] = -1.0;
318 }
319 }
320
321 let a = Matrix::<$d>::from_rows(black_box(rows));
322 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
323 black_box(Matrix::<$d>::lu);
324 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
325
326 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
327 assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
328 }
329 }
330 };
331 }
332
333 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
334 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
335 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
336
337 #[test]
338 fn solve_1x1() {
339 let a = Matrix::<1>::from_rows(black_box([[2.0]]));
340 let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
341
342 let b = Vector::<1>::new(black_box([6.0]));
343 let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
344 black_box(Lu::<1>::solve_vec);
345 let x = solve_fn(&lu, b).unwrap().into_array();
346 assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
347
348 let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
349 assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
350 }
351
352 #[test]
353 fn solve_2x2_basic() {
354 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
355 let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
356 let b = Vector::<2>::new(black_box([5.0, 11.0]));
357
358 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
359 black_box(Lu::<2>::solve_vec);
360 let x = solve_fn(&lu, b).unwrap().into_array();
361
362 assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
363 assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
364 }
365
366 #[test]
367 fn det_2x2_basic() {
368 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
369 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
370
371 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
372 assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
373 }
374
375 #[test]
376 fn det_requires_pivot_sign() {
377 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
379 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
380
381 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
382 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
383 }
384
385 #[test]
386 fn solve_requires_pivoting() {
387 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
388 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
389 let b = Vector::<2>::new(black_box([1.0, 2.0]));
390
391 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
392 black_box(Lu::<2>::solve_vec);
393 let x = solve_fn(&lu, b).unwrap().into_array();
394
395 assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
396 assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
397 }
398
399 #[test]
400 fn singular_detected() {
401 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
402 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
403 assert_eq!(err, LaError::Singular { pivot_col: 1 });
404 }
405
406 #[test]
407 fn singular_due_to_tolerance_at_first_pivot() {
408 let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
410 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
411 assert_eq!(err, LaError::Singular { pivot_col: 0 });
412 }
413
414 #[test]
415 fn nonfinite_detected_on_pivot_entry() {
416 let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
417 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
418 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
419 }
420
421 #[test]
422 fn nonfinite_detected_in_pivot_column_scan() {
423 let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
424 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
425 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
426 }
427
428 #[test]
429 fn solve_vec_nonfinite_forward_substitution_overflow() {
430 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
432 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
433
434 let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
435 let err = lu.solve_vec(b).unwrap_err();
436 assert_eq!(err, LaError::NonFinite { pivot_col: 1 });
437 }
438
439 #[test]
440 fn solve_vec_nonfinite_back_substitution_overflow() {
441 let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
443 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
444
445 let b = Vector::<2>::new([0.0, 1.0e300]);
446 let err = lu.solve_vec(b).unwrap_err();
447 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
448 }
449}