math_audio_solvers/direct/
lu.rs1use crate::traits::ComplexField;
7use ndarray::{Array1, Array2};
8use num_traits::FromPrimitive;
9use thiserror::Error;
10
11#[derive(Error, Debug)]
13pub enum LuError {
14 #[error("Matrix is singular or nearly singular")]
15 SingularMatrix,
16 #[error("Matrix dimensions mismatch: expected {expected}, got {got}")]
17 DimensionMismatch { expected: usize, got: usize },
18}
19
20#[derive(Debug, Clone)]
24pub struct LuFactorization<T: ComplexField> {
25 pub lu: Array2<T>,
27 pub pivots: Vec<usize>,
29 pub n: usize,
31}
32
33impl<T: ComplexField> LuFactorization<T> {
34 pub fn solve(&self, b: &Array1<T>) -> Result<Array1<T>, LuError> {
36 if b.len() != self.n {
37 return Err(LuError::DimensionMismatch {
38 expected: self.n,
39 got: b.len(),
40 });
41 }
42
43 let mut x = Array1::from_elem(self.n, T::zero());
44
45 for i in 0..self.n {
47 x[i] = b[self.pivots[i]];
48 }
49
50 for i in 0..self.n {
52 #[allow(clippy::needless_range_loop)]
53 for j in 0..i {
54 let l_ij = self.lu[[i, j]];
55 let x_j = x[j];
56 x[i] -= l_ij * x_j;
57 }
58 }
59
60 for i in (0..self.n).rev() {
62 #[allow(clippy::needless_range_loop)]
63 for j in (i + 1)..self.n {
64 let u_ij = self.lu[[i, j]];
65 let x_j = x[j];
66 x[i] -= u_ij * x_j;
67 }
68 let u_ii = self.lu[[i, i]];
69 if u_ii.is_zero_approx(T::Real::from_f64(1e-20).unwrap()) {
70 return Err(LuError::SingularMatrix);
71 }
72 x[i] *= u_ii.inv();
73 }
74
75 Ok(x)
76 }
77}
78
79#[allow(dead_code)]
81pub fn lu_factorize<T: ComplexField>(a: &Array2<T>) -> Result<LuFactorization<T>, LuError> {
82 let n = a.nrows();
83 if n != a.ncols() {
84 return Err(LuError::DimensionMismatch {
85 expected: n,
86 got: a.ncols(),
87 });
88 }
89
90 let mut lu = a.clone();
91 let mut pivots: Vec<usize> = (0..n).collect();
92
93 for k in 0..n {
94 let mut max_val = lu[[k, k]].norm();
96 let mut max_row = k;
97
98 for i in (k + 1)..n {
99 let val = lu[[i, k]].norm();
100 if val > max_val {
101 max_val = val;
102 max_row = i;
103 }
104 }
105
106 if max_val < T::Real::from_f64(1e-20).unwrap() {
108 return Err(LuError::SingularMatrix);
109 }
110
111 if max_row != k {
113 for j in 0..n {
114 let tmp = lu[[k, j]];
115 lu[[k, j]] = lu[[max_row, j]];
116 lu[[max_row, j]] = tmp;
117 }
118 pivots.swap(k, max_row);
119 }
120
121 let pivot = lu[[k, k]];
123 for i in (k + 1)..n {
124 let mult = lu[[i, k]] * pivot.inv();
125 lu[[i, k]] = mult; for j in (k + 1)..n {
128 let update = mult * lu[[k, j]];
129 lu[[i, j]] -= update;
130 }
131 }
132 }
133
134 Ok(LuFactorization { lu, pivots, n })
135}
136
137#[cfg(feature = "oxiblas-ndarray")]
141pub fn lu_solve<T: ComplexField + oxiblas_core::Field + Clone + bytemuck::Zeroable>(
142 a: &Array2<T>,
143 b: &Array1<T>,
144) -> Result<Array1<T>, LuError> {
145 oxiblas_ndarray::lapack::solve_ndarray(a, b).map_err(|_| LuError::SingularMatrix)
146}
147
148#[cfg(not(feature = "oxiblas-ndarray"))]
152pub fn lu_solve<T: ComplexField>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
153 let factorization = lu_factorize(a)?;
154 factorization.solve(b)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use approx::assert_relative_eq;
161 use ndarray::array;
162 use num_complex::Complex64;
163
164 #[test]
165 fn test_lu_solve_real() {
166 let a = array![[4.0_f64, 1.0], [1.0, 3.0],];
167
168 let b = array![1.0_f64, 2.0];
169
170 let x = lu_solve(&a, &b).expect("LU solve should succeed");
171
172 let ax = a.dot(&x);
174 for i in 0..2 {
175 assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
176 }
177 }
178
179 #[test]
180 fn test_lu_solve_complex() {
181 let a = array![
182 [Complex64::new(4.0, 1.0), Complex64::new(1.0, 0.0)],
183 [Complex64::new(1.0, 0.0), Complex64::new(3.0, -1.0)],
184 ];
185
186 let b = array![Complex64::new(1.0, 1.0), Complex64::new(2.0, -1.0)];
187
188 let x = lu_solve(&a, &b).expect("LU solve should succeed");
189
190 let ax = a.dot(&x);
192 for i in 0..2 {
193 assert_relative_eq!((ax[i] - b[i]).norm(), 0.0, epsilon = 1e-10);
194 }
195 }
196
197 #[test]
198 fn test_lu_identity() {
199 let n = 5;
200 let a = Array2::from_diag(&Array1::from_elem(n, 1.0_f64));
201 let b = Array1::from_iter((1..=n).map(|i| i as f64));
202
203 let x = lu_solve(&a, &b).expect("LU solve should succeed");
204
205 for i in 0..n {
206 assert_relative_eq!(x[i], b[i], epsilon = 1e-10);
207 }
208 }
209
210 #[test]
211 fn test_lu_singular() {
212 let a = array![[1.0_f64, 2.0], [2.0, 4.0],]; let b = array![1.0_f64, 2.0];
215
216 let result = lu_solve(&a, &b);
217 assert!(result.is_err());
218 }
219
220 #[test]
221 fn test_lu_factorize_and_solve() {
222 let a = array![[4.0_f64, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0],];
223
224 let factorization = lu_factorize(&a).expect("Factorization should succeed");
225
226 let b1 = array![1.0_f64, 2.0, 3.0];
228 let x1 = factorization.solve(&b1).expect("Solve should succeed");
229
230 let ax1 = a.dot(&x1);
231 for i in 0..3 {
232 assert_relative_eq!(ax1[i], b1[i], epsilon = 1e-10);
233 }
234
235 let b2 = array![4.0_f64, 5.0, 6.0];
236 let x2 = factorization.solve(&b2).expect("Solve should succeed");
237
238 let ax2 = a.dot(&x2);
239 for i in 0..3 {
240 assert_relative_eq!(ax2[i], b2[i], epsilon = 1e-10);
241 }
242 }
243}