use crate::error::SolverError;
#[derive(Debug)]
pub struct QrFactorization {
pub m: usize,
pub n: usize,
pub qr: Vec<f32>,
pub tau: Vec<f32>,
}
#[allow(clippy::cast_precision_loss)]
pub fn qr_factorize(a: &[f32], m: usize, n: usize) -> Result<QrFactorization, SolverError> {
if m < n {
return Err(SolverError::QrNotTallSkinny { m, n });
}
if a.len() != m * n {
return Err(SolverError::NotSquare { rows: m, cols: n });
}
let mut qr = a.to_vec();
let min_mn = m.min(n);
let mut tau = vec![0.0f32; min_mn];
for k in 0..min_mn {
let norm = householder_column_norm(&qr, k, m, n);
if norm < f64::from(f32::EPSILON) {
tau[k] = 0.0;
continue;
}
let beta = build_householder_vector(&mut qr, &mut tau, k, norm, m, n);
apply_householder_to_trailing(&mut qr, tau[k], k, m, n);
qr[k * n + k] = beta as f32;
}
Ok(QrFactorization { m, n, qr, tau })
}
fn householder_column_norm(qr: &[f32], k: usize, m: usize, n: usize) -> f64 {
let mut norm_sq = 0.0f64;
for i in k..m {
let v = f64::from(qr[i * n + k]);
norm_sq += v * v;
}
norm_sq.sqrt()
}
#[allow(clippy::cast_possible_truncation)]
fn build_householder_vector(
qr: &mut [f32],
tau: &mut [f32],
k: usize,
norm: f64,
m: usize,
n: usize,
) -> f64 {
let alpha = f64::from(qr[k * n + k]);
let beta = if alpha >= 0.0 { -norm } else { norm };
tau[k] = ((beta - alpha) / beta) as f32;
let scale = 1.0 / (alpha - beta);
for i in (k + 1)..m {
qr[i * n + k] = (f64::from(qr[i * n + k]) * scale) as f32;
}
beta
}
#[allow(clippy::cast_possible_truncation)]
fn apply_householder_to_trailing(qr: &mut [f32], tau_k: f32, k: usize, m: usize, n: usize) {
for j in (k + 1)..n {
let mut dot = f64::from(qr[k * n + j]);
for i in (k + 1)..m {
dot += f64::from(qr[i * n + k]) * f64::from(qr[i * n + j]);
}
dot *= f64::from(tau_k);
qr[k * n + j] -= dot as f32;
for i in (k + 1)..m {
qr[i * n + j] -= (f64::from(qr[i * n + k]) * dot) as f32;
}
}
}
impl QrFactorization {
pub fn extract_r(&self) -> Vec<f32> {
let n = self.n;
let mut r = vec![0.0f32; n * n];
for i in 0..n {
for j in i..n {
r[i * n + j] = self.qr[i * self.n + j];
}
}
r
}
pub fn extract_q(&self) -> Vec<f32> {
let m = self.m;
let n = self.n;
let mut q = vec![0.0f32; m * m];
for i in 0..m {
q[i * m + i] = 1.0;
}
let min_mn = m.min(n);
for k in (0..min_mn).rev() {
if self.tau[k].abs() < f32::EPSILON {
continue;
}
for j in k..m {
let mut dot = f64::from(q[k * m + j]);
for i in (k + 1)..m {
let vi = f64::from(self.qr[i * n + k]);
dot += vi * f64::from(q[i * m + j]);
}
dot *= f64::from(self.tau[k]);
q[k * m + j] -= dot as f32;
for i in (k + 1)..m {
let vi = f64::from(self.qr[i * n + k]);
q[i * m + j] -= (vi * dot) as f32;
}
}
}
q
}
pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
if b.len() != self.m {
return Err(SolverError::DimensionMismatch {
matrix_n: self.m,
rhs_len: b.len(),
});
}
let m = self.m;
let n = self.n;
let mut qtb = b.to_vec();
let min_mn = m.min(n);
for k in 0..min_mn {
if self.tau[k].abs() < f32::EPSILON {
continue;
}
let mut dot = f64::from(qtb[k]);
for i in (k + 1)..m {
dot += f64::from(self.qr[i * n + k]) * f64::from(qtb[i]);
}
dot *= f64::from(self.tau[k]);
qtb[k] -= dot as f32;
for i in (k + 1)..m {
qtb[i] -= (f64::from(self.qr[i * n + k]) * dot) as f32;
}
}
let mut x = vec![0.0f32; n];
for i in (0..n).rev() {
let mut sum = f64::from(qtb[i]);
for j in (i + 1)..n {
sum -= f64::from(self.qr[i * n + j]) * f64::from(x[j]);
}
let diag = f64::from(self.qr[i * n + i]);
if diag.abs() < f64::from(f32::EPSILON) {
return Err(SolverError::SingularMatrix(i));
}
x[i] = (sum / diag) as f32;
}
Ok(x)
}
}