use scirs2_core::ndarray::{Array1, Array2};
pub fn gram_schmidt(basis: &Array2<f64>) -> (Array2<f64>, Array1<f64>) {
let n = basis.nrows();
let d = basis.ncols();
let mut mu = Array2::<f64>::zeros((n, n));
let mut bnorm_sq = Array1::<f64>::zeros(n);
let mut b_star: Vec<Vec<f64>> = Vec::with_capacity(n);
for i in 0..n {
let mut b_star_i: Vec<f64> = (0..d).map(|k| basis[[i, k]]).collect();
for j in 0..i {
if bnorm_sq[j] < 1e-14 {
mu[[i, j]] = 0.0;
continue;
}
let dot = b_star_i
.iter()
.zip(b_star[j].iter())
.map(|(a, b)| a * b)
.sum::<f64>();
let coeff = dot / bnorm_sq[j];
mu[[i, j]] = coeff;
for k in 0..d {
b_star_i[k] -= coeff * b_star[j][k];
}
}
bnorm_sq[i] = b_star_i.iter().map(|x| x * x).sum();
b_star.push(b_star_i);
}
(mu, bnorm_sq)
}
pub fn update_gram_schmidt_after_swap(
basis: &Array2<f64>,
mu: &mut Array2<f64>,
bnorm_sq: &mut Array1<f64>,
k: usize,
) {
let n = basis.nrows();
let d = basis.ncols();
let mut b_star: Vec<Vec<f64>> = Vec::with_capacity(n);
for i in 0..k.saturating_sub(1) {
let mut b_star_i: Vec<f64> = (0..d).map(|col| basis[[i, col]]).collect();
for j in 0..i {
let coeff = mu[[i, j]];
for col in 0..d {
b_star_i[col] -= coeff * b_star[j][col];
}
}
b_star.push(b_star_i);
}
for i in k.saturating_sub(1)..n {
let mut b_star_i: Vec<f64> = (0..d).map(|col| basis[[i, col]]).collect();
for j in 0..i {
if bnorm_sq[j] < 1e-14 {
mu[[i, j]] = 0.0;
continue;
}
let dot = b_star_i
.iter()
.zip(b_star[j].iter())
.map(|(a, b)| a * b)
.sum::<f64>();
let coeff = dot / bnorm_sq[j];
mu[[i, j]] = coeff;
for col in 0..d {
b_star_i[col] -= coeff * b_star[j][col];
}
}
bnorm_sq[i] = b_star_i.iter().map(|x| x * x).sum();
b_star.push(b_star_i);
}
}
pub fn size_reduce_step(
basis: &mut Array2<f64>,
unimod: &mut Array2<f64>,
mu: &mut Array2<f64>,
k: usize,
j: usize,
) {
let q = mu[[k, j]].round();
if q == 0.0 {
return;
}
let n = basis.ncols();
let num_vecs = basis.nrows();
for col in 0..n {
let bj = basis[[j, col]];
basis[[k, col]] -= q * bj;
}
for col in 0..num_vecs {
let uj = unimod[[j, col]];
unimod[[k, col]] -= q * uj;
}
for l in 0..j {
let mujl = mu[[j, l]];
mu[[k, l]] -= q * mujl;
}
mu[[k, j]] -= q;
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_gram_schmidt_orthogonality() {
let basis = array![[1.0, 1.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
let (mu, bnorm_sq) = gram_schmidt(&basis);
let n = basis.nrows();
let d = basis.ncols();
let mut b_star: Vec<Vec<f64>> = Vec::new();
for i in 0..n {
let mut bsi: Vec<f64> = (0..d).map(|k| basis[[i, k]]).collect();
for j in 0..i {
let c = mu[[i, j]];
for k in 0..d {
bsi[k] -= c * b_star[j][k];
}
}
b_star.push(bsi);
}
for i in 0..n {
for j in 0..i {
let dot: f64 = b_star[i].iter().zip(b_star[j].iter()).map(|(a, b)| a * b).sum();
assert!(dot.abs() < 1e-10, "b̃_{} and b̃_{} not orthogonal: dot={}", i, j, dot);
}
}
for i in 0..n {
let ns: f64 = b_star[i].iter().map(|x| x * x).sum();
assert!((bnorm_sq[i] - ns).abs() < 1e-10, "bnorm_sq[{}] mismatch", i);
}
}
#[test]
fn test_gram_schmidt_mu_coefficients() {
let basis = array![[1.0, 0.0], [1.0, 1.0]];
let (mu, _bnorm_sq) = gram_schmidt(&basis);
assert!((mu[[1, 0]] - 1.0).abs() < 1e-10);
}
#[test]
fn test_gram_schmidt_identity() {
let basis = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let (mu, bnorm_sq) = gram_schmidt(&basis);
for i in 0..3 {
for j in 0..i {
assert!(mu[[i, j]].abs() < 1e-10, "mu[{}][{}] = {}", i, j, mu[[i, j]]);
}
assert!((bnorm_sq[i] - 1.0).abs() < 1e-10);
}
}
}