Skip to main content

echidna_optim/
linalg.rs

1use num_traits::Float;
2
3/// Result of LU factorization with partial pivoting.
4///
5/// Stores the combined L/U factors in a single matrix (L below diagonal,
6/// U on and above diagonal) plus the row permutation.
7pub struct LuFactors<F> {
8    /// Combined L/U matrix: L is below the diagonal (unit diagonal implicit),
9    /// U is on and above the diagonal.
10    lu: Vec<Vec<F>>,
11    /// Row permutation: `perm[i]` is the original row index for factored row `i`.
12    perm: Vec<usize>,
13    n: usize,
14}
15
16/// Factorize an `n x n` matrix via LU decomposition with partial pivoting.
17///
18/// Returns `None` if the matrix is singular (zero or near-zero pivot).
19// Explicit indexing is clearer for pivoted LU: row/col indices drive pivot search and elimination
20#[allow(clippy::needless_range_loop)]
21pub fn lu_factor<F: Float>(a: &[Vec<F>]) -> Option<LuFactors<F>> {
22    let n = a.len();
23    debug_assert!(a.iter().all(|row| row.len() == n));
24
25    let mut lu: Vec<Vec<F>> = a.to_vec();
26    let mut perm: Vec<usize> = (0..n).collect();
27
28    // Use a relative singularity threshold: eps_mach * n * max_pivot.
29    // This adapts to both f32 and f64, and scales with matrix magnitude.
30    let eps_mach = F::epsilon();
31    let n_f = F::from(n).unwrap();
32    let mut max_pivot_seen = F::zero();
33
34    for col in 0..n {
35        // Find pivot
36        let mut max_val = lu[col][col].abs();
37        let mut max_row = col;
38        for row in (col + 1)..n {
39            let v = lu[row][col].abs();
40            if v > max_val {
41                max_val = v;
42                max_row = row;
43            }
44        }
45
46        max_pivot_seen = max_pivot_seen.max(max_val);
47        let tol = eps_mach * n_f * max_pivot_seen;
48        // Also catch all-zero columns where the relative threshold is zero
49        if max_val == F::zero() || max_val < tol {
50            return None; // Singular
51        }
52
53        // Swap rows
54        if max_row != col {
55            lu.swap(col, max_row);
56            perm.swap(col, max_row);
57        }
58
59        let pivot = lu[col][col];
60
61        // Eliminate below, storing L factors in-place
62        for row in (col + 1)..n {
63            let factor = lu[row][col] / pivot;
64            lu[row][col] = factor; // Store L factor
65            for j in (col + 1)..n {
66                let val = lu[col][j];
67                lu[row][j] = lu[row][j] - factor * val;
68            }
69        }
70    }
71
72    Some(LuFactors { lu, perm, n })
73}
74
75/// Solve `A * x = b` using a pre-computed LU factorization.
76///
77/// This avoids re-factorizing when solving multiple right-hand sides
78/// against the same matrix.
79// Explicit indexing is clearer for forward/back substitution with permuted indices
80#[allow(clippy::needless_range_loop)]
81pub fn lu_back_solve<F: Float>(factors: &LuFactors<F>, b: &[F]) -> Vec<F> {
82    let n = factors.n;
83    debug_assert_eq!(b.len(), n);
84
85    // Apply permutation to b
86    let mut y = vec![F::zero(); n];
87    for i in 0..n {
88        y[i] = b[factors.perm[i]];
89    }
90
91    // Forward substitution (L * y' = permuted_b), L has unit diagonal
92    for i in 1..n {
93        for j in 0..i {
94            let l_ij = factors.lu[i][j];
95            let y_j = y[j];
96            y[i] = y[i] - l_ij * y_j;
97        }
98    }
99
100    // Back substitution (U * x = y')
101    let mut x = vec![F::zero(); n];
102    for i in (0..n).rev() {
103        let mut sum = y[i];
104        for j in (i + 1)..n {
105            sum = sum - factors.lu[i][j] * x[j];
106        }
107        x[i] = sum / factors.lu[i][i];
108    }
109
110    x
111}
112
113/// Solve `A * x = b` via LU factorization with partial pivoting.
114///
115/// `a` is an `n x n` matrix stored as `a[row][col]`.
116/// Returns `None` if the matrix is singular (zero or near-zero pivot).
117pub fn lu_solve<F: Float>(a: &[Vec<F>], b: &[F]) -> Option<Vec<F>> {
118    let factors = lu_factor(a)?;
119    Some(lu_back_solve(&factors, b))
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn lu_solve_identity() {
128        let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
129        let b = vec![3.0, 7.0];
130        let x = lu_solve(&a, &b).unwrap();
131        assert!((x[0] - 3.0).abs() < 1e-12);
132        assert!((x[1] - 7.0).abs() < 1e-12);
133    }
134
135    #[test]
136    fn lu_solve_2x2() {
137        // [2 1] [x0]   [5]
138        // [1 3] [x1] = [7]
139        // Solution: x0 = 8/5, x1 = 9/5
140        let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
141        let b = vec![5.0, 7.0];
142        let x = lu_solve(&a, &b).unwrap();
143        assert!((x[0] - 1.6).abs() < 1e-12);
144        assert!((x[1] - 1.8).abs() < 1e-12);
145    }
146
147    #[test]
148    fn lu_solve_singular() {
149        let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
150        let b = vec![3.0, 6.0];
151        assert!(lu_solve(&a, &b).is_none());
152    }
153
154    #[test]
155    fn lu_solve_needs_pivoting() {
156        // First pivot is zero — requires row swap
157        let a = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
158        let b = vec![3.0, 7.0];
159        let x = lu_solve(&a, &b).unwrap();
160        assert!((x[0] - 7.0).abs() < 1e-12);
161        assert!((x[1] - 3.0).abs() < 1e-12);
162    }
163
164    #[test]
165    fn lu_factor_then_back_solve_matches_lu_solve() {
166        let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
167        let b1 = vec![5.0, 7.0];
168        let b2 = vec![1.0, 0.0];
169
170        // Factorize once
171        let factors = lu_factor(&a).unwrap();
172
173        // Solve two different RHS
174        let x1 = lu_back_solve(&factors, &b1);
175        let x2 = lu_back_solve(&factors, &b2);
176
177        // Compare with lu_solve
178        let x1_ref = lu_solve(&a, &b1).unwrap();
179        let x2_ref = lu_solve(&a, &b2).unwrap();
180
181        for i in 0..2 {
182            assert!((x1[i] - x1_ref[i]).abs() < 1e-12);
183            assert!((x2[i] - x2_ref[i]).abs() < 1e-12);
184        }
185    }
186
187    #[test]
188    fn lu_factor_then_back_solve_3x3() {
189        // [1 2 3] [x]   [14]
190        // [4 5 6] [y] = [32]
191        // [7 8 0] [z]   [23]
192        let a = vec![
193            vec![1.0, 2.0, 3.0],
194            vec![4.0, 5.0, 6.0],
195            vec![7.0, 8.0, 0.0],
196        ];
197        let b = vec![14.0, 32.0, 23.0];
198        let factors = lu_factor(&a).unwrap();
199        let x = lu_back_solve(&factors, &b);
200        let x_ref = lu_solve(&a, &b).unwrap();
201        for i in 0..3 {
202            assert!(
203                (x[i] - x_ref[i]).abs() < 1e-10,
204                "x[{}] = {}, expected {}",
205                i,
206                x[i],
207                x_ref[i]
208            );
209        }
210    }
211
212    #[test]
213    fn lu_factor_singular_returns_none() {
214        let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
215        assert!(lu_factor(&a).is_none());
216    }
217}