use ariadnetor_core::backend::{
BackendError, ComputeBackend, ExecPolicy, MemoryOrder, SvdDescriptor,
};
use ariadnetor_native::NativeBackend;
use num_complex::Complex;
fn to_col_major<T: Copy>(row_major: &[T], rows: usize, cols: usize) -> Vec<T> {
let mut cm = vec![row_major[0]; rows * cols];
for i in 0..rows {
for j in 0..cols {
cm[j * rows + i] = row_major[i * cols + j];
}
}
cm
}
#[test]
fn test_svd_f64_square() {
let backend = NativeBackend::new();
let a_logical = [1.0f64, 2.0, 3.0, 4.0];
let a = to_col_major(&a_logical, 2, 2);
let (m, n, k) = (2, 2, 2);
let mut u = [0.0f64; 4]; let mut s = [0.0f64; 2]; let mut vt = [0.0f64; 4];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.svd(desc).unwrap();
assert!(s[0] > s[1]);
assert!(s[1] >= 0.0);
for i in 0..m {
for j in 0..n {
let mut val = 0.0;
for l in 0..k {
val += u[l * m + i] * s[l] * vt[j * k + l];
}
let expected = a_logical[i * n + j];
assert!(
(val - expected).abs() < 1e-10,
"Reconstruction mismatch at ({i},{j}): {val} vs {expected}",
);
}
}
}
#[test]
fn test_svd_f64_rectangular() {
let backend = NativeBackend::new();
let a_logical = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let (m, n, k) = (2, 3, 2);
let a = to_col_major(&a_logical, m, n);
let mut u = vec![0.0f64; m * k];
let mut s = vec![0.0f64; k];
let mut vt = vec![0.0f64; k * n];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.svd(desc).unwrap();
assert!(s[0] > s[1]);
for i in 0..m {
for j in 0..n {
let mut val = 0.0;
for l in 0..k {
val += u[l * m + i] * s[l] * vt[j * k + l];
}
let expected = a_logical[i * n + j];
assert!(
(val - expected).abs() < 1e-10,
"Reconstruction mismatch at ({i},{j})"
);
}
}
}
#[test]
fn test_svd_f32_basic() {
let backend = NativeBackend::new();
let a_logical = [1.0f32, 2.0, 3.0, 4.0];
let (m, n, k) = (2, 2, 2);
let a = to_col_major(&a_logical, m, n);
let mut u = [0.0f32; 4];
let mut s = [0.0f32; 2];
let mut vt = [0.0f32; 4];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.svd(desc).unwrap();
assert!(s[0] > s[1]);
for i in 0..m {
for j in 0..n {
let mut val = 0.0f32;
for l in 0..k {
val += u[l * m + i] * s[l] * vt[j * k + l];
}
let expected = a_logical[i * n + j];
assert!(
(val - expected).abs() < 1e-4,
"Reconstruction mismatch at ({i},{j}): {val} vs {expected}",
);
}
}
}
#[test]
fn test_svd_c64_hermitian() {
let backend = NativeBackend::new();
let a_logical = [
Complex::new(2.0, 0.0),
Complex::new(1.0, -1.0),
Complex::new(1.0, 1.0),
Complex::new(3.0, 0.0),
];
let (m, n, k) = (2, 2, 2);
let a = to_col_major(&a_logical, m, n);
let mut u = vec![Complex::new(0.0, 0.0); m * k];
let mut s = vec![0.0f64; k];
let mut vt = vec![Complex::new(0.0, 0.0); k * n];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.svd(desc).unwrap();
assert!(s[0] > s[1]);
assert!(s[1] >= 0.0);
for i in 0..m {
for j in 0..n {
let mut val = Complex::new(0.0, 0.0);
for l in 0..k {
val += u[l * m + i] * s[l] * vt[j * k + l];
}
let expected = a_logical[i * n + j];
let diff = (val - expected).norm();
assert!(
diff < 1e-10,
"SVD reconstruction mismatch at ({i},{j}): {val} vs {expected}",
);
}
}
}
#[test]
fn test_svd_c64_rectangular() {
let backend = NativeBackend::new();
let a_logical = [
Complex::new(1.0, 2.0),
Complex::new(3.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(4.0, -1.0),
Complex::new(2.0, 3.0),
Complex::new(1.0, 1.0),
];
let (m, n, k) = (2, 3, 2);
let a = to_col_major(&a_logical, m, n);
let mut u = vec![Complex::new(0.0, 0.0); m * k];
let mut s = vec![0.0f64; k];
let mut vt = vec![Complex::new(0.0, 0.0); k * n];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.svd(desc).unwrap();
assert!(s[0] > s[1]);
for i in 0..m {
for j in 0..n {
let mut val = Complex::new(0.0, 0.0);
for l in 0..k {
val += u[l * m + i] * s[l] * vt[j * k + l];
}
let expected = a_logical[i * n + j];
let diff = (val - expected).norm();
assert!(diff < 1e-10, "SVD reconstruction mismatch at ({i},{j})");
}
}
}
#[test]
fn test_svd_c64_unitary_check() {
let backend = NativeBackend::new();
let a_logical = [
Complex::new(1.0, 2.0),
Complex::new(3.0, -1.0),
Complex::new(0.0, 4.0),
Complex::new(2.0, 1.0),
];
let (m, n, k) = (2, 2, 2);
let a = to_col_major(&a_logical, m, n);
let mut u = vec![Complex::new(0.0, 0.0); m * k];
let mut s = vec![0.0f64; k];
let mut vt = vec![Complex::new(0.0, 0.0); k * n];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.svd(desc).unwrap();
for i in 0..k {
for j in 0..k {
let mut val = Complex::new(0.0, 0.0);
for l in 0..m {
val += u[i * m + l].conj() * u[j * m + l];
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
val.norm() - expected < 1e-10,
"U^H * U not identity at ({i},{j}): {val}"
);
}
}
}
#[test]
fn test_svd_rejects_row_major_order() {
let backend = NativeBackend::new();
let (m, n) = (2usize, 2usize);
let a = [0.0f64; 4];
let mut u = [0.0f64; 4];
let mut s = [0.0f64; 2];
let mut vt = [0.0f64; 4];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
let result = backend.svd(desc);
assert!(
matches!(result, Err(BackendError::InvalidArgument(_))),
"expected InvalidArgument for RowMajor SVD, got {result:?}"
);
}
#[test]
fn test_svd_c32_basic() {
let backend = NativeBackend::new();
let a_logical = [
Complex::new(2.0f32, 0.0),
Complex::new(1.0, -1.0),
Complex::new(1.0, 1.0),
Complex::new(3.0, 0.0),
];
let (m, n, k) = (2, 2, 2);
let a = to_col_major(&a_logical, m, n);
let mut u = vec![Complex::new(0.0f32, 0.0); m * k];
let mut s = vec![0.0f32; k];
let mut vt = vec![Complex::new(0.0f32, 0.0); k * n];
let desc = SvdDescriptor {
m,
n,
a: &a,
u: &mut u,
s: &mut s,
vt: &mut vt,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.svd(desc).unwrap();
assert!(s[0] > s[1]);
for i in 0..m {
for j in 0..n {
let mut val = Complex::new(0.0f32, 0.0);
for l in 0..k {
val += u[l * m + i] * s[l] * vt[j * k + l];
}
let expected = a_logical[i * n + j];
let diff = (val - expected).norm();
assert!(diff < 1e-4, "SVD reconstruction mismatch at ({i},{j})");
}
}
}