use ariadnetor_core::Scalar;
use ariadnetor_core::backend::{
BackendError, ComputeBackend, ExecPolicy, MemoryOrder, SolveDescriptor,
};
use ariadnetor_native::NativeBackend;
use num_complex::Complex;
fn assert_solve_laws<T: Scalar>(
a: &[T],
b: &[T],
n: usize,
nrhs: usize,
tol: f64,
to_c64: impl Fn(T) -> Complex<f64>,
) {
let backend = NativeBackend::new();
let mut x = vec![T::zero(); n * nrhs];
backend
.solve(SolveDescriptor {
n,
nrhs,
a,
b,
x: &mut x,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
})
.unwrap();
let a64: Vec<Complex<f64>> = a.iter().map(|&v| to_c64(v)).collect();
let x64: Vec<Complex<f64>> = x.iter().map(|&v| to_c64(v)).collect();
let b64: Vec<Complex<f64>> = b.iter().map(|&v| to_c64(v)).collect();
for j in 0..nrhs {
for i in 0..n {
let mut ax = Complex::new(0.0, 0.0);
for k in 0..n {
ax += a64[k * n + i] * x64[j * n + k];
}
assert!(
(ax.re - b64[j * n + i].re).abs() < tol && (ax.im - b64[j * n + i].im).abs() < tol,
"A*X != B at i={i}, j={j}: ax={ax:?}, b={b:?}",
b = b64[j * n + i],
);
}
}
}
#[test]
fn test_solve_f64() {
let a = [2.0f64, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
let b = [1.0f64, 0.0, 1.0, 0.0, 1.0, 0.0];
assert_solve_laws(&a, &b, 3, 2, 1e-10, |x| Complex::new(x, 0.0));
}
#[test]
fn test_solve_f32() {
let a = [2.0f32, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
let b = [1.0f32, 0.0, 1.0, 0.0, 1.0, 0.0];
assert_solve_laws(&a, &b, 3, 2, 1e-4, |x| Complex::new(x as f64, 0.0));
}
#[test]
fn test_solve_c64() {
let a: Vec<Complex<f64>> = vec![
Complex::new(2.0, 0.0),
Complex::new(1.0, 1.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, -1.0),
Complex::new(3.0, 0.0),
Complex::new(1.0, 1.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, -1.0),
Complex::new(2.0, 0.0),
];
let b = vec![
Complex::new(1.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
];
assert_solve_laws(&a, &b, 3, 2, 1e-10, |x| x);
}
#[test]
fn test_solve_c32() {
let a: Vec<Complex<f32>> = vec![
Complex::new(2.0, 0.0),
Complex::new(1.0, 1.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, -1.0),
Complex::new(3.0, 0.0),
Complex::new(1.0, 1.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, -1.0),
Complex::new(2.0, 0.0),
];
let b = vec![
Complex::new(1.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
];
assert_solve_laws(&a, &b, 3, 2, 1e-3, |x| {
Complex::new(x.re as f64, x.im as f64)
});
}
#[test]
fn test_solve_rejects_row_major_order() {
let backend = NativeBackend::new();
let (n, nrhs) = (2usize, 1usize);
let a = [0.0f64; 4];
let b = [0.0f64; 2];
let mut x = [0.0f64; 2];
let desc = SolveDescriptor {
n,
nrhs,
a: &a,
b: &b,
x: &mut x,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
let result = backend.solve(desc);
assert!(
matches!(result, Err(BackendError::InvalidArgument(_))),
"expected InvalidArgument for RowMajor solve, got {result:?}"
);
}