use super::super::*;
#[test]
fn test_gemm_reference_2x2() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
gemm_reference(2, 2, 2, &a, &b, &mut c).unwrap();
assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_gemm_reference_identity() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let identity = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let mut c = vec![0.0; 9];
gemm_reference(3, 3, 3, &a, &identity, &mut c).unwrap();
assert_eq!(c, a);
}
#[test]
fn test_gemm_reference_accumulation() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 0.0, 0.0, 1.0];
let mut c = vec![10.0, 20.0, 30.0, 40.0];
gemm_reference(2, 2, 2, &a, &b, &mut c).unwrap();
assert_eq!(c, vec![11.0, 22.0, 33.0, 44.0]);
}
#[test]
fn test_gemm_reference_rectangular() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let mut c = vec![0.0; 4];
gemm_reference(2, 2, 3, &a, &b, &mut c).unwrap();
assert_eq!(c, vec![58.0, 64.0, 139.0, 154.0]);
}
#[test]
fn test_gemm_reference_size_mismatch() {
let a = vec![1.0, 2.0, 3.0]; let b = vec![1.0, 2.0, 3.0, 4.0];
let mut c = vec![0.0; 4];
let result = gemm_reference(2, 2, 2, &a, &b, &mut c);
assert!(result.is_err());
}
#[test]
fn test_gemm_reference_b_size_mismatch() {
let a = vec![1.0; 4];
let b = vec![1.0, 2.0]; let mut c = vec![0.0; 4];
assert!(gemm_reference(2, 2, 2, &a, &b, &mut c).is_err());
}
#[test]
fn test_gemm_reference_c_size_mismatch() {
let a = vec![1.0; 4];
let b = vec![1.0; 4];
let mut c = vec![0.0; 2]; assert!(gemm_reference(2, 2, 2, &a, &b, &mut c).is_err());
}
#[test]
fn test_gemm_reference_with_jidoka_basic() {
let guard = JidokaGuard::strict();
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
gemm_reference_with_jidoka(2, 2, 2, &a, &b, &mut c, &guard).unwrap();
assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_gemm_reference_with_jidoka_nan_input() {
let guard = JidokaGuard::strict();
let a = vec![f32::NAN, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
assert!(gemm_reference_with_jidoka(2, 2, 2, &a, &b, &mut c, &guard).is_err());
}
#[test]
fn test_gemm_reference_with_jidoka_inf_input() {
let guard = JidokaGuard::strict();
let a = vec![f32::INFINITY, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
assert!(gemm_reference_with_jidoka(2, 2, 2, &a, &b, &mut c, &guard).is_err());
}
#[test]
fn test_gemm_reference_with_jidoka_inf_in_b() {
let guard = JidokaGuard::strict();
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, f32::INFINITY, 7.0, 8.0];
let mut c = vec![0.0; 4];
assert!(gemm_reference_with_jidoka(2, 2, 2, &a, &b, &mut c, &guard).is_err());
}
#[test]
fn test_transpose_large_with_remainder() {
let rows = 10;
let cols = 13;
let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
let mut b = vec![0.0; rows * cols];
transpose(rows, cols, &a, &mut b).unwrap();
for r in 0..rows {
for c in 0..cols {
assert_eq!(b[c * rows + r], a[r * cols + c], "transpose mismatch at ({}, {})", r, c);
}
}
}
#[test]
fn test_transpose_exact_block_size() {
let rows = 16;
let cols = 16;
let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
let mut b = vec![0.0; rows * cols];
transpose(rows, cols, &a, &mut b).unwrap();
for r in 0..rows {
for c in 0..cols {
assert_eq!(b[c * rows + r], a[r * cols + c]);
}
}
}
#[test]
fn test_transpose_size_mismatch() {
let a = vec![1.0; 6];
let mut b = vec![0.0; 4]; assert!(transpose(2, 3, &a, &mut b).is_err());
}
#[test]
fn test_transpose_small_scalar_path() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut b = vec![0.0; 6];
transpose(2, 3, &a, &mut b).unwrap();
assert_eq!(b, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}