Skip to main content

math_audio_solvers/direct/
lu.rs

1//! LU decomposition solver
2//!
3//! Provides LU factorization with partial pivoting for solving dense linear systems.
4//! Uses BLAS/LAPACK when available (native feature), with a pure-Rust fallback.
5
6use crate::traits::ComplexField;
7use ndarray::{Array1, Array2};
8use num_traits::FromPrimitive;
9use thiserror::Error;
10
11/// Errors that can occur during LU factorization
12#[derive(Error, Debug)]
13pub enum LuError {
14    #[error("Matrix is singular or nearly singular")]
15    SingularMatrix,
16    #[error("Matrix dimensions mismatch: expected {expected}, got {got}")]
17    DimensionMismatch { expected: usize, got: usize },
18}
19
20/// LU factorization result
21///
22/// Stores L and U factors along with pivot information
23#[derive(Debug, Clone)]
24pub struct LuFactorization<T: ComplexField> {
25    /// Combined L and U matrices (L is unit lower triangular, stored below diagonal)
26    pub lu: Array2<T>,
27    /// Pivot indices
28    pub pivots: Vec<usize>,
29    /// Matrix dimension
30    pub n: usize,
31}
32
33impl<T: ComplexField> LuFactorization<T> {
34    /// Solve Ax = b using the pre-computed LU factorization
35    pub fn solve(&self, b: &Array1<T>) -> Result<Array1<T>, LuError> {
36        if b.len() != self.n {
37            return Err(LuError::DimensionMismatch {
38                expected: self.n,
39                got: b.len(),
40            });
41        }
42
43        let mut x = Array1::from_elem(self.n, T::zero());
44
45        // Apply row permutations Pb
46        for i in 0..self.n {
47            x[i] = b[self.pivots[i]];
48        }
49
50        // Forward substitution: Ly = Pb
51        for i in 0..self.n {
52            #[allow(clippy::needless_range_loop)]
53            for j in 0..i {
54                let l_ij = self.lu[[i, j]];
55                let x_j = x[j];
56                x[i] -= l_ij * x_j;
57            }
58        }
59
60        // Backward substitution: Ux = y
61        for i in (0..self.n).rev() {
62            #[allow(clippy::needless_range_loop)]
63            for j in (i + 1)..self.n {
64                let u_ij = self.lu[[i, j]];
65                let x_j = x[j];
66                x[i] -= u_ij * x_j;
67            }
68            let u_ii = self.lu[[i, i]];
69            if u_ii.is_zero_approx(T::Real::from_f64(1e-20).unwrap()) {
70                return Err(LuError::SingularMatrix);
71            }
72            x[i] *= u_ii.inv();
73        }
74
75        Ok(x)
76    }
77}
78
79/// Compute LU factorization with partial pivoting (pure Rust implementation)
80#[allow(dead_code)]
81pub fn lu_factorize<T: ComplexField>(a: &Array2<T>) -> Result<LuFactorization<T>, LuError> {
82    let n = a.nrows();
83    if n != a.ncols() {
84        return Err(LuError::DimensionMismatch {
85            expected: n,
86            got: a.ncols(),
87        });
88    }
89
90    let mut lu = a.clone();
91    let mut pivots: Vec<usize> = (0..n).collect();
92
93    for k in 0..n {
94        // Find pivot
95        let mut max_val = lu[[k, k]].norm();
96        let mut max_row = k;
97
98        for i in (k + 1)..n {
99            let val = lu[[i, k]].norm();
100            if val > max_val {
101                max_val = val;
102                max_row = i;
103            }
104        }
105
106        // Check for singularity
107        if max_val < T::Real::from_f64(1e-20).unwrap() {
108            return Err(LuError::SingularMatrix);
109        }
110
111        // Swap rows if needed
112        if max_row != k {
113            for j in 0..n {
114                let tmp = lu[[k, j]];
115                lu[[k, j]] = lu[[max_row, j]];
116                lu[[max_row, j]] = tmp;
117            }
118            pivots.swap(k, max_row);
119        }
120
121        // Compute multipliers and eliminate
122        let pivot = lu[[k, k]];
123        for i in (k + 1)..n {
124            let mult = lu[[i, k]] * pivot.inv();
125            lu[[i, k]] = mult; // Store multiplier in L part
126
127            for j in (k + 1)..n {
128                let update = mult * lu[[k, j]];
129                lu[[i, j]] -= update;
130            }
131        }
132    }
133
134    Ok(LuFactorization { lu, pivots, n })
135}
136
137/// Solve Ax = b using LU decomposition
138///
139/// This is a convenience function that combines factorization and solve.
140#[cfg(feature = "oxiblas-ndarray")]
141pub fn lu_solve<T: ComplexField + oxiblas_core::Field + Clone + bytemuck::Zeroable>(
142    a: &Array2<T>,
143    b: &Array1<T>,
144) -> Result<Array1<T>, LuError> {
145    oxiblas_ndarray::lapack::solve_ndarray(a, b).map_err(|_| LuError::SingularMatrix)
146}
147
148/// Solve Ax = b using LU decomposition
149///
150/// This is a convenience function that combines factorization and solve.
151#[cfg(not(feature = "oxiblas-ndarray"))]
152pub fn lu_solve<T: ComplexField>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
153    let factorization = lu_factorize(a)?;
154    factorization.solve(b)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use approx::assert_relative_eq;
161    use ndarray::array;
162    use num_complex::Complex64;
163
164    #[test]
165    fn test_lu_solve_real() {
166        let a = array![[4.0_f64, 1.0], [1.0, 3.0],];
167
168        let b = array![1.0_f64, 2.0];
169
170        let x = lu_solve(&a, &b).expect("LU solve should succeed");
171
172        // Verify: Ax = b
173        let ax = a.dot(&x);
174        for i in 0..2 {
175            assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
176        }
177    }
178
179    #[test]
180    fn test_lu_solve_complex() {
181        let a = array![
182            [Complex64::new(4.0, 1.0), Complex64::new(1.0, 0.0)],
183            [Complex64::new(1.0, 0.0), Complex64::new(3.0, -1.0)],
184        ];
185
186        let b = array![Complex64::new(1.0, 1.0), Complex64::new(2.0, -1.0)];
187
188        let x = lu_solve(&a, &b).expect("LU solve should succeed");
189
190        // Verify: Ax ≈ b
191        let ax = a.dot(&x);
192        for i in 0..2 {
193            assert_relative_eq!((ax[i] - b[i]).norm(), 0.0, epsilon = 1e-10);
194        }
195    }
196
197    #[test]
198    fn test_lu_identity() {
199        let n = 5;
200        let a = Array2::from_diag(&Array1::from_elem(n, 1.0_f64));
201        let b = Array1::from_iter((1..=n).map(|i| i as f64));
202
203        let x = lu_solve(&a, &b).expect("LU solve should succeed");
204
205        for i in 0..n {
206            assert_relative_eq!(x[i], b[i], epsilon = 1e-10);
207        }
208    }
209
210    #[test]
211    fn test_lu_singular() {
212        let a = array![[1.0_f64, 2.0], [2.0, 4.0],]; // Singular matrix
213
214        let b = array![1.0_f64, 2.0];
215
216        let result = lu_solve(&a, &b);
217        assert!(result.is_err());
218    }
219
220    #[test]
221    fn test_lu_factorize_and_solve() {
222        let a = array![[4.0_f64, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0],];
223
224        let factorization = lu_factorize(&a).expect("Factorization should succeed");
225
226        // Solve multiple RHS
227        let b1 = array![1.0_f64, 2.0, 3.0];
228        let x1 = factorization.solve(&b1).expect("Solve should succeed");
229
230        let ax1 = a.dot(&x1);
231        for i in 0..3 {
232            assert_relative_eq!(ax1[i], b1[i], epsilon = 1e-10);
233        }
234
235        let b2 = array![4.0_f64, 5.0, 6.0];
236        let x2 = factorization.solve(&b2).expect("Solve should succeed");
237
238        let ax2 = a.dot(&x2);
239        for i in 0..3 {
240            assert_relative_eq!(ax2[i], b2[i], epsilon = 1e-10);
241        }
242    }
243}