use crate::blis::*;
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21a_true_asm_matches_scalar_k64() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 64;
let a: Vec<f32> = (0..MR * k).map(|i| ((i % 100) as f32) * 0.01).collect();
let b: Vec<f32> = (0..k * NR).map(|i| ((i % 100) as f32) * 0.01).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
let max_rel_diff: f32 = c_scalar
.iter()
.zip(c_asm.iter())
.map(|(s, a)| (s - a).abs() / s.abs().max(1e-10))
.fold(0.0, f32::max);
assert!(max_rel_diff < 1e-5, "F21a: ASM microkernel k=64 max_rel_diff={}", max_rel_diff);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21a_true_asm_matches_scalar_k256() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 256;
let a: Vec<f32> = (0..MR * k).map(|i| ((i % 100) as f32) * 0.01).collect();
let b: Vec<f32> = (0..k * NR).map(|i| ((i % 100) as f32) * 0.01).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
let max_diff: f32 =
c_scalar.iter().zip(c_asm.iter()).map(|(s, a)| (s - a).abs()).fold(0.0, f32::max);
assert!(max_diff < 1e-4, "F21a: ASM microkernel k=256 max_diff={}", max_diff);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21a_true_asm_matches_scalar_k1024() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 1024;
let a: Vec<f32> = (0..MR * k).map(|i| ((i % 50) as f32) * 0.01).collect();
let b: Vec<f32> = (0..k * NR).map(|i| ((i % 50) as f32) * 0.01).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
let max_diff: f32 =
c_scalar.iter().zip(c_asm.iter()).map(|(s, a)| (s - a).abs()).fold(0.0, f32::max);
assert!(max_diff < 1e-3, "F21a: ASM microkernel k=1024 max_diff={}", max_diff);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21h_k_remainder_k1() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 1;
let a: Vec<f32> = (0..MR * k).map(|i| (i as f32) + 1.0).collect();
let b: Vec<f32> = (0..k * NR).map(|i| (i as f32) + 1.0).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
for i in 0..MR * NR {
assert!(
(c_scalar[i] - c_asm[i]).abs() < 1e-5,
"F21h: k=1 mismatch at {}: {} vs {}",
i,
c_scalar[i],
c_asm[i]
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21h_k_remainder_k5() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 5; let a: Vec<f32> = (0..MR * k).map(|i| ((i % 10) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * NR).map(|i| ((i % 10) as f32) * 0.1).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
let max_diff: f32 =
c_scalar.iter().zip(c_asm.iter()).map(|(s, a)| (s - a).abs()).fold(0.0, f32::max);
assert!(max_diff < 1e-5, "F21h: k=5 remainder max_diff={}", max_diff);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21h_k_remainder_k7() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 7; let a: Vec<f32> = (0..MR * k).map(|i| ((i % 10) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * NR).map(|i| ((i % 10) as f32) * 0.1).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
let max_diff: f32 =
c_scalar.iter().zip(c_asm.iter()).map(|(s, a)| (s - a).abs()).fold(0.0, f32::max);
assert!(max_diff < 1e-5, "F21h: k=7 remainder max_diff={}", max_diff);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21h_k_remainder_k9() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 9; let a: Vec<f32> = (0..MR * k).map(|i| ((i % 10) as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * NR).map(|i| ((i % 10) as f32) * 0.1).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
let max_diff: f32 =
c_scalar.iter().zip(c_asm.iter()).map(|(s, a)| (s - a).abs()).fold(0.0, f32::max);
assert!(max_diff < 1e-5, "F21h: k=9 remainder max_diff={}", max_diff);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21j_asm_faster_than_intrinsics() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 256;
let a: Vec<f32> = (0..MR * k).map(|i| (i as f32) * 0.001).collect();
let b: Vec<f32> = (0..k * NR).map(|i| (i as f32) * 0.001).collect();
let mut c = vec![0.0; MR * NR];
for _ in 0..10 {
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), MR);
}
c.fill(0.0);
}
let iterations = 1000;
let start_asm = std::time::Instant::now();
for _ in 0..iterations {
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), MR);
}
}
let asm_time = start_asm.elapsed();
c.fill(0.0);
let start_intrinsics = std::time::Instant::now();
for _ in 0..iterations {
unsafe {
microkernel_8x6_avx2(k, a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), MR);
}
}
let intrinsics_time = start_intrinsics.elapsed();
let ratio = intrinsics_time.as_nanos() as f64 / asm_time.as_nanos() as f64;
assert!(
ratio >= 0.3,
"F21j: ASM should not be significantly slower than intrinsics. Ratio: {:.2}",
ratio
);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_f21c_pipeline_correctness() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let k = 16;
let a: Vec<f32> = (0..MR * k).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..k * NR).map(|i| (i as f32) * 0.01).collect();
let mut c_scalar = vec![0.0; MR * NR];
let mut c_asm = vec![0.0; MR * NR];
microkernel_scalar(k, &a, &b, &mut c_scalar, MR);
unsafe {
microkernel_8x6_true_asm(k, a.as_ptr(), b.as_ptr(), c_asm.as_mut_ptr(), MR);
}
for i in 0..MR * NR {
let rel_diff = (c_scalar[i] - c_asm[i]).abs() / c_scalar[i].abs().max(1e-10);
assert!(
rel_diff < 1e-5,
"F21c: Pipeline incorrect at {}: scalar={}, asm={}, rel_diff={}",
i,
c_scalar[i],
c_asm[i],
rel_diff
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_gemm_with_true_asm_microkernel() {
let n = 128;
let a: Vec<f32> = (0..n * n).map(|i| ((i % 10) as f32) * 0.1).collect();
let b: Vec<f32> = (0..n * n).map(|i| ((i % 7) as f32) * 0.1).collect();
let mut c_ref = vec![0.0; n * n];
let mut c_blis = vec![0.0; n * n];
gemm_reference(n, n, n, &a, &b, &mut c_ref).unwrap();
gemm_blis(n, n, n, &a, &b, &mut c_blis, None).unwrap();
let max_diff: f32 =
c_ref.iter().zip(c_blis.iter()).map(|(r, b)| (r - b).abs()).fold(0.0, f32::max);
assert!(max_diff < 1e-2, "GEMM with true ASM microkernel: max_diff={}", max_diff);
}