use super::super::*;
#[test]
fn test_validate_gemm_dims_a_mismatch() {
let a = vec![1.0f32; 10]; let b = vec![1.0f32; 20]; let mut c = vec![0.0f32; 15];
let result = gemm_blis(3, 5, 4, &a, &b, &mut c, None);
assert!(result.is_err());
let err = result.unwrap_err();
let msg = format!("{}", err);
assert!(msg.contains("A size mismatch"), "Got: {}", msg);
}
#[test]
fn test_validate_gemm_dims_b_mismatch() {
let a = vec![1.0f32; 12]; let b = vec![1.0f32; 19]; let mut c = vec![0.0f32; 15];
let result = gemm_blis(3, 5, 4, &a, &b, &mut c, None);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("B size mismatch"), "Got: {}", msg);
}
#[test]
fn test_validate_gemm_dims_c_mismatch() {
let a = vec![1.0f32; 12]; let b = vec![1.0f32; 20]; let mut c = vec![0.0f32; 14];
let result = gemm_blis(3, 5, 4, &a, &b, &mut c, None);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("C size mismatch"), "Got: {}", msg);
}
#[test]
fn test_validate_gemm_dims_all_correct() {
let m = 4;
let n = 3;
let k = 5;
let a = vec![1.0f32; m * k];
let b = vec![1.0f32; k * n];
let mut c = vec![0.0f32; m * n];
let result = gemm_blis(m, n, k, &a, &b, &mut c, None);
assert!(result.is_ok());
}
#[test]
fn test_gemm_blis_zero_m() {
let a: Vec<f32> = vec![];
let b = vec![1.0f32; 20]; let mut c: Vec<f32> = vec![];
let result = gemm_blis(0, 5, 4, &a, &b, &mut c, None);
assert!(result.is_ok());
}
#[test]
fn test_gemm_blis_zero_n() {
let a = vec![1.0f32; 12]; let b: Vec<f32> = vec![];
let mut c: Vec<f32> = vec![];
let result = gemm_blis(3, 0, 4, &a, &b, &mut c, None);
assert!(result.is_ok());
}
#[test]
fn test_gemm_blis_zero_k() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let mut c = vec![0.0f32; 15];
let result = gemm_blis(3, 5, 0, &a, &b, &mut c, None);
assert!(result.is_ok());
}
#[test]
fn test_gemm_blis_falls_to_reference() {
let a =
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0];
let b = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0];
let mut c = vec![0.0f32; 16];
gemm_blis(4, 4, 4, &a, &b, &mut c, None).unwrap();
for i in 0..16 {
assert!((c[i] - a[i]).abs() < 1e-5, "c[{}] = {}, expected {}", i, c[i], a[i]);
}
}
#[test]
fn test_gemm_blis_parallel_a_mismatch() {
let a = vec![1.0f32; 5]; let b = vec![1.0f32; 12]; let mut c = vec![0.0f32; 8];
let result = gemm_blis_parallel(2, 4, 3, &a, &b, &mut c);
assert!(result.is_err());
}
#[test]
fn test_gemm_blis_parallel_b_mismatch() {
let a = vec![1.0f32; 6]; let b = vec![1.0f32; 11]; let mut c = vec![0.0f32; 8];
let result = gemm_blis_parallel(2, 4, 3, &a, &b, &mut c);
assert!(result.is_err());
}
#[test]
fn test_gemm_blis_parallel_c_mismatch() {
let a = vec![1.0f32; 6]; let b = vec![1.0f32; 12]; let mut c = vec![0.0f32; 7];
let result = gemm_blis_parallel(2, 4, 3, &a, &b, &mut c);
assert!(result.is_err());
}
#[test]
fn test_gemm_blis_parallel_small_falls_to_sequential() {
let n = 8;
let a: Vec<f32> = (0..n * n).map(|i| (i % 7) as f32).collect();
let b: Vec<f32> = (0..n * n).map(|i| ((i + 3) % 5) as f32).collect();
let mut c_par = vec![0.0f32; n * n];
let mut c_ref = vec![0.0f32; n * n];
gemm_blis_parallel(n, n, n, &a, &b, &mut c_par).unwrap();
gemm_reference(n, n, n, &a, &b, &mut c_ref).unwrap();
for i in 0..n * n {
assert!(
(c_par[i] - c_ref[i]).abs() < 1e-3,
"Mismatch at {}: par={}, ref={}",
i,
c_par[i],
c_ref[i]
);
}
}
#[test]
fn test_gemm_blis_parallel_large_matrix() {
let n = 128;
let a: Vec<f32> = (0..n * n).map(|i| ((i % 11) as f32) * 0.1).collect();
let b: Vec<f32> = (0..n * n).map(|i| ((i % 7) as f32) * 0.1).collect();
let mut c_par = vec![0.0f32; n * n];
let mut c_ref = vec![0.0f32; n * n];
gemm_blis_parallel(n, n, n, &a, &b, &mut c_par).unwrap();
gemm_reference(n, n, n, &a, &b, &mut c_ref).unwrap();
let mut max_diff = 0.0f32;
for i in 0..n * n {
let diff = (c_par[i] - c_ref[i]).abs();
max_diff = max_diff.max(diff);
}
assert!(max_diff < 1e-1, "Max diff: {}", max_diff);
}
#[test]
fn test_gemm_blis_parallel_rectangular_tall() {
let m = 200;
let n = 50;
let k = 60;
let a: Vec<f32> = (0..m * k).map(|i| ((i % 9) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * n).map(|i| ((i % 5) as f32) * 0.1).collect();
let mut c = vec![0.0f32; m * n];
let result = gemm_blis_parallel(m, n, k, &a, &b, &mut c);
assert!(result.is_ok());
assert!(c.iter().any(|&v| v != 0.0));
}