Skip to main content

echidna_optim/
linalg.rs

1use num_traits::Float;
2
3/// Result of LU factorization with partial pivoting.
4///
5/// Stores the combined L/U factors in a single matrix (L below diagonal,
6/// U on and above diagonal) plus the row permutation.
7pub struct LuFactors<F> {
8    /// Combined L/U matrix: L is below the diagonal (unit diagonal implicit),
9    /// U is on and above the diagonal.
10    lu: Vec<Vec<F>>,
11    /// Row permutation: `perm[i]` is the original row index for factored row `i`.
12    perm: Vec<usize>,
13    n: usize,
14}
15
16/// Factorize an `n x n` matrix via LU decomposition with partial pivoting.
17///
18/// Returns `None` if the matrix is singular (zero or near-zero pivot).
19// Explicit indexing is clearer for pivoted LU: row/col indices drive pivot search and elimination
20#[allow(clippy::needless_range_loop)]
21pub fn lu_factor<F: Float>(a: &[Vec<F>]) -> Option<LuFactors<F>> {
22    let n = a.len();
23    debug_assert!(a.iter().all(|row| row.len() == n));
24
25    let mut lu: Vec<Vec<F>> = a.to_vec();
26    let mut perm: Vec<usize> = (0..n).collect();
27
28    let eps = F::from(1e-12).unwrap_or_else(|| F::epsilon());
29
30    for col in 0..n {
31        // Find pivot
32        let mut max_val = lu[col][col].abs();
33        let mut max_row = col;
34        for row in (col + 1)..n {
35            let v = lu[row][col].abs();
36            if v > max_val {
37                max_val = v;
38                max_row = row;
39            }
40        }
41
42        if max_val < eps {
43            return None; // Singular
44        }
45
46        // Swap rows
47        if max_row != col {
48            lu.swap(col, max_row);
49            perm.swap(col, max_row);
50        }
51
52        let pivot = lu[col][col];
53
54        // Eliminate below, storing L factors in-place
55        for row in (col + 1)..n {
56            let factor = lu[row][col] / pivot;
57            lu[row][col] = factor; // Store L factor
58            for j in (col + 1)..n {
59                let val = lu[col][j];
60                lu[row][j] = lu[row][j] - factor * val;
61            }
62        }
63    }
64
65    Some(LuFactors { lu, perm, n })
66}
67
68/// Solve `A * x = b` using a pre-computed LU factorization.
69///
70/// This avoids re-factorizing when solving multiple right-hand sides
71/// against the same matrix.
72// Explicit indexing is clearer for forward/back substitution with permuted indices
73#[allow(clippy::needless_range_loop)]
74pub fn lu_back_solve<F: Float>(factors: &LuFactors<F>, b: &[F]) -> Vec<F> {
75    let n = factors.n;
76    debug_assert_eq!(b.len(), n);
77
78    // Apply permutation to b
79    let mut y = vec![F::zero(); n];
80    for i in 0..n {
81        y[i] = b[factors.perm[i]];
82    }
83
84    // Forward substitution (L * y' = permuted_b), L has unit diagonal
85    for i in 1..n {
86        for j in 0..i {
87            let l_ij = factors.lu[i][j];
88            let y_j = y[j];
89            y[i] = y[i] - l_ij * y_j;
90        }
91    }
92
93    // Back substitution (U * x = y')
94    let mut x = vec![F::zero(); n];
95    for i in (0..n).rev() {
96        let mut sum = y[i];
97        for j in (i + 1)..n {
98            sum = sum - factors.lu[i][j] * x[j];
99        }
100        x[i] = sum / factors.lu[i][i];
101    }
102
103    x
104}
105
106/// Solve `A * x = b` via LU factorization with partial pivoting.
107///
108/// `a` is an `n x n` matrix stored as `a[row][col]`.
109/// Returns `None` if the matrix is singular (zero or near-zero pivot).
110pub fn lu_solve<F: Float>(a: &[Vec<F>], b: &[F]) -> Option<Vec<F>> {
111    let factors = lu_factor(a)?;
112    Some(lu_back_solve(&factors, b))
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn lu_solve_identity() {
121        let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
122        let b = vec![3.0, 7.0];
123        let x = lu_solve(&a, &b).unwrap();
124        assert!((x[0] - 3.0).abs() < 1e-12);
125        assert!((x[1] - 7.0).abs() < 1e-12);
126    }
127
128    #[test]
129    fn lu_solve_2x2() {
130        // [2 1] [x0]   [5]
131        // [1 3] [x1] = [7]
132        // Solution: x0 = 8/5, x1 = 9/5
133        let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
134        let b = vec![5.0, 7.0];
135        let x = lu_solve(&a, &b).unwrap();
136        assert!((x[0] - 1.6).abs() < 1e-12);
137        assert!((x[1] - 1.8).abs() < 1e-12);
138    }
139
140    #[test]
141    fn lu_solve_singular() {
142        let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
143        let b = vec![3.0, 6.0];
144        assert!(lu_solve(&a, &b).is_none());
145    }
146
147    #[test]
148    fn lu_solve_needs_pivoting() {
149        // First pivot is zero — requires row swap
150        let a = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
151        let b = vec![3.0, 7.0];
152        let x = lu_solve(&a, &b).unwrap();
153        assert!((x[0] - 7.0).abs() < 1e-12);
154        assert!((x[1] - 3.0).abs() < 1e-12);
155    }
156
157    #[test]
158    fn lu_factor_then_back_solve_matches_lu_solve() {
159        let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
160        let b1 = vec![5.0, 7.0];
161        let b2 = vec![1.0, 0.0];
162
163        // Factorize once
164        let factors = lu_factor(&a).unwrap();
165
166        // Solve two different RHS
167        let x1 = lu_back_solve(&factors, &b1);
168        let x2 = lu_back_solve(&factors, &b2);
169
170        // Compare with lu_solve
171        let x1_ref = lu_solve(&a, &b1).unwrap();
172        let x2_ref = lu_solve(&a, &b2).unwrap();
173
174        for i in 0..2 {
175            assert!((x1[i] - x1_ref[i]).abs() < 1e-12);
176            assert!((x2[i] - x2_ref[i]).abs() < 1e-12);
177        }
178    }
179
180    #[test]
181    fn lu_factor_then_back_solve_3x3() {
182        // [1 2 3] [x]   [14]
183        // [4 5 6] [y] = [32]
184        // [7 8 0] [z]   [23]
185        let a = vec![
186            vec![1.0, 2.0, 3.0],
187            vec![4.0, 5.0, 6.0],
188            vec![7.0, 8.0, 0.0],
189        ];
190        let b = vec![14.0, 32.0, 23.0];
191        let factors = lu_factor(&a).unwrap();
192        let x = lu_back_solve(&factors, &b);
193        let x_ref = lu_solve(&a, &b).unwrap();
194        for i in 0..3 {
195            assert!(
196                (x[i] - x_ref[i]).abs() < 1e-10,
197                "x[{}] = {}, expected {}",
198                i,
199                x[i],
200                x_ref[i]
201            );
202        }
203    }
204
205    #[test]
206    fn lu_factor_singular_returns_none() {
207        let a = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
208        assert!(lu_factor(&a).is_none());
209    }
210}