math_audio_solvers/direct/
lu.rs1use crate::traits::ComplexField;
7use ndarray::{Array1, Array2};
8use num_traits::FromPrimitive;
9use thiserror::Error;
10
11#[derive(Error, Debug)]
13pub enum LuError {
14 #[error("Matrix is singular or nearly singular")]
15 SingularMatrix,
16 #[error("Matrix dimensions mismatch: expected {expected}, got {got}")]
17 DimensionMismatch { expected: usize, got: usize },
18}
19
20#[derive(Debug, Clone)]
24pub struct LuFactorization<T: ComplexField> {
25 pub lu: Array2<T>,
27 pub pivots: Vec<usize>,
29 pub n: usize,
31}
32
33impl<T: ComplexField> LuFactorization<T> {
34 pub fn solve(&self, b: &Array1<T>) -> Result<Array1<T>, LuError> {
36 if b.len() != self.n {
37 return Err(LuError::DimensionMismatch {
38 expected: self.n,
39 got: b.len(),
40 });
41 }
42
43 let mut x = b.clone();
44
45 for i in 0..self.n {
47 let pivot = self.pivots[i];
48 if pivot != i {
49 x.swap(i, pivot);
50 }
51 }
52
53 for i in 0..self.n {
55 for j in 0..i {
56 let l_ij = self.lu[[i, j]];
57 x[i] = x[i] - l_ij * x[j];
58 }
59 }
60
61 for i in (0..self.n).rev() {
63 for j in (i + 1)..self.n {
64 let u_ij = self.lu[[i, j]];
65 x[i] = x[i] - u_ij * x[j];
66 }
67 let u_ii = self.lu[[i, i]];
68 if u_ii.norm() < T::Real::from_f64(1e-30).unwrap() {
69 return Err(LuError::SingularMatrix);
70 }
71 x[i] *= u_ii.inv();
72 }
73
74 Ok(x)
75 }
76}
77
78pub fn lu_factorize<T: ComplexField>(a: &Array2<T>) -> Result<LuFactorization<T>, LuError> {
80 let n = a.nrows();
81 if n != a.ncols() {
82 return Err(LuError::DimensionMismatch {
83 expected: n,
84 got: a.ncols(),
85 });
86 }
87
88 let mut lu = a.clone();
89 let mut pivots: Vec<usize> = (0..n).collect();
90
91 for k in 0..n {
92 let mut max_val = lu[[k, k]].norm();
94 let mut max_row = k;
95
96 for i in (k + 1)..n {
97 let val = lu[[i, k]].norm();
98 if val > max_val {
99 max_val = val;
100 max_row = i;
101 }
102 }
103
104 if max_val < T::Real::from_f64(1e-30).unwrap() {
106 return Err(LuError::SingularMatrix);
107 }
108
109 if max_row != k {
111 for j in 0..n {
112 let tmp = lu[[k, j]];
113 lu[[k, j]] = lu[[max_row, j]];
114 lu[[max_row, j]] = tmp;
115 }
116 pivots.swap(k, max_row);
117 }
118
119 let pivot = lu[[k, k]];
121 for i in (k + 1)..n {
122 let mult = lu[[i, k]] * pivot.inv();
123 lu[[i, k]] = mult; for j in (k + 1)..n {
126 let update = mult * lu[[k, j]];
127 lu[[i, j]] -= update;
128 }
129 }
130 }
131
132 Ok(LuFactorization { lu, pivots, n })
133}
134
135pub fn lu_solve<T: ComplexField>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
139 let factorization = lu_factorize(a)?;
140 factorization.solve(b)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use approx::assert_relative_eq;
147 use ndarray::array;
148 use num_complex::Complex64;
149
150 #[test]
151 fn test_lu_solve_real() {
152 let a = array![[4.0_f64, 1.0], [1.0, 3.0],];
153
154 let b = array![1.0_f64, 2.0];
155
156 let x = lu_solve(&a, &b).expect("LU solve should succeed");
157
158 let ax = a.dot(&x);
160 for i in 0..2 {
161 assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
162 }
163 }
164
165 #[test]
166 fn test_lu_solve_complex() {
167 let a = array![
168 [Complex64::new(4.0, 1.0), Complex64::new(1.0, 0.0)],
169 [Complex64::new(1.0, 0.0), Complex64::new(3.0, -1.0)],
170 ];
171
172 let b = array![Complex64::new(1.0, 1.0), Complex64::new(2.0, -1.0)];
173
174 let x = lu_solve(&a, &b).expect("LU solve should succeed");
175
176 let ax = a.dot(&x);
178 for i in 0..2 {
179 assert_relative_eq!((ax[i] - b[i]).norm(), 0.0, epsilon = 1e-10);
180 }
181 }
182
183 #[test]
184 fn test_lu_identity() {
185 let n = 5;
186 let a = Array2::from_diag(&Array1::from_elem(n, 1.0_f64));
187 let b = Array1::from_iter((1..=n).map(|i| i as f64));
188
189 let x = lu_solve(&a, &b).expect("LU solve should succeed");
190
191 for i in 0..n {
192 assert_relative_eq!(x[i], b[i], epsilon = 1e-10);
193 }
194 }
195
196 #[test]
197 fn test_lu_singular() {
198 let a = array![[1.0_f64, 2.0], [2.0, 4.0],]; let b = array![1.0_f64, 2.0];
201
202 let result = lu_solve(&a, &b);
203 assert!(result.is_err());
204 }
205
206 #[test]
207 fn test_lu_factorize_and_solve() {
208 let a = array![[4.0_f64, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0],];
209
210 let factorization = lu_factorize(&a).expect("Factorization should succeed");
211
212 let b1 = array![1.0_f64, 2.0, 3.0];
214 let x1 = factorization.solve(&b1).expect("Solve should succeed");
215
216 let ax1 = a.dot(&x1);
217 for i in 0..3 {
218 assert_relative_eq!(ax1[i], b1[i], epsilon = 1e-10);
219 }
220
221 let b2 = array![4.0_f64, 5.0, 6.0];
222 let x2 = factorization.solve(&b2).expect("Solve should succeed");
223
224 let ax2 = a.dot(&x2);
225 for i in 0..3 {
226 assert_relative_eq!(ax2[i], b2[i], epsilon = 1e-10);
227 }
228 }
229}