Skip to main content

math_audio_solvers/direct/
lu.rs

1//! LU decomposition solver
2//!
3//! Provides LU factorization with partial pivoting for solving dense linear systems.
4//! Uses BLAS/LAPACK when available (native feature), with a pure-Rust fallback.
5
6use crate::traits::ComplexField;
7use ndarray::{Array1, Array2};
8use num_traits::FromPrimitive;
9use thiserror::Error;
10
11#[cfg(feature = "ndarray-linalg")]
12use ndarray_linalg::Solve;
13
14/// Errors that can occur during LU factorization
15#[derive(Error, Debug)]
16pub enum LuError {
17    #[error("Matrix is singular or nearly singular")]
18    SingularMatrix,
19    #[error("Matrix dimensions mismatch: expected {expected}, got {got}")]
20    DimensionMismatch { expected: usize, got: usize },
21}
22
23/// LU factorization result
24///
25/// Stores L and U factors along with pivot information
26#[derive(Debug, Clone)]
27pub struct LuFactorization<T: ComplexField> {
28    /// Combined L and U matrices (L is unit lower triangular, stored below diagonal)
29    pub lu: Array2<T>,
30    /// Pivot indices
31    pub pivots: Vec<usize>,
32    /// Matrix dimension
33    pub n: usize,
34}
35
36impl<T: ComplexField> LuFactorization<T> {
37    /// Solve Ax = b using the pre-computed LU factorization
38    pub fn solve(&self, b: &Array1<T>) -> Result<Array1<T>, LuError> {
39        if b.len() != self.n {
40            return Err(LuError::DimensionMismatch {
41                expected: self.n,
42                got: b.len(),
43            });
44        }
45
46        let mut x = Array1::from_elem(self.n, T::zero());
47
48        // Apply row permutations Pb
49        for i in 0..self.n {
50            x[i] = b[self.pivots[i]];
51        }
52
53        // Forward substitution: Ly = Pb
54        for i in 0..self.n {
55            #[allow(clippy::needless_range_loop)]
56            for j in 0..i {
57                let l_ij = self.lu[[i, j]];
58                let x_j = x[j];
59                x[i] -= l_ij * x_j;
60            }
61        }
62
63        // Backward substitution: Ux = y
64        for i in (0..self.n).rev() {
65            #[allow(clippy::needless_range_loop)]
66            for j in (i + 1)..self.n {
67                let u_ij = self.lu[[i, j]];
68                let x_j = x[j];
69                x[i] -= u_ij * x_j;
70            }
71            let u_ii = self.lu[[i, i]];
72            if u_ii.is_zero_approx(T::Real::from_f64(1e-20).unwrap()) {
73                return Err(LuError::SingularMatrix);
74            }
75            x[i] *= u_ii.inv();
76        }
77
78        Ok(x)
79    }
80}
81
82/// Compute LU factorization with partial pivoting (pure Rust implementation)
83#[allow(dead_code)]
84pub fn lu_factorize<T: ComplexField>(a: &Array2<T>) -> Result<LuFactorization<T>, LuError> {
85    let n = a.nrows();
86    if n != a.ncols() {
87        return Err(LuError::DimensionMismatch {
88            expected: n,
89            got: a.ncols(),
90        });
91    }
92
93    let mut lu = a.clone();
94    let mut pivots: Vec<usize> = (0..n).collect();
95
96    for k in 0..n {
97        // Find pivot
98        let mut max_val = lu[[k, k]].norm();
99        let mut max_row = k;
100
101        for i in (k + 1)..n {
102            let val = lu[[i, k]].norm();
103            if val > max_val {
104                max_val = val;
105                max_row = i;
106            }
107        }
108
109        // Check for singularity
110        if max_val < T::Real::from_f64(1e-20).unwrap() {
111            return Err(LuError::SingularMatrix);
112        }
113
114        // Swap rows if needed
115        if max_row != k {
116            for j in 0..n {
117                let tmp = lu[[k, j]];
118                lu[[k, j]] = lu[[max_row, j]];
119                lu[[max_row, j]] = tmp;
120            }
121            pivots.swap(k, max_row);
122        }
123
124        // Compute multipliers and eliminate
125        let pivot = lu[[k, k]];
126        for i in (k + 1)..n {
127            let mult = lu[[i, k]] * pivot.inv();
128            lu[[i, k]] = mult; // Store multiplier in L part
129
130            for j in (k + 1)..n {
131                let update = mult * lu[[k, j]];
132                lu[[i, j]] -= update;
133            }
134        }
135    }
136
137    Ok(LuFactorization { lu, pivots, n })
138}
139
140/// Solve Ax = b using LU decomposition
141///
142/// This is a convenience function that combines factorization and solve.
143#[cfg(feature = "ndarray-linalg")]
144pub fn lu_solve<T: ComplexField + ndarray_linalg::Lapack>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
145    a.solve_into(b.clone()).map_err(|_| LuError::SingularMatrix)
146}
147
148/// Solve Ax = b using LU decomposition
149///
150/// This is a convenience function that combines factorization and solve.
151#[cfg(not(feature = "ndarray-linalg"))]
152pub fn lu_solve<T: ComplexField>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
153    let factorization = lu_factorize(a)?;
154    factorization.solve(b)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use approx::assert_relative_eq;
161    use ndarray::array;
162    use num_complex::Complex64;
163
164    #[test]
165    fn test_lu_solve_real() {
166        let a = array![[4.0_f64, 1.0], [1.0, 3.0],];
167
168        let b = array![1.0_f64, 2.0];
169
170        let x = lu_solve(&a, &b).expect("LU solve should succeed");
171
172        // Verify: Ax = b
173        let ax = a.dot(&x);
174        for i in 0..2 {
175            assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
176        }
177    }
178
179    #[test]
180    fn test_lu_solve_complex() {
181        let a = array![
182            [Complex64::new(4.0, 1.0), Complex64::new(1.0, 0.0)],
183            [Complex64::new(1.0, 0.0), Complex64::new(3.0, -1.0)],
184        ];
185
186        let b = array![Complex64::new(1.0, 1.0), Complex64::new(2.0, -1.0)];
187
188        let x = lu_solve(&a, &b).expect("LU solve should succeed");
189
190        // Verify: Ax ≈ b
191        let ax = a.dot(&x);
192        for i in 0..2 {
193            assert_relative_eq!((ax[i] - b[i]).norm(), 0.0, epsilon = 1e-10);
194        }
195    }
196
197    #[test]
198    fn test_lu_identity() {
199        let n = 5;
200        let a = Array2::from_diag(&Array1::from_elem(n, 1.0_f64));
201        let b = Array1::from_iter((1..=n).map(|i| i as f64));
202
203        let x = lu_solve(&a, &b).expect("LU solve should succeed");
204
205        for i in 0..n {
206            assert_relative_eq!(x[i], b[i], epsilon = 1e-10);
207        }
208    }
209
210    #[test]
211    fn test_lu_singular() {
212        let a = array![[1.0_f64, 2.0], [2.0, 4.0],]; // Singular matrix
213
214        let b = array![1.0_f64, 2.0];
215
216        let result = lu_solve(&a, &b);
217        assert!(result.is_err());
218    }
219
220    #[test]
221    fn test_lu_factorize_and_solve() {
222        let a = array![[4.0_f64, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0],];
223
224        let factorization = lu_factorize(&a).expect("Factorization should succeed");
225
226        // Solve multiple RHS
227        let b1 = array![1.0_f64, 2.0, 3.0];
228        let x1 = factorization.solve(&b1).expect("Solve should succeed");
229
230        let ax1 = a.dot(&x1);
231        for i in 0..3 {
232            assert_relative_eq!(ax1[i], b1[i], epsilon = 1e-10);
233        }
234
235        let b2 = array![4.0_f64, 5.0, 6.0];
236        let x2 = factorization.solve(&b2).expect("Solve should succeed");
237
238        let ax2 = a.dot(&x2);
239        for i in 0..3 {
240            assert_relative_eq!(ax2[i], b2[i], epsilon = 1e-10);
241        }
242    }
243}