math_audio_solvers/direct/
lu.rs

1//! LU decomposition solver
2//!
3//! Provides LU factorization with partial pivoting for solving dense linear systems.
4//! Uses BLAS/LAPACK when available (native feature), with a pure-Rust fallback.
5
6use crate::traits::ComplexField;
7use ndarray::{Array1, Array2};
8use num_traits::FromPrimitive;
9use thiserror::Error;
10
11#[cfg(feature = "ndarray-linalg")]
12use ndarray_linalg::Solve;
13
14/// Errors that can occur during LU factorization
15#[derive(Error, Debug)]
16pub enum LuError {
17    #[error("Matrix is singular or nearly singular")]
18    SingularMatrix,
19    #[error("Matrix dimensions mismatch: expected {expected}, got {got}")]
20    DimensionMismatch { expected: usize, got: usize },
21}
22
23/// LU factorization result
24///
25/// Stores L and U factors along with pivot information
26#[derive(Debug, Clone)]
27pub struct LuFactorization<T: ComplexField> {
28    /// Combined L and U matrices (L is unit lower triangular, stored below diagonal)
29    pub lu: Array2<T>,
30    /// Pivot indices
31    pub pivots: Vec<usize>,
32    /// Matrix dimension
33    pub n: usize,
34}
35
36impl<T: ComplexField> LuFactorization<T> {
37    /// Solve Ax = b using the pre-computed LU factorization
38    pub fn solve(&self, b: &Array1<T>) -> Result<Array1<T>, LuError> {
39        if b.len() != self.n {
40            return Err(LuError::DimensionMismatch {
41                expected: self.n,
42                got: b.len(),
43            });
44        }
45
46        let mut x = b.clone();
47
48        // Apply row permutations (forward substitution with L)
49        for i in 0..self.n {
50            let pivot = self.pivots[i];
51            if pivot != i {
52                x.swap(i, pivot);
53            }
54        }
55
56        // Forward substitution: Ly = Pb
57        for i in 0..self.n {
58            for j in 0..i {
59                let l_ij = self.lu[[i, j]];
60                x[i] = x[i] - l_ij * x[j];
61            }
62        }
63
64        // Backward substitution: Ux = y
65        for i in (0..self.n).rev() {
66            for j in (i + 1)..self.n {
67                let u_ij = self.lu[[i, j]];
68                x[i] = x[i] - u_ij * x[j];
69            }
70            let u_ii = self.lu[[i, i]];
71            if u_ii.norm() < T::Real::from_f64(1e-30).unwrap() {
72                return Err(LuError::SingularMatrix);
73            }
74            x[i] *= u_ii.inv();
75        }
76
77        Ok(x)
78    }
79}
80
81/// Compute LU factorization with partial pivoting (pure Rust implementation)
82#[allow(dead_code)]
83pub fn lu_factorize<T: ComplexField>(a: &Array2<T>) -> Result<LuFactorization<T>, LuError> {
84    let n = a.nrows();
85    if n != a.ncols() {
86        return Err(LuError::DimensionMismatch {
87            expected: n,
88            got: a.ncols(),
89        });
90    }
91
92    let mut lu = a.clone();
93    let mut pivots: Vec<usize> = (0..n).collect();
94
95    for k in 0..n {
96        // Find pivot
97        let mut max_val = lu[[k, k]].norm();
98        let mut max_row = k;
99
100        for i in (k + 1)..n {
101            let val = lu[[i, k]].norm();
102            if val > max_val {
103                max_val = val;
104                max_row = i;
105            }
106        }
107
108        // Check for singularity
109        if max_val < T::Real::from_f64(1e-30).unwrap() {
110            return Err(LuError::SingularMatrix);
111        }
112
113        // Swap rows if needed
114        if max_row != k {
115            for j in 0..n {
116                let tmp = lu[[k, j]];
117                lu[[k, j]] = lu[[max_row, j]];
118                lu[[max_row, j]] = tmp;
119            }
120            pivots.swap(k, max_row);
121        }
122
123        // Compute multipliers and eliminate
124        let pivot = lu[[k, k]];
125        for i in (k + 1)..n {
126            let mult = lu[[i, k]] * pivot.inv();
127            lu[[i, k]] = mult; // Store multiplier in L part
128
129            for j in (k + 1)..n {
130                let update = mult * lu[[k, j]];
131                lu[[i, j]] -= update;
132            }
133        }
134    }
135
136    Ok(LuFactorization { lu, pivots, n })
137}
138
139/// Solve Ax = b using LU decomposition
140///
141/// This is a convenience function that combines factorization and solve.
142pub fn lu_solve<T: ComplexField>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
143    #[cfg(feature = "ndarray-linalg")]
144    {
145        a.solve_into(b.clone()).map_err(|_| LuError::SingularMatrix)
146    }
147
148    #[cfg(not(feature = "ndarray-linalg"))]
149    {
150        let factorization = lu_factorize(a)?;
151        factorization.solve(b)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use approx::assert_relative_eq;
159    use ndarray::array;
160    use num_complex::Complex64;
161
162    #[test]
163    fn test_lu_solve_real() {
164        let a = array![[4.0_f64, 1.0], [1.0, 3.0],];
165
166        let b = array![1.0_f64, 2.0];
167
168        let x = lu_solve(&a, &b).expect("LU solve should succeed");
169
170        // Verify: Ax = b
171        let ax = a.dot(&x);
172        for i in 0..2 {
173            assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
174        }
175    }
176
177    #[test]
178    fn test_lu_solve_complex() {
179        let a = array![
180            [Complex64::new(4.0, 1.0), Complex64::new(1.0, 0.0)],
181            [Complex64::new(1.0, 0.0), Complex64::new(3.0, -1.0)],
182        ];
183
184        let b = array![Complex64::new(1.0, 1.0), Complex64::new(2.0, -1.0)];
185
186        let x = lu_solve(&a, &b).expect("LU solve should succeed");
187
188        // Verify: Ax ≈ b
189        let ax = a.dot(&x);
190        for i in 0..2 {
191            assert_relative_eq!((ax[i] - b[i]).norm(), 0.0, epsilon = 1e-10);
192        }
193    }
194
195    #[test]
196    fn test_lu_identity() {
197        let n = 5;
198        let a = Array2::from_diag(&Array1::from_elem(n, 1.0_f64));
199        let b = Array1::from_iter((1..=n).map(|i| i as f64));
200
201        let x = lu_solve(&a, &b).expect("LU solve should succeed");
202
203        for i in 0..n {
204            assert_relative_eq!(x[i], b[i], epsilon = 1e-10);
205        }
206    }
207
208    #[test]
209    fn test_lu_singular() {
210        let a = array![[1.0_f64, 2.0], [2.0, 4.0],]; // Singular matrix
211
212        let b = array![1.0_f64, 2.0];
213
214        let result = lu_solve(&a, &b);
215        assert!(result.is_err());
216    }
217
218    #[test]
219    fn test_lu_factorize_and_solve() {
220        let a = array![[4.0_f64, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0],];
221
222        let factorization = lu_factorize(&a).expect("Factorization should succeed");
223
224        // Solve multiple RHS
225        let b1 = array![1.0_f64, 2.0, 3.0];
226        let x1 = factorization.solve(&b1).expect("Solve should succeed");
227
228        let ax1 = a.dot(&x1);
229        for i in 0..3 {
230            assert_relative_eq!(ax1[i], b1[i], epsilon = 1e-10);
231        }
232
233        let b2 = array![4.0_f64, 5.0, 6.0];
234        let x2 = factorization.solve(&b2).expect("Solve should succeed");
235
236        let ax2 = a.dot(&x2);
237        for i in 0..3 {
238            assert_relative_eq!(ax2[i], b2[i], epsilon = 1e-10);
239        }
240    }
241}