use ariadnetor_core::Scalar;
use ariadnetor_core::backend::{ComputeBackend, ExecPolicy, GemmDescriptor, MemoryOrder};
use ariadnetor_native::NativeBackend;
use num_complex::Complex;
use rstest::rstest;
#[test]
fn test_gemm_f64_identity() {
let backend = NativeBackend::new();
let a = [1.0f64, 0.0, 0.0, 1.0];
let b = [5.0f64, 6.0, 7.0, 8.0];
let mut c = [0.0f64; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: 1.0,
a: &a,
b: &b,
beta: 0.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn test_gemm_f64_basic() {
let backend = NativeBackend::new();
let a = [1.0f64, 2.0, 3.0, 4.0];
let b = [5.0f64, 6.0, 7.0, 8.0];
let mut c = [0.0f64; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: 1.0,
a: &a,
b: &b,
beta: 0.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_gemm_f64_alpha_beta() {
let backend = NativeBackend::new();
let a = [1.0f64, 2.0, 3.0, 4.0];
let b = [5.0f64, 6.0, 7.0, 8.0];
let mut c = [1.0f64; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: 2.0,
a: &a,
b: &b,
beta: 3.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [41.0, 47.0, 89.0, 103.0]);
}
#[test]
fn test_gemm_f64_rectangular() {
let backend = NativeBackend::new();
let a = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = [7.0f64, 8.0, 9.0, 10.0, 11.0, 12.0];
let mut c = [0.0f64; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 3,
alpha: 1.0,
a: &a,
b: &b,
beta: 0.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [58.0, 64.0, 139.0, 154.0]);
}
#[test]
fn test_gemm_f32_basic() {
let backend = NativeBackend::new();
let a = [1.0f32, 2.0, 3.0, 4.0];
let b = [5.0f32, 6.0, 7.0, 8.0];
let mut c = [0.0f32; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: 1.0,
a: &a,
b: &b,
beta: 0.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_gemm_f32_alpha_beta() {
let backend = NativeBackend::new();
let a = [1.0f32, 2.0, 3.0, 4.0];
let b = [5.0f32, 6.0, 7.0, 8.0];
let mut c = [2.0f32; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: 2.0,
a: &a,
b: &b,
beta: 3.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [44.0, 50.0, 92.0, 106.0]);
}
#[test]
fn test_gemm_c64_basic() {
let backend = NativeBackend::new();
let a = [
Complex::new(1.0, 1.0),
Complex::new(2.0, 1.0),
Complex::new(3.0, 1.0),
Complex::new(4.0, 1.0),
];
let b = [
Complex::new(5.0, 1.0),
Complex::new(6.0, 1.0),
Complex::new(7.0, 1.0),
Complex::new(8.0, 1.0),
];
let mut c = [Complex::new(0.0, 0.0); 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: Complex::new(1.0, 0.0),
a: &a,
b: &b,
beta: Complex::new(0.0, 0.0),
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert!((c[0].re - 17.0f64).abs() < 1e-10);
assert!((c[0].im - 15.0f64).abs() < 1e-10);
}
#[test]
fn test_gemm_c64_alpha_beta() {
let backend = NativeBackend::new();
let a = [
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
];
let b = [
Complex::new(3.0, 4.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(3.0, 4.0),
];
let mut c = [
Complex::new(1.0, 1.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 1.0),
];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: Complex::new(2.0, 0.0),
a: &a,
b: &b,
beta: Complex::new(0.0, 1.0),
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert!((c[0].re - 5.0f64).abs() < 1e-10);
assert!((c[0].im - 9.0f64).abs() < 1e-10);
}
#[test]
fn test_gemm_c32_basic() {
let backend = NativeBackend::new();
let a = [
Complex::new(1.0f32, 1.0),
Complex::new(2.0, 1.0),
Complex::new(3.0, 1.0),
Complex::new(4.0, 1.0),
];
let b = [
Complex::new(5.0f32, 1.0),
Complex::new(6.0, 1.0),
Complex::new(7.0, 1.0),
Complex::new(8.0, 1.0),
];
let mut c = [Complex::new(0.0f32, 0.0); 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: Complex::new(1.0, 0.0),
a: &a,
b: &b,
beta: Complex::new(0.0, 0.0),
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert!((c[0].re - 17.0).abs() < 1e-4);
assert!((c[0].im - 15.0).abs() < 1e-4);
}
#[test]
fn test_gemm_c32_alpha_beta() {
let backend = NativeBackend::new();
let a = [
Complex::new(1.0f32, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
];
let b = [
Complex::new(3.0f32, 4.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(3.0, 4.0),
];
let mut c = [
Complex::new(2.0f32, 3.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(2.0, 3.0),
];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: Complex::new(2.0, 0.0),
a: &a,
b: &b,
beta: Complex::new(0.0, 1.0),
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::RowMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert!((c[0].re - 3.0).abs() < 1e-4);
assert!((c[0].im - 10.0).abs() < 1e-4);
}
#[test]
fn test_gemm_f64_colmajor() {
let backend = NativeBackend::new();
let a = [1.0f64, 3.0, 2.0, 4.0];
let b = [5.0f64, 7.0, 6.0, 8.0];
let mut c = [2.0f64; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: 2.0,
a: &a,
b: &b,
beta: 3.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [44.0, 92.0, 50.0, 106.0]);
}
#[test]
fn test_gemm_f32_colmajor() {
let backend = NativeBackend::new();
let a = [1.0f32, 3.0, 2.0, 4.0];
let b = [5.0f32, 7.0, 6.0, 8.0];
let mut c = [2.0f32; 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: 2.0,
a: &a,
b: &b,
beta: 3.0,
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert_eq!(c, [44.0, 92.0, 50.0, 106.0]);
}
#[test]
fn test_gemm_c64_colmajor() {
let backend = NativeBackend::new();
let a = [
Complex::new(1.0, 1.0),
Complex::new(3.0, 1.0),
Complex::new(2.0, 1.0),
Complex::new(4.0, 1.0),
];
let b = [
Complex::new(5.0, 1.0),
Complex::new(7.0, 1.0),
Complex::new(6.0, 1.0),
Complex::new(8.0, 1.0),
];
let mut c = [Complex::new(2.0, 3.0); 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: Complex::new(2.0, 0.0),
a: &a,
b: &b,
beta: Complex::new(0.0, 1.0),
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert!((c[0].re - 31.0f64).abs() < 1e-10);
assert!((c[0].im - 32.0f64).abs() < 1e-10);
}
#[test]
fn test_gemm_c32_colmajor() {
let backend = NativeBackend::new();
let a = [
Complex::new(1.0f32, 1.0),
Complex::new(3.0, 1.0),
Complex::new(2.0, 1.0),
Complex::new(4.0, 1.0),
];
let b = [
Complex::new(5.0f32, 1.0),
Complex::new(7.0, 1.0),
Complex::new(6.0, 1.0),
Complex::new(8.0, 1.0),
];
let mut c = [Complex::new(2.0f32, 3.0); 4];
let desc = GemmDescriptor {
m: 2,
n: 2,
k: 2,
alpha: Complex::new(2.0, 0.0),
a: &a,
b: &b,
beta: Complex::new(0.0, 1.0),
c: &mut c,
trans_a: false,
trans_b: false,
order: MemoryOrder::ColumnMajor,
policy: ExecPolicy::Sequential,
};
backend.gemm(desc).unwrap();
assert!((c[0].re - 31.0).abs() < 1e-3);
assert!((c[0].im - 32.0).abs() < 1e-3);
}
fn layout_index(i: usize, j: usize, rows: usize, cols: usize, order: MemoryOrder) -> usize {
match order {
MemoryOrder::RowMajor => i * cols + j,
MemoryOrder::ColumnMajor => j * rows + i,
}
}
fn encode<T: Scalar>(data: &[T], rows: usize, cols: usize, order: MemoryOrder) -> Vec<T> {
let mut buf = vec![T::zero(); rows * cols];
for i in 0..rows {
for j in 0..cols {
buf[layout_index(i, j, rows, cols, order)] = data[i * cols + j];
}
}
buf
}
fn decode<T: Scalar>(buf: &[T], rows: usize, cols: usize, order: MemoryOrder) -> Vec<T> {
let mut out = vec![T::zero(); rows * cols];
for i in 0..rows {
for j in 0..cols {
out[i * cols + j] = buf[layout_index(i, j, rows, cols, order)];
}
}
out
}
fn transpose_logical<T: Scalar>(data: &[T], rows: usize, cols: usize) -> Vec<T> {
let mut t = vec![T::zero(); rows * cols];
for i in 0..rows {
for j in 0..cols {
t[j * rows + i] = data[i * cols + j];
}
}
t
}
fn check_gemm_combination<T: Scalar>(
trans_a: bool,
trans_b: bool,
order: MemoryOrder,
mk: fn(f64) -> T,
) {
let (m, n, k) = (2usize, 4usize, 3usize);
let op_a: Vec<T> = (1..=m * k).map(|x| mk(x as f64)).collect(); let op_b: Vec<T> = (1..=k * n).map(|x| mk(x as f64)).collect();
let mut reference = vec![T::zero(); m * n];
for i in 0..m {
for j in 0..n {
let mut acc = T::zero();
for p in 0..k {
acc = acc + op_a[i * k + p] * op_b[p * n + j];
}
reference[i * n + j] = acc;
}
}
let a_buf = if trans_a {
encode(&transpose_logical(&op_a, m, k), k, m, order)
} else {
encode(&op_a, m, k, order)
};
let b_buf = if trans_b {
encode(&transpose_logical(&op_b, k, n), n, k, order)
} else {
encode(&op_b, k, n, order)
};
let mut c_buf = vec![T::zero(); m * n];
let desc = GemmDescriptor {
m,
n,
k,
alpha: T::one(),
a: &a_buf,
b: &b_buf,
beta: T::zero(),
c: &mut c_buf,
trans_a,
trans_b,
order,
policy: ExecPolicy::Sequential,
};
NativeBackend::new().gemm(desc).unwrap();
let got = decode(&c_buf, m, n, order);
let neg_one = -<T::Real as num_traits::One>::one();
for (g, r) in got.iter().zip(reference.iter()) {
let err = (*g + r.scale_real(neg_one)).abs();
assert!(
err <= <T::Real as num_traits::Float>::epsilon(),
"GEMM mismatch: trans_a={trans_a}, trans_b={trans_b}, order={order:?}"
);
}
}
#[rstest]
fn gemm_layout_transpose_invariance_f64(
#[values(false, true)] trans_a: bool,
#[values(false, true)] trans_b: bool,
#[values(MemoryOrder::RowMajor, MemoryOrder::ColumnMajor)] order: MemoryOrder,
) {
check_gemm_combination::<f64>(trans_a, trans_b, order, |x| x);
}
#[rstest]
fn gemm_layout_transpose_invariance_f32(
#[values(false, true)] trans_a: bool,
#[values(false, true)] trans_b: bool,
#[values(MemoryOrder::RowMajor, MemoryOrder::ColumnMajor)] order: MemoryOrder,
) {
check_gemm_combination::<f32>(trans_a, trans_b, order, |x| x as f32);
}
#[rstest]
fn gemm_layout_transpose_invariance_c64(
#[values(false, true)] trans_a: bool,
#[values(false, true)] trans_b: bool,
#[values(MemoryOrder::RowMajor, MemoryOrder::ColumnMajor)] order: MemoryOrder,
) {
check_gemm_combination::<Complex<f64>>(trans_a, trans_b, order, |x| Complex::new(x, 0.0));
}
#[rstest]
fn gemm_layout_transpose_invariance_c32(
#[values(false, true)] trans_a: bool,
#[values(false, true)] trans_b: bool,
#[values(MemoryOrder::RowMajor, MemoryOrder::ColumnMajor)] order: MemoryOrder,
) {
check_gemm_combination::<Complex<f32>>(trans_a, trans_b, order, |x| {
Complex::new(x as f32, 0.0)
});
}