Skip to main content

trueno_solve/
qr.rs

1//! QR factorization via Householder reflections.
2//!
3//! # Contract: solve-qr-v1.yaml
4//!
5//! A = QR where Q is orthogonal (Q^T Q = I), R is upper triangular.
6//!
7//! ## Proof obligations
8//! - ||Q^T Q - I||_F < √n · u
9//! - ||A - QR||_F / ||A||_F < √n · u
10//! - R is upper triangular
11
12use crate::error::SolverError;
13
14/// QR factorization result.
15#[derive(Debug)]
16pub struct QrFactorization {
17    /// Rows.
18    pub m: usize,
19    /// Columns.
20    pub n: usize,
21    /// Householder vectors stored in lower triangle of QR, R in upper.
22    pub qr: Vec<f32>,
23    /// Householder scalar factors (tau).
24    pub tau: Vec<f32>,
25}
26
27/// QR factorization via Householder reflections.
28///
29/// # Contract: solve-qr-v1.yaml / qr_factorization
30///
31/// # Errors
32///
33/// Returns error if m < n (not tall-skinny).
34#[allow(clippy::cast_precision_loss)]
35pub fn qr_factorize(a: &[f32], m: usize, n: usize) -> Result<QrFactorization, SolverError> {
36    if m < n {
37        return Err(SolverError::QrNotTallSkinny { m, n });
38    }
39    if a.len() != m * n {
40        return Err(SolverError::NotSquare { rows: m, cols: n });
41    }
42
43    let mut qr = a.to_vec();
44    let min_mn = m.min(n);
45    let mut tau = vec![0.0f32; min_mn];
46
47    for k in 0..min_mn {
48        // Compute Householder vector for column k, rows k..m
49        let mut norm_sq = 0.0f64;
50        for i in k..m {
51            let v = f64::from(qr[i * n + k]);
52            norm_sq += v * v;
53        }
54        let norm = norm_sq.sqrt();
55
56        if norm < f64::from(f32::EPSILON) {
57            tau[k] = 0.0;
58            continue;
59        }
60
61        // Choose sign to avoid cancellation
62        let alpha = f64::from(qr[k * n + k]);
63        let beta = if alpha >= 0.0 { -norm } else { norm };
64
65        tau[k] = ((beta - alpha) / beta) as f32;
66        let scale = 1.0 / (alpha - beta);
67
68        // Scale the Householder vector
69        for i in (k + 1)..m {
70            qr[i * n + k] = (f64::from(qr[i * n + k]) * scale) as f32;
71        }
72        qr[k * n + k] = beta as f32;
73
74        // Apply Householder reflection to remaining columns
75        for j in (k + 1)..n {
76            let mut dot = f64::from(qr[k * n + j]);
77            for i in (k + 1)..m {
78                dot += f64::from(qr[i * n + k]) * f64::from(qr[i * n + j]);
79            }
80            dot *= f64::from(tau[k]);
81
82            qr[k * n + j] -= dot as f32;
83            for i in (k + 1)..m {
84                qr[i * n + j] -= (f64::from(qr[i * n + k]) * dot) as f32;
85            }
86        }
87    }
88
89    Ok(QrFactorization { m, n, qr, tau })
90}
91
92impl QrFactorization {
93    /// Extract R (upper triangular, n×n).
94    pub fn extract_r(&self) -> Vec<f32> {
95        let n = self.n;
96        let mut r = vec![0.0f32; n * n];
97        for i in 0..n {
98            for j in i..n {
99                r[i * n + j] = self.qr[i * self.n + j];
100            }
101        }
102        r
103    }
104
105    /// Extract Q (m×m orthogonal matrix).
106    ///
107    /// Builds Q by applying Householder reflectors in reverse.
108    pub fn extract_q(&self) -> Vec<f32> {
109        let m = self.m;
110        let n = self.n;
111
112        // Start with identity
113        let mut q = vec![0.0f32; m * m];
114        for i in 0..m {
115            q[i * m + i] = 1.0;
116        }
117
118        let min_mn = m.min(n);
119        for k in (0..min_mn).rev() {
120            if self.tau[k].abs() < f32::EPSILON {
121                continue;
122            }
123
124            // Build Householder vector v
125            // v[k] = 1.0, v[k+1..m] from qr storage
126            for j in k..m {
127                let mut dot = f64::from(q[k * m + j]);
128                // v[k] = 1.0 implicitly
129                for i in (k + 1)..m {
130                    let vi = f64::from(self.qr[i * n + k]);
131                    dot += vi * f64::from(q[i * m + j]);
132                }
133                dot *= f64::from(self.tau[k]);
134
135                q[k * m + j] -= dot as f32;
136                for i in (k + 1)..m {
137                    let vi = f64::from(self.qr[i * n + k]);
138                    q[i * m + j] -= (vi * dot) as f32;
139                }
140            }
141        }
142
143        q
144    }
145
146    /// Solve Ax = b using QR: x = R^{-1} Q^T b (least-squares for overdetermined).
147    ///
148    /// # Errors
149    ///
150    /// Returns error on dimension mismatch.
151    pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
152        if b.len() != self.m {
153            return Err(SolverError::DimensionMismatch {
154                matrix_n: self.m,
155                rhs_len: b.len(),
156            });
157        }
158
159        let m = self.m;
160        let n = self.n;
161
162        // Apply Q^T to b
163        let mut qtb = b.to_vec();
164        let min_mn = m.min(n);
165        for k in 0..min_mn {
166            if self.tau[k].abs() < f32::EPSILON {
167                continue;
168            }
169            let mut dot = f64::from(qtb[k]);
170            for i in (k + 1)..m {
171                dot += f64::from(self.qr[i * n + k]) * f64::from(qtb[i]);
172            }
173            dot *= f64::from(self.tau[k]);
174            qtb[k] -= dot as f32;
175            for i in (k + 1)..m {
176                qtb[i] -= (f64::from(self.qr[i * n + k]) * dot) as f32;
177            }
178        }
179
180        // Back-substitution with R (upper n×n block)
181        let mut x = vec![0.0f32; n];
182        for i in (0..n).rev() {
183            let mut sum = f64::from(qtb[i]);
184            for j in (i + 1)..n {
185                sum -= f64::from(self.qr[i * n + j]) * f64::from(x[j]);
186            }
187            let diag = f64::from(self.qr[i * n + i]);
188            if diag.abs() < f64::from(f32::EPSILON) {
189                return Err(SolverError::SingularMatrix(i));
190            }
191            x[i] = (sum / diag) as f32;
192        }
193
194        Ok(x)
195    }
196}