solvers/direct/
lu.rs

1//! LU decomposition solver
2//!
3//! Provides LU factorization with partial pivoting for solving dense linear systems.
4//! Uses BLAS/LAPACK when available (native feature), with a pure-Rust fallback.
5
6use crate::traits::ComplexField;
7use ndarray::{Array1, Array2};
8use num_traits::FromPrimitive;
9use thiserror::Error;
10
11/// Errors that can occur during LU factorization
12#[derive(Error, Debug)]
13pub enum LuError {
14    #[error("Matrix is singular or nearly singular")]
15    SingularMatrix,
16    #[error("Matrix dimensions mismatch: expected {expected}, got {got}")]
17    DimensionMismatch { expected: usize, got: usize },
18}
19
20/// LU factorization result
21///
22/// Stores L and U factors along with pivot information
23#[derive(Debug, Clone)]
24pub struct LuFactorization<T: ComplexField> {
25    /// Combined L and U matrices (L is unit lower triangular, stored below diagonal)
26    pub lu: Array2<T>,
27    /// Pivot indices
28    pub pivots: Vec<usize>,
29    /// Matrix dimension
30    pub n: usize,
31}
32
33impl<T: ComplexField> LuFactorization<T> {
34    /// Solve Ax = b using the pre-computed LU factorization
35    pub fn solve(&self, b: &Array1<T>) -> Result<Array1<T>, LuError> {
36        if b.len() != self.n {
37            return Err(LuError::DimensionMismatch {
38                expected: self.n,
39                got: b.len(),
40            });
41        }
42
43        let mut x = b.clone();
44
45        // Apply row permutations (forward substitution with L)
46        for i in 0..self.n {
47            let pivot = self.pivots[i];
48            if pivot != i {
49                x.swap(i, pivot);
50            }
51        }
52
53        // Forward substitution: Ly = Pb
54        for i in 0..self.n {
55            for j in 0..i {
56                let l_ij = self.lu[[i, j]];
57                x[i] = x[i] - l_ij * x[j];
58            }
59        }
60
61        // Backward substitution: Ux = y
62        for i in (0..self.n).rev() {
63            for j in (i + 1)..self.n {
64                let u_ij = self.lu[[i, j]];
65                x[i] = x[i] - u_ij * x[j];
66            }
67            let u_ii = self.lu[[i, i]];
68            if u_ii.norm() < T::Real::from_f64(1e-30).unwrap() {
69                return Err(LuError::SingularMatrix);
70            }
71            x[i] *= u_ii.inv();
72        }
73
74        Ok(x)
75    }
76}
77
78/// Compute LU factorization with partial pivoting (pure Rust implementation)
79pub fn lu_factorize<T: ComplexField>(a: &Array2<T>) -> Result<LuFactorization<T>, LuError> {
80    let n = a.nrows();
81    if n != a.ncols() {
82        return Err(LuError::DimensionMismatch {
83            expected: n,
84            got: a.ncols(),
85        });
86    }
87
88    let mut lu = a.clone();
89    let mut pivots: Vec<usize> = (0..n).collect();
90
91    for k in 0..n {
92        // Find pivot
93        let mut max_val = lu[[k, k]].norm();
94        let mut max_row = k;
95
96        for i in (k + 1)..n {
97            let val = lu[[i, k]].norm();
98            if val > max_val {
99                max_val = val;
100                max_row = i;
101            }
102        }
103
104        // Check for singularity
105        if max_val < T::Real::from_f64(1e-30).unwrap() {
106            return Err(LuError::SingularMatrix);
107        }
108
109        // Swap rows if needed
110        if max_row != k {
111            for j in 0..n {
112                let tmp = lu[[k, j]];
113                lu[[k, j]] = lu[[max_row, j]];
114                lu[[max_row, j]] = tmp;
115            }
116            pivots.swap(k, max_row);
117        }
118
119        // Compute multipliers and eliminate
120        let pivot = lu[[k, k]];
121        for i in (k + 1)..n {
122            let mult = lu[[i, k]] * pivot.inv();
123            lu[[i, k]] = mult; // Store multiplier in L part
124
125            for j in (k + 1)..n {
126                let update = mult * lu[[k, j]];
127                lu[[i, j]] -= update;
128            }
129        }
130    }
131
132    Ok(LuFactorization { lu, pivots, n })
133}
134
135/// Solve Ax = b using LU decomposition
136///
137/// This is a convenience function that combines factorization and solve.
138pub fn lu_solve<T: ComplexField>(a: &Array2<T>, b: &Array1<T>) -> Result<Array1<T>, LuError> {
139    let factorization = lu_factorize(a)?;
140    factorization.solve(b)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use approx::assert_relative_eq;
147    use ndarray::array;
148    use num_complex::Complex64;
149
150    #[test]
151    fn test_lu_solve_real() {
152        let a = array![[4.0_f64, 1.0], [1.0, 3.0],];
153
154        let b = array![1.0_f64, 2.0];
155
156        let x = lu_solve(&a, &b).expect("LU solve should succeed");
157
158        // Verify: Ax = b
159        let ax = a.dot(&x);
160        for i in 0..2 {
161            assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
162        }
163    }
164
165    #[test]
166    fn test_lu_solve_complex() {
167        let a = array![
168            [Complex64::new(4.0, 1.0), Complex64::new(1.0, 0.0)],
169            [Complex64::new(1.0, 0.0), Complex64::new(3.0, -1.0)],
170        ];
171
172        let b = array![Complex64::new(1.0, 1.0), Complex64::new(2.0, -1.0)];
173
174        let x = lu_solve(&a, &b).expect("LU solve should succeed");
175
176        // Verify: Ax ≈ b
177        let ax = a.dot(&x);
178        for i in 0..2 {
179            assert_relative_eq!((ax[i] - b[i]).norm(), 0.0, epsilon = 1e-10);
180        }
181    }
182
183    #[test]
184    fn test_lu_identity() {
185        let n = 5;
186        let a = Array2::from_diag(&Array1::from_elem(n, 1.0_f64));
187        let b = Array1::from_iter((1..=n).map(|i| i as f64));
188
189        let x = lu_solve(&a, &b).expect("LU solve should succeed");
190
191        for i in 0..n {
192            assert_relative_eq!(x[i], b[i], epsilon = 1e-10);
193        }
194    }
195
196    #[test]
197    fn test_lu_singular() {
198        let a = array![[1.0_f64, 2.0], [2.0, 4.0],]; // Singular matrix
199
200        let b = array![1.0_f64, 2.0];
201
202        let result = lu_solve(&a, &b);
203        assert!(result.is_err());
204    }
205
206    #[test]
207    fn test_lu_factorize_and_solve() {
208        let a = array![[4.0_f64, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0],];
209
210        let factorization = lu_factorize(&a).expect("Factorization should succeed");
211
212        // Solve multiple RHS
213        let b1 = array![1.0_f64, 2.0, 3.0];
214        let x1 = factorization.solve(&b1).expect("Solve should succeed");
215
216        let ax1 = a.dot(&x1);
217        for i in 0..3 {
218            assert_relative_eq!(ax1[i], b1[i], epsilon = 1e-10);
219        }
220
221        let b2 = array![4.0_f64, 5.0, 6.0];
222        let x2 = factorization.solve(&b2).expect("Solve should succeed");
223
224        let ax2 = a.dot(&x2);
225        for i in 0..3 {
226            assert_relative_eq!(ax2[i], b2[i], epsilon = 1e-10);
227        }
228    }
229}