Skip to main content

trueno_solve/
lu.rs

1//! LU factorization with partial pivoting.
2//!
3//! # Contract: solve-lu-v1.yaml
4//!
5//! PA = LU where L is unit lower triangular, U is upper triangular,
6//! P is a permutation matrix (stored as pivot vector).
7//!
8//! ## Proof obligations
9//! - ||PA - LU||_F / ||A||_F < n · u
10//! - L is unit lower triangular, U is upper triangular
11//! - Solution accuracy: ||Ax - b|| / (||A|| · ||x||) < κ(A) · n · u
12
13use crate::error::SolverError;
14
15/// LU factorization result (in-place: L and U stored in the same matrix).
16///
17/// After factorization, the matrix `a` contains:
18/// - Upper triangle (including diagonal): U
19/// - Strict lower triangle: L (unit diagonal implicit)
20#[derive(Debug)]
21pub struct LuFactorization {
22    /// Row dimension.
23    pub n: usize,
24    /// Factored matrix (L\U packed).
25    pub lu: Vec<f32>,
26    /// Pivot indices: row i was swapped with row pivot[i].
27    pub pivot: Vec<usize>,
28}
29
30/// LU factorization with partial pivoting.
31///
32/// # Contract: solve-lu-v1.yaml / lu_factorization
33///
34/// Stores L and U in-place: U in upper triangle, L (unit diagonal) in strict lower.
35///
36/// # Errors
37///
38/// Returns `SingularMatrix` if a zero pivot is encountered.
39#[allow(clippy::cast_precision_loss)]
40pub fn lu_factorize(a: &[f32], n: usize) -> Result<LuFactorization, SolverError> {
41    if a.len() != n * n {
42        return Err(SolverError::NotSquare {
43            rows: n,
44            cols: a.len() / n.max(1),
45        });
46    }
47
48    let mut lu = a.to_vec();
49    let mut pivot: Vec<usize> = (0..n).collect();
50
51    for k in 0..n {
52        // Partial pivoting: find max |a[i][k]| for i >= k
53        let mut max_val = lu[k * n + k].abs();
54        let mut max_row = k;
55        for i in (k + 1)..n {
56            let val = lu[i * n + k].abs();
57            if val > max_val {
58                max_val = val;
59                max_row = i;
60            }
61        }
62
63        if max_val < f32::EPSILON * 1e3 {
64            return Err(SolverError::SingularMatrix(k));
65        }
66
67        // Swap rows k and max_row
68        if max_row != k {
69            pivot.swap(k, max_row);
70            for j in 0..n {
71                lu.swap(k * n + j, max_row * n + j);
72            }
73        }
74
75        // Eliminate below diagonal
76        let pivot_val = lu[k * n + k];
77        for i in (k + 1)..n {
78            let factor = lu[i * n + k] / pivot_val;
79            lu[i * n + k] = factor; // Store L factor
80
81            for j in (k + 1)..n {
82                lu[i * n + j] -= factor * lu[k * n + j];
83            }
84        }
85    }
86
87    Ok(LuFactorization { n, lu, pivot })
88}
89
90impl LuFactorization {
91    /// Solve Ax = b using the LU factorization.
92    ///
93    /// # Contract: solve-lu-v1.yaml / solution_accuracy
94    ///
95    /// # Errors
96    ///
97    /// Returns error on dimension mismatch.
98    pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
99        if b.len() != self.n {
100            return Err(SolverError::DimensionMismatch {
101                matrix_n: self.n,
102                rhs_len: b.len(),
103            });
104        }
105
106        let n = self.n;
107        let mut x = vec![0.0f32; n];
108
109        // Apply permutation to b
110        for i in 0..n {
111            x[i] = b[self.pivot[i]];
112        }
113
114        // Forward substitution (L * y = Pb)
115        for i in 1..n {
116            let mut sum = x[i];
117            for j in 0..i {
118                sum -= self.lu[i * n + j] * x[j];
119            }
120            x[i] = sum;
121        }
122
123        // Back substitution (U * x = y)
124        for i in (0..n).rev() {
125            let mut sum = x[i];
126            for j in (i + 1)..n {
127                sum -= self.lu[i * n + j] * x[j];
128            }
129            x[i] = sum / self.lu[i * n + i];
130        }
131
132        Ok(x)
133    }
134
135    /// Extract L matrix (unit lower triangular).
136    pub fn extract_l(&self) -> Vec<f32> {
137        let n = self.n;
138        let mut l = vec![0.0f32; n * n];
139        for i in 0..n {
140            l[i * n + i] = 1.0; // Unit diagonal
141            for j in 0..i {
142                l[i * n + j] = self.lu[i * n + j];
143            }
144        }
145        l
146    }
147
148    /// Extract U matrix (upper triangular).
149    pub fn extract_u(&self) -> Vec<f32> {
150        let n = self.n;
151        let mut u = vec![0.0f32; n * n];
152        for i in 0..n {
153            for j in i..n {
154                u[i * n + j] = self.lu[i * n + j];
155            }
156        }
157        u
158    }
159
160    /// Extract permutation matrix.
161    pub fn extract_p(&self) -> Vec<f32> {
162        let n = self.n;
163        let mut p = vec![0.0f32; n * n];
164        for (i, &pi) in self.pivot.iter().enumerate() {
165            p[i * n + pi] = 1.0;
166        }
167        p
168    }
169}