1use crate::error::SolverError;
13
14#[derive(Debug)]
16pub struct QrFactorization {
17 pub m: usize,
19 pub n: usize,
21 pub qr: Vec<f32>,
23 pub tau: Vec<f32>,
25}
26
27#[allow(clippy::cast_precision_loss)]
35pub fn qr_factorize(a: &[f32], m: usize, n: usize) -> Result<QrFactorization, SolverError> {
36 if m < n {
37 return Err(SolverError::QrNotTallSkinny { m, n });
38 }
39 if a.len() != m * n {
40 return Err(SolverError::NotSquare { rows: m, cols: n });
41 }
42
43 let mut qr = a.to_vec();
44 let min_mn = m.min(n);
45 let mut tau = vec![0.0f32; min_mn];
46
47 for k in 0..min_mn {
48 let mut norm_sq = 0.0f64;
50 for i in k..m {
51 let v = f64::from(qr[i * n + k]);
52 norm_sq += v * v;
53 }
54 let norm = norm_sq.sqrt();
55
56 if norm < f64::from(f32::EPSILON) {
57 tau[k] = 0.0;
58 continue;
59 }
60
61 let alpha = f64::from(qr[k * n + k]);
63 let beta = if alpha >= 0.0 { -norm } else { norm };
64
65 tau[k] = ((beta - alpha) / beta) as f32;
66 let scale = 1.0 / (alpha - beta);
67
68 for i in (k + 1)..m {
70 qr[i * n + k] = (f64::from(qr[i * n + k]) * scale) as f32;
71 }
72 qr[k * n + k] = beta as f32;
73
74 for j in (k + 1)..n {
76 let mut dot = f64::from(qr[k * n + j]);
77 for i in (k + 1)..m {
78 dot += f64::from(qr[i * n + k]) * f64::from(qr[i * n + j]);
79 }
80 dot *= f64::from(tau[k]);
81
82 qr[k * n + j] -= dot as f32;
83 for i in (k + 1)..m {
84 qr[i * n + j] -= (f64::from(qr[i * n + k]) * dot) as f32;
85 }
86 }
87 }
88
89 Ok(QrFactorization { m, n, qr, tau })
90}
91
92impl QrFactorization {
93 pub fn extract_r(&self) -> Vec<f32> {
95 let n = self.n;
96 let mut r = vec![0.0f32; n * n];
97 for i in 0..n {
98 for j in i..n {
99 r[i * n + j] = self.qr[i * self.n + j];
100 }
101 }
102 r
103 }
104
105 pub fn extract_q(&self) -> Vec<f32> {
109 let m = self.m;
110 let n = self.n;
111
112 let mut q = vec![0.0f32; m * m];
114 for i in 0..m {
115 q[i * m + i] = 1.0;
116 }
117
118 let min_mn = m.min(n);
119 for k in (0..min_mn).rev() {
120 if self.tau[k].abs() < f32::EPSILON {
121 continue;
122 }
123
124 for j in k..m {
127 let mut dot = f64::from(q[k * m + j]);
128 for i in (k + 1)..m {
130 let vi = f64::from(self.qr[i * n + k]);
131 dot += vi * f64::from(q[i * m + j]);
132 }
133 dot *= f64::from(self.tau[k]);
134
135 q[k * m + j] -= dot as f32;
136 for i in (k + 1)..m {
137 let vi = f64::from(self.qr[i * n + k]);
138 q[i * m + j] -= (vi * dot) as f32;
139 }
140 }
141 }
142
143 q
144 }
145
146 pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
152 if b.len() != self.m {
153 return Err(SolverError::DimensionMismatch {
154 matrix_n: self.m,
155 rhs_len: b.len(),
156 });
157 }
158
159 let m = self.m;
160 let n = self.n;
161
162 let mut qtb = b.to_vec();
164 let min_mn = m.min(n);
165 for k in 0..min_mn {
166 if self.tau[k].abs() < f32::EPSILON {
167 continue;
168 }
169 let mut dot = f64::from(qtb[k]);
170 for i in (k + 1)..m {
171 dot += f64::from(self.qr[i * n + k]) * f64::from(qtb[i]);
172 }
173 dot *= f64::from(self.tau[k]);
174 qtb[k] -= dot as f32;
175 for i in (k + 1)..m {
176 qtb[i] -= (f64::from(self.qr[i * n + k]) * dot) as f32;
177 }
178 }
179
180 let mut x = vec![0.0f32; n];
182 for i in (0..n).rev() {
183 let mut sum = f64::from(qtb[i]);
184 for j in (i + 1)..n {
185 sum -= f64::from(self.qr[i * n + j]) * f64::from(x[j]);
186 }
187 let diag = f64::from(self.qr[i * n + i]);
188 if diag.abs() < f64::from(f32::EPSILON) {
189 return Err(SolverError::SingularMatrix(i));
190 }
191 x[i] = (sum / diag) as f32;
192 }
193
194 Ok(x)
195 }
196}