use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{s, Array2};
use super::DistribConfig;
#[derive(Debug, Clone)]
pub struct HouseholderReflector {
pub v: Vec<f64>,
pub beta: f64,
pub applied_to: std::ops::Range<usize>,
}
impl HouseholderReflector {
pub fn from_vector(x: &[f64]) -> (Vec<f64>, f64) {
let n = x.len();
if n == 0 {
return (vec![], 0.0);
}
let sigma: f64 = x.iter().skip(1).map(|xi| xi * xi).sum::<f64>();
let mut v: Vec<f64> = x.to_vec();
if sigma == 0.0 && x[0] >= 0.0 {
return (v, 0.0);
}
let x_norm = (x[0] * x[0] + sigma).sqrt();
if x[0] <= 0.0 {
v[0] = x[0] - x_norm;
} else {
v[0] = x[0] + x_norm;
}
let vt_v: f64 = v[0] * v[0] + sigma;
let beta = if vt_v.abs() < f64::EPSILON {
0.0
} else {
2.0 / vt_v
};
(v, beta)
}
pub fn apply_left_to(&self, a: &mut Array2<f64>, row_offset: usize, col_offset: usize) {
let n = self.v.len();
let nrows = a.nrows();
let ncols = a.ncols();
if row_offset + n > nrows || col_offset >= ncols {
return;
}
let n_cols_apply = ncols - col_offset;
let mut w = vec![0.0f64; n_cols_apply];
for j in 0..n_cols_apply {
let mut d = 0.0f64;
for i in 0..n {
d += self.v[i] * a[[row_offset + i, col_offset + j]];
}
w[j] = d;
}
for i in 0..n {
for j in 0..n_cols_apply {
a[[row_offset + i, col_offset + j]] -= self.beta * self.v[i] * w[j];
}
}
}
}
pub fn local_qr_panel(a_panel: &mut Array2<f64>) -> Vec<HouseholderReflector> {
let m = a_panel.nrows();
let n = a_panel.ncols();
let n_house = m.min(n);
let mut reflectors = Vec::with_capacity(n_house);
for col in 0..n_house {
let sub_len = m - col;
let mut x = vec![0.0f64; sub_len];
for i in 0..sub_len {
x[i] = a_panel[[col + i, col]];
}
let (v, beta) = HouseholderReflector::from_vector(&x);
if beta.abs() < f64::EPSILON {
reflectors.push(HouseholderReflector {
v,
beta,
applied_to: col..col + sub_len,
});
continue;
}
let refl = HouseholderReflector {
v: v.clone(),
beta,
applied_to: col..col + sub_len,
};
refl.apply_left_to(a_panel, col, col);
for i in (col + 1)..m {
a_panel[[i, col]] = 0.0;
}
reflectors.push(refl);
}
reflectors
}
pub fn build_q_from_reflectors(reflectors: &[HouseholderReflector], m: usize) -> Array2<f64> {
let mut q = Array2::<f64>::eye(m);
for refl in reflectors.iter().rev() {
if refl.beta.abs() < f64::EPSILON {
continue;
}
let row_offset = refl.applied_to.start;
let n = refl.v.len();
if row_offset + n > m {
continue;
}
let mut w = vec![0.0f64; m];
for j in 0..m {
let mut d = 0.0f64;
for i in 0..n {
d += refl.v[i] * q[[row_offset + i, j]];
}
w[j] = d;
}
for i in 0..n {
for j in 0..m {
q[[row_offset + i, j]] -= refl.beta * refl.v[i] * w[j];
}
}
}
q
}
fn tournament_combine_pair(
r_top: &Array2<f64>,
r_bottom: &Array2<f64>,
) -> LinalgResult<(Array2<f64>, Vec<HouseholderReflector>)> {
let n = r_top.ncols();
if r_bottom.ncols() != n {
return Err(LinalgError::DimensionError(
"tournament_combine_pair: R matrices must have same number of columns".to_string(),
));
}
let nt = r_top.nrows();
let nb = r_bottom.nrows();
let mut stacked = Array2::<f64>::zeros((nt + nb, n));
stacked.slice_mut(s![..nt, ..]).assign(r_top);
stacked.slice_mut(s![nt.., ..]).assign(r_bottom);
let reflectors = local_qr_panel(&mut stacked);
let r_new = stacked.slice(s![..n, ..]).to_owned();
Ok((r_new, reflectors))
}
pub fn tournament_qr_reduction(
panels: &[Array2<f64>],
) -> LinalgResult<(Array2<f64>, Vec<HouseholderReflector>)> {
if panels.is_empty() {
return Err(LinalgError::ValueError(
"tournament_qr_reduction: panels slice is empty".to_string(),
));
}
if panels.len() == 1 {
return Ok((panels[0].clone(), vec![]));
}
let mut current: Vec<Array2<f64>> = panels.to_vec();
let mut all_reflectors: Vec<HouseholderReflector> = Vec::new();
while current.len() > 1 {
let mut next: Vec<Array2<f64>> = Vec::new();
let mut i = 0;
while i < current.len() {
if i + 1 < current.len() {
let (r_new, mut refls) = tournament_combine_pair(¤t[i], ¤t[i + 1])?;
next.push(r_new);
all_reflectors.append(&mut refls);
i += 2;
} else {
next.push(current[i].clone());
i += 1;
}
}
current = next;
}
Ok((
current.into_iter().next().ok_or_else(|| {
LinalgError::ComputationError("tournament_qr_reduction: empty result".to_string())
})?,
all_reflectors,
))
}
pub fn caqr_simulate(
a: &Array2<f64>,
_config: &DistribConfig,
) -> LinalgResult<(Array2<f64>, Array2<f64>)> {
let m = a.nrows();
let n = a.ncols();
if m == 0 || n == 0 {
return Err(LinalgError::ValueError(
"caqr_simulate: input matrix must be non-empty".to_string(),
));
}
let mut r_work = a.to_owned();
let mut all_reflectors: Vec<HouseholderReflector> = Vec::new();
let n_house = m.min(n);
for col in 0..n_house {
let sub_len = m - col;
let mut x = vec![0.0f64; sub_len];
for i in 0..sub_len {
x[i] = r_work[[col + i, col]];
}
let (v, beta) = HouseholderReflector::from_vector(&x);
let refl = HouseholderReflector {
v,
beta,
applied_to: col..col + sub_len,
};
if refl.beta.abs() > f64::EPSILON {
refl.apply_left_to(&mut r_work, col, col);
for i in (col + 1)..m {
r_work[[i, col]] = 0.0;
}
}
all_reflectors.push(refl);
}
let q = build_q_from_reflectors(&all_reflectors, m);
Ok((q, r_work))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
fn matmul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let m = a.nrows();
let k = a.ncols();
let n = b.ncols();
let mut c = Array2::<f64>::zeros((m, n));
for i in 0..m {
for ki in 0..k {
for j in 0..n {
c[[i, j]] += a[[i, ki]] * b[[ki, j]];
}
}
}
c
}
fn frob_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let mut s = 0.0f64;
for (x, y) in a.iter().zip(b.iter()) {
let d = x - y;
s += d * d;
}
s.sqrt()
}
#[test]
fn test_householder_reflector_construction() {
let x = vec![3.0, 4.0, 0.0];
let (v, beta) = HouseholderReflector::from_vector(&x);
let vt_v: f64 = v.iter().map(|vi| vi * vi).sum();
assert_abs_diff_eq!(beta, 2.0 / vt_v, epsilon = 1e-12);
}
#[test]
fn test_householder_reflector_orthogonality() {
let x = vec![1.0, 2.0, 3.0];
let (v, beta) = HouseholderReflector::from_vector(&x);
let n = v.len();
let mut h = Array2::<f64>::eye(n);
for i in 0..n {
for j in 0..n {
h[[i, j]] -= beta * v[i] * v[j];
}
}
let ht_h = matmul(&h.t().to_owned(), &h);
let eye = Array2::<f64>::eye(n);
assert_abs_diff_eq!(frob_diff(&ht_h, &eye), 0.0, epsilon = 1e-12);
}
#[test]
fn test_local_qr_panel_qtq_identity() {
let mut a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]];
let original_a = a.clone();
let reflectors = local_qr_panel(&mut a);
let q = build_q_from_reflectors(&reflectors, original_a.nrows());
let qtq = matmul(&q.t().to_owned(), &q);
let eye = Array2::<f64>::eye(3);
assert_abs_diff_eq!(frob_diff(&qtq, &eye), 0.0, epsilon = 1e-11);
}
#[test]
fn test_local_qr_panel_qr_equals_a() {
let original_a = array![
[12.0_f64, -51.0, 4.0],
[6.0, 167.0, -68.0],
[-4.0, 24.0, -41.0]
];
let mut a_work = original_a.clone();
let reflectors = local_qr_panel(&mut a_work);
let q = build_q_from_reflectors(&reflectors, original_a.nrows());
let r = a_work.clone();
let qr = matmul(&q, &r);
assert_abs_diff_eq!(frob_diff(&qr, &original_a), 0.0, epsilon = 1e-9);
}
#[test]
fn test_tournament_qr_preserves_r_factor() {
let r0 = array![[3.0_f64, 1.0, 2.0], [0.0, 2.0, -1.0], [0.0, 0.0, 1.5]];
let r1 = array![[2.0_f64, -1.0, 1.0], [0.0, 1.5, 0.5], [0.0, 0.0, 0.8]];
let (r_combined, _) =
tournament_qr_reduction(&[r0.clone(), r1.clone()]).expect("tournament failed");
assert!(r_combined.nrows() <= 3);
assert!(r_combined.ncols() == 3);
for i in 0..r_combined.nrows() {
for j in 0..i {
assert_abs_diff_eq!(r_combined[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_caqr_qr_equals_a() {
let a = array![
[1.0_f64, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 10.0],
[1.0, -1.0, 2.0]
];
let config = DistribConfig {
block_size: 2,
n_proc_rows: 2,
n_proc_cols: 2,
};
let (q, r) = caqr_simulate(&a, &config).expect("caqr failed");
let qr_prod = matmul(&q, &r);
assert_abs_diff_eq!(frob_diff(&qr_prod, &a), 0.0, epsilon = 1e-9);
}
#[test]
fn test_caqr_q_orthogonal() {
let a = Array2::<f64>::from_shape_fn((8, 5), |(i, j)| i as f64 * 0.7 - j as f64 * 0.3);
let config = DistribConfig {
block_size: 3,
n_proc_rows: 2,
n_proc_cols: 2,
};
let (q, _r) = caqr_simulate(&a, &config).expect("caqr failed");
let qtq = matmul(&q.t().to_owned(), &q);
let eye = Array2::<f64>::eye(q.nrows());
assert_abs_diff_eq!(frob_diff(&qtq, &eye), 0.0, epsilon = 1e-9);
}
}