1use crate::LaError;
4use crate::matrix::Matrix;
5use crate::vector::Vector;
6
7#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Lu<const D: usize> {
11 factors: Matrix<D>,
12 piv: [usize; D],
13 piv_sign: f64,
14 tol: f64,
15}
16
17impl<const D: usize> Lu<D> {
18 #[inline]
19 pub(crate) fn factor(a: Matrix<D>, tol: f64) -> Result<Self, LaError> {
20 let mut lu = a;
21
22 let mut piv = [0usize; D];
23 for (i, p) in piv.iter_mut().enumerate() {
24 *p = i;
25 }
26
27 let mut piv_sign = 1.0;
28
29 for k in 0..D {
30 let mut pivot_row = k;
32 let mut pivot_abs = lu.rows[k][k].abs();
33 if !pivot_abs.is_finite() {
34 return Err(LaError::NonFinite { pivot_col: k });
35 }
36
37 for r in (k + 1)..D {
38 let v = lu.rows[r][k].abs();
39 if !v.is_finite() {
40 return Err(LaError::NonFinite { pivot_col: k });
41 }
42 if v > pivot_abs {
43 pivot_abs = v;
44 pivot_row = r;
45 }
46 }
47
48 if pivot_abs <= tol {
49 return Err(LaError::Singular { pivot_col: k });
50 }
51
52 if pivot_row != k {
53 lu.rows.swap(k, pivot_row);
54 piv.swap(k, pivot_row);
55 piv_sign = -piv_sign;
56 }
57
58 let pivot = lu.rows[k][k];
59 if !pivot.is_finite() {
60 return Err(LaError::NonFinite { pivot_col: k });
61 }
62
63 for r in (k + 1)..D {
65 let mult = lu.rows[r][k] / pivot;
66 if !mult.is_finite() {
67 return Err(LaError::NonFinite { pivot_col: k });
68 }
69 lu.rows[r][k] = mult;
70
71 for c in (k + 1)..D {
72 lu.rows[r][c] = (-mult).mul_add(lu.rows[k][c], lu.rows[r][c]);
73 }
74 }
75 }
76
77 Ok(Self {
78 factors: lu,
79 piv,
80 piv_sign,
81 tol,
82 })
83 }
84
85 #[inline]
108 pub fn solve_vec(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
109 let mut x = [0.0; D];
110 for (i, x_i) in x.iter_mut().enumerate() {
111 *x_i = b.data[self.piv[i]];
112 }
113
114 for i in 0..D {
116 let mut sum = x[i];
117 let row = self.factors.rows[i];
118 for (j, x_j) in x.iter().enumerate().take(i) {
119 sum = (-row[j]).mul_add(*x_j, sum);
120 }
121 if !sum.is_finite() {
122 return Err(LaError::NonFinite { pivot_col: i });
123 }
124 x[i] = sum;
125 }
126
127 for ii in 0..D {
129 let i = D - 1 - ii;
130 let mut sum = x[i];
131 let row = self.factors.rows[i];
132 for (j, x_j) in x.iter().enumerate().skip(i + 1) {
133 sum = (-row[j]).mul_add(*x_j, sum);
134 }
135
136 let diag = row[i];
137 if !diag.is_finite() || !sum.is_finite() {
138 return Err(LaError::NonFinite { pivot_col: i });
139 }
140 if diag.abs() <= self.tol {
141 return Err(LaError::Singular { pivot_col: i });
142 }
143
144 x[i] = sum / diag;
145 }
146
147 Ok(Vector::new(x))
148 }
149
150 #[inline]
166 #[must_use]
167 pub fn det(&self) -> f64 {
168 let mut det = self.piv_sign;
169 for i in 0..D {
170 det *= self.factors.rows[i][i];
171 }
172 det
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::DEFAULT_PIVOT_TOL;
180
181 use core::hint::black_box;
182
183 use approx::assert_abs_diff_eq;
184 use pastey::paste;
185
186 macro_rules! gen_public_api_pivoting_solve_vec_and_det_tests {
187 ($d:literal) => {
188 paste! {
189 #[test]
190 fn [<public_api_lu_solve_vec_pivoting_ $d d>]() {
191 let mut rows = [[0.0f64; $d]; $d];
197 for i in 0..$d {
198 rows[i][i] = 1.0;
199 }
200 rows.swap(0, 1);
201
202 let a = Matrix::<$d>::from_rows(black_box(rows));
203 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
204 black_box(Matrix::<$d>::lu);
205 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
206
207 let b_arr = {
209 let mut arr = [0.0f64; $d];
210 let mut val = 1.0f64;
211 for dst in arr.iter_mut() {
212 *dst = val;
213 val += 1.0;
214 }
215 arr
216 };
217 let mut expected = b_arr;
218 expected.swap(0, 1);
219 let b = Vector::<$d>::new(black_box(b_arr));
220
221 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
222 black_box(Lu::<$d>::solve_vec);
223 let x = solve_fn(&lu, b).unwrap().into_array();
224
225 for i in 0..$d {
226 assert_abs_diff_eq!(x[i], expected[i], epsilon = 1e-12);
227 }
228 }
229
230 #[test]
231 fn [<public_api_lu_det_pivoting_ $d d>]() {
232 let mut rows = [[0.0f64; $d]; $d];
237 for i in 0..$d {
238 rows[i][i] = 1.0;
239 }
240 rows.swap(0, 1);
241
242 let a = Matrix::<$d>::from_rows(black_box(rows));
243 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
244 black_box(Matrix::<$d>::lu);
245 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
246
247 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
249 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12);
250 }
251 }
252 };
253 }
254
255 gen_public_api_pivoting_solve_vec_and_det_tests!(2);
256 gen_public_api_pivoting_solve_vec_and_det_tests!(3);
257 gen_public_api_pivoting_solve_vec_and_det_tests!(4);
258 gen_public_api_pivoting_solve_vec_and_det_tests!(5);
259
260 macro_rules! gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests {
261 ($d:literal) => {
262 paste! {
263 #[test]
264 fn [<public_api_lu_solve_vec_tridiagonal_smoke_ $d d>]() {
265 #[allow(clippy::large_stack_arrays)]
270 let mut rows = [[0.0f64; $d]; $d];
271 for i in 0..$d {
272 rows[i][i] = 2.0;
273 if i > 0 {
274 rows[i][i - 1] = -1.0;
275 }
276 if i + 1 < $d {
277 rows[i][i + 1] = -1.0;
278 }
279 }
280
281 let a = Matrix::<$d>::from_rows(black_box(rows));
282 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
283 black_box(Matrix::<$d>::lu);
284 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
285
286 let mut b_arr = [0.0f64; $d];
288 b_arr[0] = 1.0;
289 b_arr[$d - 1] = 1.0;
290 let b = Vector::<$d>::new(black_box(b_arr));
291
292 let solve_fn: fn(&Lu<$d>, Vector<$d>) -> Result<Vector<$d>, LaError> =
293 black_box(Lu::<$d>::solve_vec);
294 let x = solve_fn(&lu, b).unwrap().into_array();
295
296 for &x_i in &x {
297 assert_abs_diff_eq!(x_i, 1.0, epsilon = 1e-9);
298 }
299 }
300
301 #[test]
302 fn [<public_api_lu_det_tridiagonal_smoke_ $d d>]() {
303 #[allow(clippy::large_stack_arrays)]
309 let mut rows = [[0.0f64; $d]; $d];
310 for i in 0..$d {
311 rows[i][i] = 2.0;
312 if i > 0 {
313 rows[i][i - 1] = -1.0;
314 }
315 if i + 1 < $d {
316 rows[i][i + 1] = -1.0;
317 }
318 }
319
320 let a = Matrix::<$d>::from_rows(black_box(rows));
321 let lu_fn: fn(Matrix<$d>, f64) -> Result<Lu<$d>, LaError> =
322 black_box(Matrix::<$d>::lu);
323 let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap();
324
325 let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det);
326 assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8);
327 }
328 }
329 };
330 }
331
332 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(16);
333 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(32);
334 gen_public_api_tridiagonal_smoke_solve_vec_and_det_tests!(64);
335
336 #[test]
337 fn solve_1x1() {
338 let a = Matrix::<1>::from_rows(black_box([[2.0]]));
339 let lu = (black_box(Lu::<1>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
340
341 let b = Vector::<1>::new(black_box([6.0]));
342 let solve_fn: fn(&Lu<1>, Vector<1>) -> Result<Vector<1>, LaError> =
343 black_box(Lu::<1>::solve_vec);
344 let x = solve_fn(&lu, b).unwrap().into_array();
345 assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12);
346
347 let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det);
348 assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0);
349 }
350
351 #[test]
352 fn solve_2x2_basic() {
353 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
354 let lu = (black_box(Lu::<2>::factor))(a, DEFAULT_PIVOT_TOL).unwrap();
355 let b = Vector::<2>::new(black_box([5.0, 11.0]));
356
357 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
358 black_box(Lu::<2>::solve_vec);
359 let x = solve_fn(&lu, b).unwrap().into_array();
360
361 assert_abs_diff_eq!(x[0], 1.0, epsilon = 1e-12);
362 assert_abs_diff_eq!(x[1], 2.0, epsilon = 1e-12);
363 }
364
365 #[test]
366 fn det_2x2_basic() {
367 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]]));
368 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
369
370 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
371 assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12);
372 }
373
374 #[test]
375 fn det_requires_pivot_sign() {
376 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
378 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
379
380 let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det);
381 assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0);
382 }
383
384 #[test]
385 fn solve_requires_pivoting() {
386 let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]]));
387 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
388 let b = Vector::<2>::new(black_box([1.0, 2.0]));
389
390 let solve_fn: fn(&Lu<2>, Vector<2>) -> Result<Vector<2>, LaError> =
391 black_box(Lu::<2>::solve_vec);
392 let x = solve_fn(&lu, b).unwrap().into_array();
393
394 assert_abs_diff_eq!(x[0], 2.0, epsilon = 1e-12);
395 assert_abs_diff_eq!(x[1], 1.0, epsilon = 1e-12);
396 }
397
398 #[test]
399 fn singular_detected() {
400 let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [2.0, 4.0]]));
401 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
402 assert_eq!(err, LaError::Singular { pivot_col: 1 });
403 }
404
405 #[test]
406 fn singular_due_to_tolerance_at_first_pivot() {
407 let a = Matrix::<2>::from_rows(black_box([[1e-13, 0.0], [0.0, 1.0]]));
409 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
410 assert_eq!(err, LaError::Singular { pivot_col: 0 });
411 }
412
413 #[test]
414 fn nonfinite_detected_on_pivot_entry() {
415 let a = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
416 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
417 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
418 }
419
420 #[test]
421 fn nonfinite_detected_in_pivot_column_scan() {
422 let a = Matrix::<2>::from_rows([[1.0, 0.0], [f64::INFINITY, 1.0]]);
423 let err = a.lu(DEFAULT_PIVOT_TOL).unwrap_err();
424 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
425 }
426
427 #[test]
428 fn solve_vec_nonfinite_forward_substitution_overflow() {
429 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
431 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
432
433 let b = Vector::<3>::new([1.0e308, 1.0e308, 0.0]);
434 let err = lu.solve_vec(b).unwrap_err();
435 assert_eq!(err, LaError::NonFinite { pivot_col: 1 });
436 }
437
438 #[test]
439 fn solve_vec_nonfinite_back_substitution_overflow() {
440 let a = Matrix::<2>::from_rows([[1.0, 1.0], [0.0, 2.0e-12]]);
442 let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();
443
444 let b = Vector::<2>::new([0.0, 1.0e300]);
445 let err = lu.solve_vec(b).unwrap_err();
446 assert_eq!(err, LaError::NonFinite { pivot_col: 0 });
447 }
448}