Skip to main content

trueno_solve/
qr.rs

1//! QR factorization via Householder reflections.
2//!
3//! # Contract: solve-qr-v1.yaml
4//!
5//! A = QR where Q is orthogonal (Q^T Q = I), R is upper triangular.
6//!
7//! ## Proof obligations
8//! - ||Q^T Q - I||_F < √n · u
9//! - ||A - QR||_F / ||A||_F < √n · u
10//! - R is upper triangular
11
12use crate::error::SolverError;
13
14/// QR factorization result.
15#[derive(Debug)]
16pub struct QrFactorization {
17    /// Rows.
18    pub m: usize,
19    /// Columns.
20    pub n: usize,
21    /// Householder vectors stored in lower triangle of QR, R in upper.
22    pub qr: Vec<f32>,
23    /// Householder scalar factors (tau).
24    pub tau: Vec<f32>,
25}
26
27/// QR factorization via Householder reflections.
28///
29/// # Contract: solve-qr-v1.yaml / qr_factorization
30///
31/// # Errors
32///
33/// Returns error if m < n (not tall-skinny).
34#[allow(clippy::cast_precision_loss)]
35pub fn qr_factorize(a: &[f32], m: usize, n: usize) -> Result<QrFactorization, SolverError> {
36    if m < n {
37        return Err(SolverError::QrNotTallSkinny { m, n });
38    }
39    if a.len() != m * n {
40        return Err(SolverError::NotSquare { rows: m, cols: n });
41    }
42
43    let mut qr = a.to_vec();
44    let min_mn = m.min(n);
45    let mut tau = vec![0.0f32; min_mn];
46
47    for k in 0..min_mn {
48        let norm = householder_column_norm(&qr, k, m, n);
49        if norm < f64::from(f32::EPSILON) {
50            tau[k] = 0.0;
51            continue;
52        }
53        let beta = build_householder_vector(&mut qr, &mut tau, k, norm, m, n);
54        apply_householder_to_trailing(&mut qr, tau[k], k, m, n);
55        qr[k * n + k] = beta as f32;
56    }
57
58    Ok(QrFactorization { m, n, qr, tau })
59}
60
61/// Norm of the subcolumn qr[k..m, k] (in f64 for accuracy).
62fn householder_column_norm(qr: &[f32], k: usize, m: usize, n: usize) -> f64 {
63    let mut norm_sq = 0.0f64;
64    for i in k..m {
65        let v = f64::from(qr[i * n + k]);
66        norm_sq += v * v;
67    }
68    norm_sq.sqrt()
69}
70
71/// Compute `tau[k]`, scale the Householder sub-vector in place, and return `beta`.
72#[allow(clippy::cast_possible_truncation)]
73fn build_householder_vector(
74    qr: &mut [f32],
75    tau: &mut [f32],
76    k: usize,
77    norm: f64,
78    m: usize,
79    n: usize,
80) -> f64 {
81    let alpha = f64::from(qr[k * n + k]);
82    let beta = if alpha >= 0.0 { -norm } else { norm };
83    tau[k] = ((beta - alpha) / beta) as f32;
84    let scale = 1.0 / (alpha - beta);
85    for i in (k + 1)..m {
86        qr[i * n + k] = (f64::from(qr[i * n + k]) * scale) as f32;
87    }
88    beta
89}
90
91/// Apply the current Householder reflector to columns `k+1..n`.
92#[allow(clippy::cast_possible_truncation)]
93fn apply_householder_to_trailing(qr: &mut [f32], tau_k: f32, k: usize, m: usize, n: usize) {
94    for j in (k + 1)..n {
95        let mut dot = f64::from(qr[k * n + j]);
96        for i in (k + 1)..m {
97            dot += f64::from(qr[i * n + k]) * f64::from(qr[i * n + j]);
98        }
99        dot *= f64::from(tau_k);
100
101        qr[k * n + j] -= dot as f32;
102        for i in (k + 1)..m {
103            qr[i * n + j] -= (f64::from(qr[i * n + k]) * dot) as f32;
104        }
105    }
106}
107
108impl QrFactorization {
109    /// Extract R (upper triangular, n×n).
110    pub fn extract_r(&self) -> Vec<f32> {
111        let n = self.n;
112        let mut r = vec![0.0f32; n * n];
113        for i in 0..n {
114            for j in i..n {
115                r[i * n + j] = self.qr[i * self.n + j];
116            }
117        }
118        r
119    }
120
121    /// Extract Q (m×m orthogonal matrix).
122    ///
123    /// Builds Q by applying Householder reflectors in reverse.
124    pub fn extract_q(&self) -> Vec<f32> {
125        let m = self.m;
126        let n = self.n;
127
128        // Start with identity
129        let mut q = vec![0.0f32; m * m];
130        for i in 0..m {
131            q[i * m + i] = 1.0;
132        }
133
134        let min_mn = m.min(n);
135        for k in (0..min_mn).rev() {
136            if self.tau[k].abs() < f32::EPSILON {
137                continue;
138            }
139
140            // Build Householder vector v
141            // v[k] = 1.0, v[k+1..m] from qr storage
142            for j in k..m {
143                let mut dot = f64::from(q[k * m + j]);
144                // v[k] = 1.0 implicitly
145                for i in (k + 1)..m {
146                    let vi = f64::from(self.qr[i * n + k]);
147                    dot += vi * f64::from(q[i * m + j]);
148                }
149                dot *= f64::from(self.tau[k]);
150
151                q[k * m + j] -= dot as f32;
152                for i in (k + 1)..m {
153                    let vi = f64::from(self.qr[i * n + k]);
154                    q[i * m + j] -= (vi * dot) as f32;
155                }
156            }
157        }
158
159        q
160    }
161
162    /// Solve Ax = b using QR: x = R^{-1} Q^T b (least-squares for overdetermined).
163    ///
164    /// # Errors
165    ///
166    /// Returns error on dimension mismatch.
167    pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
168        if b.len() != self.m {
169            return Err(SolverError::DimensionMismatch {
170                matrix_n: self.m,
171                rhs_len: b.len(),
172            });
173        }
174
175        let m = self.m;
176        let n = self.n;
177
178        // Apply Q^T to b
179        let mut qtb = b.to_vec();
180        let min_mn = m.min(n);
181        for k in 0..min_mn {
182            if self.tau[k].abs() < f32::EPSILON {
183                continue;
184            }
185            let mut dot = f64::from(qtb[k]);
186            for i in (k + 1)..m {
187                dot += f64::from(self.qr[i * n + k]) * f64::from(qtb[i]);
188            }
189            dot *= f64::from(self.tau[k]);
190            qtb[k] -= dot as f32;
191            for i in (k + 1)..m {
192                qtb[i] -= (f64::from(self.qr[i * n + k]) * dot) as f32;
193            }
194        }
195
196        // Back-substitution with R (upper n×n block)
197        let mut x = vec![0.0f32; n];
198        for i in (0..n).rev() {
199            let mut sum = f64::from(qtb[i]);
200            for j in (i + 1)..n {
201                sum -= f64::from(self.qr[i * n + j]) * f64::from(x[j]);
202            }
203            let diag = f64::from(self.qr[i * n + i]);
204            if diag.abs() < f64::from(f32::EPSILON) {
205                return Err(SolverError::SingularMatrix(i));
206            }
207            x[i] = (sum / diag) as f32;
208        }
209
210        Ok(x)
211    }
212}