math_audio_solvers/direct/
lu.rs1use crate::traits::ComplexField;
7use ndarray::{Array1, Array2};
8use num_traits::FromPrimitive;
9use thiserror::Error;
10
11#[cfg(feature = "ndarray-linalg")]
12use ndarray_linalg::Solve;
13
14#[derive(Error, Debug)]
16pub enum LuError {
17 #[error("Matrix is singular or nearly singular")]
18 SingularMatrix,
19 #[error("Matrix dimensions mismatch: expected {expected}, got {got}")]
20 DimensionMismatch { expected: usize, got: usize },
21}
22
23#[derive(Debug, Clone)]
27pub struct LuFactorization<T: ComplexField> {
28 pub lu: Array2<T>,
30 pub pivots: Vec<usize>,
32 pub n: usize,
34}
35
36impl<T: ComplexField> LuFactorization<T> {
37 pub fn solve(&self, b: &Array1<T>) -> Result<Array1<T>, LuError> {
39 if b.len() != self.n {
40 return Err(LuError::DimensionMismatch {
41 expected: self.n,
42 got: b.len(),
43 });
44 }
45
46 let mut x = b.clone();
47
48 for i in 0..self.n {
50 let pivot = self.pivots[i];
51 if pivot != i {
52 x.swap(i, pivot);
53 }
54 }
55
56 for i in 0..self.n {
58 for j in 0..i {
59 let l_ij = self.lu[[i, j]];
60 x[i] = x[i] - l_ij * x[j];
61 }
62 }
63
64 for i in (0..self.n).rev() {
66 for j in (i + 1)..self.n {
67 let u_ij = self.lu[[i, j]];
68 x[i] = x[i] - u_ij * x[j];
69 }
70 let u_ii = self.lu[[i, i]];
71 if u_ii.norm() < T::Real::from_f64(1e-30).unwrap() {
72 return Err(LuError::SingularMatrix);
73 }
74 x[i] *= u_ii.inv();
75 }
76
77 Ok(x)
78 }
79}
80
81pub fn lu_factorize<T: ComplexField>(a: &Array2<T>) -> Result<LuFactorization<T>, LuError> {
83 let n = a.nrows();
84 if n != a.ncols() {
85 return Err(LuError::DimensionMismatch {
86 expected: n,
87 got: a.ncols(),
88 });
89 }
90
91 let mut lu = a.clone();
92 let mut pivots: Vec<usize> = (0..n).collect();
93
94 for k in 0..n {
95 let mut max_val = lu[[k, k]].norm();
97 let mut max_row = k;
98
99 for i in (k + 1)..n {
100 let val = lu[[i, k]].norm();
101 if val > max_val {
102 max_val = val;
103 max_row = i;
104 }
105 }
106
107 if max_val < T::Real::from_f64(1e-30).unwrap() {
109 return Err(LuError::SingularMatrix);
110 }
111
112 if max_row != k {
114 for j in 0..n {
115 let tmp = lu[[k, j]];
116 lu[[k, j]] = lu[[max_row, j]];
117 lu[[max_row, j]] = tmp;
118 }
119 pivots.swap(k, max_row);
120 }
121
122 let pivot = lu[[k, k]];
124 for i in (k + 1)..n {
125 let mult = lu[[i, k]] * pivot.inv();
126 lu[[i, k]] = mult; for j in (k + 1)..n {
129 let update = mult * lu[[k, j]];
130 lu[[i, j]] -= update;
131 }
132 }
133 }
134
135 Ok(LuFactorization { lu, pivots, n })
136}
137
138pub fn lu_solve<T: ComplexField>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
142 #[cfg(feature = "ndarray-linalg")]
143 {
144 a.solve_into(b.clone()).map_err(|_| LuError::SingularMatrix)
145 }
146
147 #[cfg(not(feature = "ndarray-linalg"))]
148 {
149 let factorization = lu_factorize(a)?;
150 factorization.solve(b)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use approx::assert_relative_eq;
158 use ndarray::array;
159 use num_complex::Complex64;
160
161 #[test]
162 fn test_lu_solve_real() {
163 let a = array![[4.0_f64, 1.0], [1.0, 3.0],];
164
165 let b = array![1.0_f64, 2.0];
166
167 let x = lu_solve(&a, &b).expect("LU solve should succeed");
168
169 let ax = a.dot(&x);
171 for i in 0..2 {
172 assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
173 }
174 }
175
176 #[test]
177 fn test_lu_solve_complex() {
178 let a = array![
179 [Complex64::new(4.0, 1.0), Complex64::new(1.0, 0.0)],
180 [Complex64::new(1.0, 0.0), Complex64::new(3.0, -1.0)],
181 ];
182
183 let b = array![Complex64::new(1.0, 1.0), Complex64::new(2.0, -1.0)];
184
185 let x = lu_solve(&a, &b).expect("LU solve should succeed");
186
187 let ax = a.dot(&x);
189 for i in 0..2 {
190 assert_relative_eq!((ax[i] - b[i]).norm(), 0.0, epsilon = 1e-10);
191 }
192 }
193
194 #[test]
195 fn test_lu_identity() {
196 let n = 5;
197 let a = Array2::from_diag(&Array1::from_elem(n, 1.0_f64));
198 let b = Array1::from_iter((1..=n).map(|i| i as f64));
199
200 let x = lu_solve(&a, &b).expect("LU solve should succeed");
201
202 for i in 0..n {
203 assert_relative_eq!(x[i], b[i], epsilon = 1e-10);
204 }
205 }
206
207 #[test]
208 fn test_lu_singular() {
209 let a = array![[1.0_f64, 2.0], [2.0, 4.0],]; let b = array![1.0_f64, 2.0];
212
213 let result = lu_solve(&a, &b);
214 assert!(result.is_err());
215 }
216
217 #[test]
218 fn test_lu_factorize_and_solve() {
219 let a = array![[4.0_f64, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0],];
220
221 let factorization = lu_factorize(&a).expect("Factorization should succeed");
222
223 let b1 = array![1.0_f64, 2.0, 3.0];
225 let x1 = factorization.solve(&b1).expect("Solve should succeed");
226
227 let ax1 = a.dot(&x1);
228 for i in 0..3 {
229 assert_relative_eq!(ax1[i], b1[i], epsilon = 1e-10);
230 }
231
232 let b2 = array![4.0_f64, 5.0, 6.0];
233 let x2 = factorization.solve(&b2).expect("Solve should succeed");
234
235 let ax2 = a.dot(&x2);
236 for i in 0..3 {
237 assert_relative_eq!(ax2[i], b2[i], epsilon = 1e-10);
238 }
239 }
240}