1use num_traits::Float;
2
3pub struct LuFactors<F> {
8 lu: Vec<Vec<F>>,
11 perm: Vec<usize>,
13 n: usize,
14}
15
16#[allow(clippy::needless_range_loop)]
21pub fn lu_factor<F: Float>(a: &[Vec<F>]) -> Option<LuFactors<F>> {
22 let n = a.len();
23 debug_assert!(a.iter().all(|row| row.len() == n));
24
25 let mut lu: Vec<Vec<F>> = a.to_vec();
26 let mut perm: Vec<usize> = (0..n).collect();
27
28 let eps_mach = F::epsilon();
31 let n_f = F::from(n).unwrap();
32 let mut max_pivot_seen = F::zero();
33
34 for col in 0..n {
35 let mut max_val = lu[col][col].abs();
37 let mut max_row = col;
38 for row in (col + 1)..n {
39 let v = lu[row][col].abs();
40 if v > max_val {
41 max_val = v;
42 max_row = row;
43 }
44 }
45
46 max_pivot_seen = max_pivot_seen.max(max_val);
47 let tol = eps_mach * n_f * max_pivot_seen;
48 if max_val == F::zero() || max_val < tol {
50 return None; }
52
53 if max_row != col {
55 lu.swap(col, max_row);
56 perm.swap(col, max_row);
57 }
58
59 let pivot = lu[col][col];
60
61 for row in (col + 1)..n {
63 let factor = lu[row][col] / pivot;
64 lu[row][col] = factor; for j in (col + 1)..n {
66 let val = lu[col][j];
67 lu[row][j] = lu[row][j] - factor * val;
68 }
69 }
70 }
71
72 Some(LuFactors { lu, perm, n })
73}
74
75#[allow(clippy::needless_range_loop)]
81pub fn lu_back_solve<F: Float>(factors: &LuFactors<F>, b: &[F]) -> Vec<F> {
82 let n = factors.n;
83 debug_assert_eq!(b.len(), n);
84
85 let mut y = vec![F::zero(); n];
87 for i in 0..n {
88 y[i] = b[factors.perm[i]];
89 }
90
91 for i in 1..n {
93 for j in 0..i {
94 let l_ij = factors.lu[i][j];
95 let y_j = y[j];
96 y[i] = y[i] - l_ij * y_j;
97 }
98 }
99
100 let mut x = vec![F::zero(); n];
102 for i in (0..n).rev() {
103 let mut sum = y[i];
104 for j in (i + 1)..n {
105 sum = sum - factors.lu[i][j] * x[j];
106 }
107 x[i] = sum / factors.lu[i][i];
108 }
109
110 x
111}
112
113pub fn lu_solve<F: Float>(a: &[Vec<F>], b: &[F]) -> Option<Vec<F>> {
118 let factors = lu_factor(a)?;
119 Some(lu_back_solve(&factors, b))
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn lu_solve_identity() {
128 let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
129 let b = vec![3.0, 7.0];
130 let x = lu_solve(&a, &b).unwrap();
131 assert!((x[0] - 3.0).abs() < 1e-12);
132 assert!((x[1] - 7.0).abs() < 1e-12);
133 }
134
135 #[test]
136 fn lu_solve_2x2() {
137 let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
141 let b = vec![5.0, 7.0];
142 let x = lu_solve(&a, &b).unwrap();
143 assert!((x[0] - 1.6).abs() < 1e-12);
144 assert!((x[1] - 1.8).abs() < 1e-12);
145 }
146
147 #[test]
148 fn lu_solve_singular() {
149 let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
150 let b = vec![3.0, 6.0];
151 assert!(lu_solve(&a, &b).is_none());
152 }
153
154 #[test]
155 fn lu_solve_needs_pivoting() {
156 let a = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
158 let b = vec![3.0, 7.0];
159 let x = lu_solve(&a, &b).unwrap();
160 assert!((x[0] - 7.0).abs() < 1e-12);
161 assert!((x[1] - 3.0).abs() < 1e-12);
162 }
163
164 #[test]
165 fn lu_factor_then_back_solve_matches_lu_solve() {
166 let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
167 let b1 = vec![5.0, 7.0];
168 let b2 = vec![1.0, 0.0];
169
170 let factors = lu_factor(&a).unwrap();
172
173 let x1 = lu_back_solve(&factors, &b1);
175 let x2 = lu_back_solve(&factors, &b2);
176
177 let x1_ref = lu_solve(&a, &b1).unwrap();
179 let x2_ref = lu_solve(&a, &b2).unwrap();
180
181 for i in 0..2 {
182 assert!((x1[i] - x1_ref[i]).abs() < 1e-12);
183 assert!((x2[i] - x2_ref[i]).abs() < 1e-12);
184 }
185 }
186
187 #[test]
188 fn lu_factor_then_back_solve_3x3() {
189 let a = vec![
193 vec![1.0, 2.0, 3.0],
194 vec![4.0, 5.0, 6.0],
195 vec![7.0, 8.0, 0.0],
196 ];
197 let b = vec![14.0, 32.0, 23.0];
198 let factors = lu_factor(&a).unwrap();
199 let x = lu_back_solve(&factors, &b);
200 let x_ref = lu_solve(&a, &b).unwrap();
201 for i in 0..3 {
202 assert!(
203 (x[i] - x_ref[i]).abs() < 1e-10,
204 "x[{}] = {}, expected {}",
205 i,
206 x[i],
207 x_ref[i]
208 );
209 }
210 }
211
212 #[test]
213 fn lu_factor_singular_returns_none() {
214 let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
215 assert!(lu_factor(&a).is_none());
216 }
217}