1use crate::error::SolverError;
13
14#[derive(Debug)]
16pub struct QrFactorization {
17 pub m: usize,
19 pub n: usize,
21 pub qr: Vec<f32>,
23 pub tau: Vec<f32>,
25}
26
27#[allow(clippy::cast_precision_loss)]
35pub fn qr_factorize(a: &[f32], m: usize, n: usize) -> Result<QrFactorization, SolverError> {
36 if m < n {
37 return Err(SolverError::QrNotTallSkinny { m, n });
38 }
39 if a.len() != m * n {
40 return Err(SolverError::NotSquare { rows: m, cols: n });
41 }
42
43 let mut qr = a.to_vec();
44 let min_mn = m.min(n);
45 let mut tau = vec![0.0f32; min_mn];
46
47 for k in 0..min_mn {
48 let norm = householder_column_norm(&qr, k, m, n);
49 if norm < f64::from(f32::EPSILON) {
50 tau[k] = 0.0;
51 continue;
52 }
53 let beta = build_householder_vector(&mut qr, &mut tau, k, norm, m, n);
54 apply_householder_to_trailing(&mut qr, tau[k], k, m, n);
55 qr[k * n + k] = beta as f32;
56 }
57
58 Ok(QrFactorization { m, n, qr, tau })
59}
60
61fn householder_column_norm(qr: &[f32], k: usize, m: usize, n: usize) -> f64 {
63 let mut norm_sq = 0.0f64;
64 for i in k..m {
65 let v = f64::from(qr[i * n + k]);
66 norm_sq += v * v;
67 }
68 norm_sq.sqrt()
69}
70
71#[allow(clippy::cast_possible_truncation)]
73fn build_householder_vector(
74 qr: &mut [f32],
75 tau: &mut [f32],
76 k: usize,
77 norm: f64,
78 m: usize,
79 n: usize,
80) -> f64 {
81 let alpha = f64::from(qr[k * n + k]);
82 let beta = if alpha >= 0.0 { -norm } else { norm };
83 tau[k] = ((beta - alpha) / beta) as f32;
84 let scale = 1.0 / (alpha - beta);
85 for i in (k + 1)..m {
86 qr[i * n + k] = (f64::from(qr[i * n + k]) * scale) as f32;
87 }
88 beta
89}
90
91#[allow(clippy::cast_possible_truncation)]
93fn apply_householder_to_trailing(qr: &mut [f32], tau_k: f32, k: usize, m: usize, n: usize) {
94 for j in (k + 1)..n {
95 let mut dot = f64::from(qr[k * n + j]);
96 for i in (k + 1)..m {
97 dot += f64::from(qr[i * n + k]) * f64::from(qr[i * n + j]);
98 }
99 dot *= f64::from(tau_k);
100
101 qr[k * n + j] -= dot as f32;
102 for i in (k + 1)..m {
103 qr[i * n + j] -= (f64::from(qr[i * n + k]) * dot) as f32;
104 }
105 }
106}
107
108impl QrFactorization {
109 pub fn extract_r(&self) -> Vec<f32> {
111 let n = self.n;
112 let mut r = vec![0.0f32; n * n];
113 for i in 0..n {
114 for j in i..n {
115 r[i * n + j] = self.qr[i * self.n + j];
116 }
117 }
118 r
119 }
120
121 pub fn extract_q(&self) -> Vec<f32> {
125 let m = self.m;
126 let n = self.n;
127
128 let mut q = vec![0.0f32; m * m];
130 for i in 0..m {
131 q[i * m + i] = 1.0;
132 }
133
134 let min_mn = m.min(n);
135 for k in (0..min_mn).rev() {
136 if self.tau[k].abs() < f32::EPSILON {
137 continue;
138 }
139
140 for j in k..m {
143 let mut dot = f64::from(q[k * m + j]);
144 for i in (k + 1)..m {
146 let vi = f64::from(self.qr[i * n + k]);
147 dot += vi * f64::from(q[i * m + j]);
148 }
149 dot *= f64::from(self.tau[k]);
150
151 q[k * m + j] -= dot as f32;
152 for i in (k + 1)..m {
153 let vi = f64::from(self.qr[i * n + k]);
154 q[i * m + j] -= (vi * dot) as f32;
155 }
156 }
157 }
158
159 q
160 }
161
162 pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
168 if b.len() != self.m {
169 return Err(SolverError::DimensionMismatch {
170 matrix_n: self.m,
171 rhs_len: b.len(),
172 });
173 }
174
175 let m = self.m;
176 let n = self.n;
177
178 let mut qtb = b.to_vec();
180 let min_mn = m.min(n);
181 for k in 0..min_mn {
182 if self.tau[k].abs() < f32::EPSILON {
183 continue;
184 }
185 let mut dot = f64::from(qtb[k]);
186 for i in (k + 1)..m {
187 dot += f64::from(self.qr[i * n + k]) * f64::from(qtb[i]);
188 }
189 dot *= f64::from(self.tau[k]);
190 qtb[k] -= dot as f32;
191 for i in (k + 1)..m {
192 qtb[i] -= (f64::from(self.qr[i * n + k]) * dot) as f32;
193 }
194 }
195
196 let mut x = vec![0.0f32; n];
198 for i in (0..n).rev() {
199 let mut sum = f64::from(qtb[i]);
200 for j in (i + 1)..n {
201 sum -= f64::from(self.qr[i * n + j]) * f64::from(x[j]);
202 }
203 let diag = f64::from(self.qr[i * n + i]);
204 if diag.abs() < f64::from(f32::EPSILON) {
205 return Err(SolverError::SingularMatrix(i));
206 }
207 x[i] = (sum / diag) as f32;
208 }
209
210 Ok(x)
211 }
212}