1use num_traits::Float;
2
3pub struct LuFactors<F> {
8 lu: Vec<Vec<F>>,
11 perm: Vec<usize>,
13 n: usize,
14}
15
16#[allow(clippy::needless_range_loop)]
21pub fn lu_factor<F: Float>(a: &[Vec<F>]) -> Option<LuFactors<F>> {
22 let n = a.len();
23 debug_assert!(a.iter().all(|row| row.len() == n));
24
25 let mut lu: Vec<Vec<F>> = a.to_vec();
26 let mut perm: Vec<usize> = (0..n).collect();
27
28 let eps = F::from(1e-12).unwrap_or_else(|| F::epsilon());
29
30 for col in 0..n {
31 let mut max_val = lu[col][col].abs();
33 let mut max_row = col;
34 for row in (col + 1)..n {
35 let v = lu[row][col].abs();
36 if v > max_val {
37 max_val = v;
38 max_row = row;
39 }
40 }
41
42 if max_val < eps {
43 return None; }
45
46 if max_row != col {
48 lu.swap(col, max_row);
49 perm.swap(col, max_row);
50 }
51
52 let pivot = lu[col][col];
53
54 for row in (col + 1)..n {
56 let factor = lu[row][col] / pivot;
57 lu[row][col] = factor; for j in (col + 1)..n {
59 let val = lu[col][j];
60 lu[row][j] = lu[row][j] - factor * val;
61 }
62 }
63 }
64
65 Some(LuFactors { lu, perm, n })
66}
67
68#[allow(clippy::needless_range_loop)]
74pub fn lu_back_solve<F: Float>(factors: &LuFactors<F>, b: &[F]) -> Vec<F> {
75 let n = factors.n;
76 debug_assert_eq!(b.len(), n);
77
78 let mut y = vec![F::zero(); n];
80 for i in 0..n {
81 y[i] = b[factors.perm[i]];
82 }
83
84 for i in 1..n {
86 for j in 0..i {
87 let l_ij = factors.lu[i][j];
88 let y_j = y[j];
89 y[i] = y[i] - l_ij * y_j;
90 }
91 }
92
93 let mut x = vec![F::zero(); n];
95 for i in (0..n).rev() {
96 let mut sum = y[i];
97 for j in (i + 1)..n {
98 sum = sum - factors.lu[i][j] * x[j];
99 }
100 x[i] = sum / factors.lu[i][i];
101 }
102
103 x
104}
105
106pub fn lu_solve<F: Float>(a: &[Vec<F>], b: &[F]) -> Option<Vec<F>> {
111 let factors = lu_factor(a)?;
112 Some(lu_back_solve(&factors, b))
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn lu_solve_identity() {
121 let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
122 let b = vec![3.0, 7.0];
123 let x = lu_solve(&a, &b).unwrap();
124 assert!((x[0] - 3.0).abs() < 1e-12);
125 assert!((x[1] - 7.0).abs() < 1e-12);
126 }
127
128 #[test]
129 fn lu_solve_2x2() {
130 let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
134 let b = vec![5.0, 7.0];
135 let x = lu_solve(&a, &b).unwrap();
136 assert!((x[0] - 1.6).abs() < 1e-12);
137 assert!((x[1] - 1.8).abs() < 1e-12);
138 }
139
140 #[test]
141 fn lu_solve_singular() {
142 let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
143 let b = vec![3.0, 6.0];
144 assert!(lu_solve(&a, &b).is_none());
145 }
146
147 #[test]
148 fn lu_solve_needs_pivoting() {
149 let a = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
151 let b = vec![3.0, 7.0];
152 let x = lu_solve(&a, &b).unwrap();
153 assert!((x[0] - 7.0).abs() < 1e-12);
154 assert!((x[1] - 3.0).abs() < 1e-12);
155 }
156
157 #[test]
158 fn lu_factor_then_back_solve_matches_lu_solve() {
159 let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
160 let b1 = vec![5.0, 7.0];
161 let b2 = vec![1.0, 0.0];
162
163 let factors = lu_factor(&a).unwrap();
165
166 let x1 = lu_back_solve(&factors, &b1);
168 let x2 = lu_back_solve(&factors, &b2);
169
170 let x1_ref = lu_solve(&a, &b1).unwrap();
172 let x2_ref = lu_solve(&a, &b2).unwrap();
173
174 for i in 0..2 {
175 assert!((x1[i] - x1_ref[i]).abs() < 1e-12);
176 assert!((x2[i] - x2_ref[i]).abs() < 1e-12);
177 }
178 }
179
180 #[test]
181 fn lu_factor_then_back_solve_3x3() {
182 let a = vec![
186 vec![1.0, 2.0, 3.0],
187 vec![4.0, 5.0, 6.0],
188 vec![7.0, 8.0, 0.0],
189 ];
190 let b = vec![14.0, 32.0, 23.0];
191 let factors = lu_factor(&a).unwrap();
192 let x = lu_back_solve(&factors, &b);
193 let x_ref = lu_solve(&a, &b).unwrap();
194 for i in 0..3 {
195 assert!(
196 (x[i] - x_ref[i]).abs() < 1e-10,
197 "x[{}] = {}, expected {}",
198 i,
199 x[i],
200 x_ref[i]
201 );
202 }
203 }
204
205 #[test]
206 fn lu_factor_singular_returns_none() {
207 let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
208 assert!(lu_factor(&a).is_none());
209 }
210}