trueno_solve/lu.rs
1//! LU factorization with partial pivoting.
2//!
3//! # Contract: solve-lu-v1.yaml
4//!
5//! PA = LU where L is unit lower triangular, U is upper triangular,
6//! P is a permutation matrix (stored as pivot vector).
7//!
8//! ## Proof obligations
9//! - ||PA - LU||_F / ||A||_F < n · u
10//! - L is unit lower triangular, U is upper triangular
11//! - Solution accuracy: ||Ax - b|| / (||A|| · ||x||) < κ(A) · n · u
12
13use crate::error::SolverError;
14
15/// LU factorization result (in-place: L and U stored in the same matrix).
16///
17/// After factorization, the matrix `a` contains:
18/// - Upper triangle (including diagonal): U
19/// - Strict lower triangle: L (unit diagonal implicit)
20#[derive(Debug)]
21pub struct LuFactorization {
22 /// Row dimension.
23 pub n: usize,
24 /// Factored matrix (L\U packed).
25 pub lu: Vec<f32>,
26 /// Pivot indices: row i was swapped with row pivot[i].
27 pub pivot: Vec<usize>,
28}
29
30/// LU factorization with partial pivoting.
31///
32/// # Contract: solve-lu-v1.yaml / lu_factorization
33///
34/// Stores L and U in-place: U in upper triangle, L (unit diagonal) in strict lower.
35///
36/// # Errors
37///
38/// Returns `SingularMatrix` if a zero pivot is encountered.
39#[allow(clippy::cast_precision_loss)]
40pub fn lu_factorize(a: &[f32], n: usize) -> Result<LuFactorization, SolverError> {
41 if a.len() != n * n {
42 return Err(SolverError::NotSquare {
43 rows: n,
44 cols: a.len() / n.max(1),
45 });
46 }
47
48 let mut lu = a.to_vec();
49 let mut pivot: Vec<usize> = (0..n).collect();
50
51 for k in 0..n {
52 // Partial pivoting: find max |a[i][k]| for i >= k
53 let mut max_val = lu[k * n + k].abs();
54 let mut max_row = k;
55 for i in (k + 1)..n {
56 let val = lu[i * n + k].abs();
57 if val > max_val {
58 max_val = val;
59 max_row = i;
60 }
61 }
62
63 if max_val < f32::EPSILON * 1e3 {
64 return Err(SolverError::SingularMatrix(k));
65 }
66
67 // Swap rows k and max_row
68 if max_row != k {
69 pivot.swap(k, max_row);
70 for j in 0..n {
71 lu.swap(k * n + j, max_row * n + j);
72 }
73 }
74
75 // Eliminate below diagonal
76 let pivot_val = lu[k * n + k];
77 for i in (k + 1)..n {
78 let factor = lu[i * n + k] / pivot_val;
79 lu[i * n + k] = factor; // Store L factor
80
81 for j in (k + 1)..n {
82 lu[i * n + j] -= factor * lu[k * n + j];
83 }
84 }
85 }
86
87 Ok(LuFactorization { n, lu, pivot })
88}
89
90impl LuFactorization {
91 /// Solve Ax = b using the LU factorization.
92 ///
93 /// # Contract: solve-lu-v1.yaml / solution_accuracy
94 ///
95 /// # Errors
96 ///
97 /// Returns error on dimension mismatch.
98 pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
99 if b.len() != self.n {
100 return Err(SolverError::DimensionMismatch {
101 matrix_n: self.n,
102 rhs_len: b.len(),
103 });
104 }
105
106 let n = self.n;
107 let mut x = vec![0.0f32; n];
108
109 // Apply permutation to b
110 for i in 0..n {
111 x[i] = b[self.pivot[i]];
112 }
113
114 // Forward substitution (L * y = Pb)
115 for i in 1..n {
116 let mut sum = x[i];
117 for j in 0..i {
118 sum -= self.lu[i * n + j] * x[j];
119 }
120 x[i] = sum;
121 }
122
123 // Back substitution (U * x = y)
124 for i in (0..n).rev() {
125 let mut sum = x[i];
126 for j in (i + 1)..n {
127 sum -= self.lu[i * n + j] * x[j];
128 }
129 x[i] = sum / self.lu[i * n + i];
130 }
131
132 Ok(x)
133 }
134
135 /// Extract L matrix (unit lower triangular).
136 pub fn extract_l(&self) -> Vec<f32> {
137 let n = self.n;
138 let mut l = vec![0.0f32; n * n];
139 for i in 0..n {
140 l[i * n + i] = 1.0; // Unit diagonal
141 for j in 0..i {
142 l[i * n + j] = self.lu[i * n + j];
143 }
144 }
145 l
146 }
147
148 /// Extract U matrix (upper triangular).
149 pub fn extract_u(&self) -> Vec<f32> {
150 let n = self.n;
151 let mut u = vec![0.0f32; n * n];
152 for i in 0..n {
153 for j in i..n {
154 u[i * n + j] = self.lu[i * n + j];
155 }
156 }
157 u
158 }
159
160 /// Extract permutation matrix.
161 pub fn extract_p(&self) -> Vec<f32> {
162 let n = self.n;
163 let mut p = vec![0.0f32; n * n];
164 for (i, &pi) in self.pivot.iter().enumerate() {
165 p[i * n + pi] = 1.0;
166 }
167 p
168 }
169}