use crate::tensor::DenseTensor;
use crate::tensor::TensorBase;
use crate::tensor::TensorError;
pub fn qr_decompose(tensor: &DenseTensor) -> Result<(DenseTensor, DenseTensor), TensorError> {
let shape = tensor.shape();
if shape.len() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: shape.len(),
});
}
let (m, n) = (shape[0], shape[1]);
if m < n {
return Err(TensorError::ShapeMismatch {
expected: vec![n, m],
got: vec![m, n],
});
}
let data = tensor.data();
let mut q = data.to_vec();
let mut r = vec![0.0; n * n];
for j in 0..n {
let mut col_norm = 0.0;
for k in 0..m {
col_norm += q[k * n + j] * q[k * n + j];
}
col_norm = col_norm.sqrt();
for i in 0..j {
let mut dot = 0.0;
for k in 0..m {
dot += q[k * n + i] * q[k * n + j];
}
r[i * n + j] = dot;
for k in 0..m {
q[k * n + j] -= dot * q[k * n + i];
}
}
for i in 0..j {
let mut dot = 0.0;
for k in 0..m {
dot += q[k * n + i] * q[k * n + j];
}
if dot.abs() > 1e-14 {
for k in 0..m {
q[k * n + j] -= dot * q[k * n + i];
}
r[i * n + j] += dot;
}
}
let mut new_norm = 0.0;
for k in 0..m {
new_norm += q[k * n + j] * q[k * n + j];
}
new_norm = new_norm.sqrt();
let rel_norm = if col_norm > 1e-14 { new_norm / col_norm } else { new_norm };
if rel_norm > 1e-12 && new_norm > 1e-14 {
r[j * n + j] = new_norm;
for k in 0..m {
q[k * n + j] /= new_norm;
}
} else {
r[j * n + j] = 0.0;
for k in 0..m {
q[k * n + j] = 0.0;
}
}
}
for i in 0..n {
if r[i * n + i] < 0.0 {
for j in i..n {
r[i * n + j] = -r[i * n + j];
}
for j in 0..m {
q[j * n + i] = -q[j * n + i];
}
}
}
let q_tensor = DenseTensor::from_vec(q, vec![m, n]);
let r_tensor = DenseTensor::from_vec(r, vec![n, n]);
Ok((q_tensor, r_tensor))
}
pub fn orthogonalize(tensor: &DenseTensor) -> Result<DenseTensor, TensorError> {
let (q, _r) = qr_decompose(tensor)?;
Ok(q)
}
pub fn orthogonalize_in_place(data: &mut [f64], shape: &[usize]) -> Result<f64, TensorError> {
if shape.len() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: shape.len(),
});
}
let (m, n) = (shape[0], shape[1]);
if m < n {
return Err(TensorError::ShapeMismatch {
expected: vec![n, m],
got: vec![m, n],
});
}
if data.len() != m * n {
return Err(TensorError::DimensionMismatch {
expected: m * n,
got: data.len(),
});
}
let mut max_error: f64 = 0.0;
for j in 0..n {
let mut col_norm = 0.0;
for k in 0..m {
col_norm += data[k * n + j] * data[k * n + j];
}
col_norm = col_norm.sqrt();
for i in 0..j {
let mut dot = 0.0;
for k in 0..m {
dot += data[k * n + i] * data[k * n + j];
}
for k in 0..m {
data[k * n + j] -= dot * data[k * n + i];
}
}
for i in 0..j {
let mut dot = 0.0;
for k in 0..m {
dot += data[k * n + i] * data[k * n + j];
}
if dot.abs() > 1e-14 {
for k in 0..m {
data[k * n + j] -= dot * data[k * n + i];
}
}
}
let mut new_norm = 0.0;
for k in 0..m {
new_norm += data[k * n + j] * data[k * n + j];
}
new_norm = new_norm.sqrt();
let rel_norm = if col_norm > 1e-14 { new_norm / col_norm } else { new_norm };
if rel_norm > 1e-12 && new_norm > 1e-14 {
for k in 0..m {
data[k * n + j] /= new_norm;
}
} else {
for k in 0..m {
data[k * n + j] = 0.0;
}
}
}
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..m {
dot += data[k * n + i] * data[k * n + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
let error = (dot - expected).abs();
max_error = max_error.max(error);
}
}
Ok(max_error)
}
pub fn is_orthogonal(tensor: &DenseTensor, tolerance: f64) -> bool {
let shape = tensor.shape();
if shape.len() != 2 {
return false;
}
let (m, n) = (shape[0], shape[1]);
if m < n {
return false;
}
let data = tensor.data();
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..m {
dot += data[k * n + i] * data[k * n + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
if (dot - expected).abs() > tolerance {
return false;
}
}
}
true
}
pub fn debug_matrix(tensor: &DenseTensor, label: &str) {
let shape = tensor.shape();
let data = tensor.data();
let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mean_val: f64 = data.iter().sum::<f64>() / data.len() as f64;
println!("{}: shape={:?}, min={:.6}, max={:.6}, mean={:.6}",
label, shape, min_val, max_val, mean_val);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::traits::TensorOps;
#[test]
fn test_qr_decomposition() {
let tensor = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
vec![4, 2],
);
let (q, r) = qr_decompose(&tensor).unwrap();
assert_eq!(q.shape(), &[4, 2]);
assert_eq!(r.shape(), &[2, 2]);
assert!(is_orthogonal(&q, 1e-5));
}
#[test]
fn test_orthogonalize() {
let tensor = DenseTensor::from_vec(
vec![1.0, 0.0, 0.0, 1.0],
vec![2, 2],
);
let ortho = orthogonalize(&tensor).unwrap();
assert!(is_orthogonal(&ortho, 1e-5));
}
#[test]
fn test_qr_reconstruction() {
let original = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![3, 2],
);
let (q, r) = qr_decompose(&original).unwrap();
let reconstructed = q.matmul(&r);
let orig_data: &[f64] = original.data();
let recon_data: &[f64] = reconstructed.data();
for (a, b) in orig_data.iter().zip(recon_data.iter()) {
assert!((a - b).abs() < 1e-5, "Reconstruction failed: {} vs {}", a, b);
}
}
#[test]
fn test_qr_square() {
let tensor = DenseTensor::from_vec(
vec![4.0, 1.0, 2.0, 3.0],
vec![2, 2],
);
let (q, r) = qr_decompose(&tensor).unwrap();
assert_eq!(q.shape(), &[2, 2]);
assert_eq!(r.shape(), &[2, 2]);
assert!(is_orthogonal(&q, 1e-5));
let reconstructed = q.matmul(&r);
let orig_data: &[f64] = tensor.data();
let recon_data: &[f64] = reconstructed.data();
for (a, b) in orig_data.iter().zip(recon_data.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
#[test]
fn test_qr_identity() {
let tensor = DenseTensor::from_vec(
vec![1.0, 0.0, 0.0, 1.0],
vec![2, 2],
);
let (q, r) = qr_decompose(&tensor).unwrap();
assert!(is_orthogonal(&q, 1e-5));
let r_data: &[f64] = r.data();
assert!((r_data[0] - 1.0).abs() < 1e-5);
assert!(r_data[1].abs() < 1e-5);
assert!(r_data[2].abs() < 1e-5);
assert!((r_data[3] - 1.0).abs() < 1e-5);
}
#[test]
fn test_qr_positive_diagonal() {
let tensor = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![3, 2],
);
let (_, r) = qr_decompose(&tensor).unwrap();
let r_data: &[f64] = r.data();
assert!(r_data[0] >= 0.0, "R[0,0] should be non-negative");
assert!(r_data[3] >= 0.0, "R[1,1] should be non-negative");
}
#[test]
fn test_orthogonalize_in_place() {
let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let error = orthogonalize_in_place(&mut data, &shape).unwrap();
assert!(error < 1e-10, "Orthogonalization error too large: {}", error);
let tensor = DenseTensor::from_vec(data, shape);
assert!(is_orthogonal(&tensor, 1e-10));
}
#[test]
fn test_orthogonalize_in_place_identity() {
let mut data = vec![1.0, 0.0, 0.0, 1.0];
let shape = vec![2, 2];
let error = orthogonalize_in_place(&mut data, &shape).unwrap();
assert!(error < 1e-10);
assert_eq!(data, vec![1.0, 0.0, 0.0, 1.0]);
}
#[test]
fn test_orthogonalize_in_place_error() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let shape = vec![4];
assert!(orthogonalize_in_place(&mut data, &shape).is_err());
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let shape = vec![2, 4];
assert!(orthogonalize_in_place(&mut data, &shape).is_err());
let mut data = vec![1.0, 2.0, 3.0];
let shape = vec![2, 2];
assert!(orthogonalize_in_place(&mut data, &shape).is_err());
}
}