use trueno_gemm_codegen::{avx512_microkernel, avx512_microkernel_broadcast_b};
avx512_microkernel!(mr = 8, nr = 32);
avx512_microkernel!(mr = 8, nr = 16);
avx512_microkernel!(mr = 8, nr = 48);
avx512_microkernel_broadcast_b!(mr = 32, nr = 6);
avx512_microkernel_broadcast_b!(mr = 48, nr = 6);
avx512_microkernel_broadcast_b!(mr = 64, nr = 6);
#[cfg(test)]
mod tests {
use super::*;
fn gemm_reference(m: usize, n: usize, k: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
for i in 0..m {
for j in 0..n {
for p in 0..k {
c[i * n + j] += a[p * m + i] * b[p * n + j];
}
}
}
}
#[test]
fn test_codegen_8x32_correctness() {
let mr = 8;
let nr = 32;
let k = 64;
let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
let mut c_gen = vec![0.0f32; mr * nr];
let mut c_ref = vec![0.0f32; mr * nr];
gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
if std::arch::is_x86_feature_detected!("avx512f") {
unsafe {
microkernel_8x32_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
}
let max_diff =
c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x32");
}
}
#[test]
fn test_codegen_8x16_correctness() {
let mr = 8;
let nr = 16;
let k = 64;
let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
let mut c_gen = vec![0.0f32; mr * nr];
let mut c_ref = vec![0.0f32; mr * nr];
gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
if std::arch::is_x86_feature_detected!("avx512f") {
unsafe {
microkernel_8x16_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
}
let max_diff =
c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x16");
}
}
#[test]
fn test_codegen_8x48_correctness() {
let mr = 8;
let nr = 48;
let k = 32;
let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
let mut c_gen = vec![0.0f32; mr * nr];
let mut c_ref = vec![0.0f32; mr * nr];
gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
if std::arch::is_x86_feature_detected!("avx512f") {
unsafe {
microkernel_8x48_avx512_gen(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
}
let max_diff =
c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
assert!(max_diff < 1e-2, "C-CODEGEN-001: max diff {max_diff} >= 1e-2 for 8x48");
}
}
#[test]
fn test_codegen_bcast_b_32x6_correctness() {
let mr = 32;
let nr = 6;
let k = 64;
let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
let mut c_gen = vec![0.0f32; mr * nr];
let mut c_ref = vec![0.0f32; mr * nr];
gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
if std::arch::is_x86_feature_detected!("avx512f") {
unsafe {
microkernel_32x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
}
let max_diff =
c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002a: max diff {max_diff} >= 1e-2");
}
}
#[test]
fn test_codegen_bcast_b_48x6_correctness() {
let mr = 48;
let nr = 6;
let k = 64;
let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
let mut c_gen = vec![0.0f32; mr * nr];
let mut c_ref = vec![0.0f32; mr * nr];
gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
if std::arch::is_x86_feature_detected!("avx512f") {
unsafe {
microkernel_48x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
}
let max_diff =
c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002b: max diff {max_diff} >= 1e-2");
}
}
#[test]
fn test_codegen_bcast_b_64x6_correctness() {
let mr = 64;
let nr = 6;
let k = 32;
let a: Vec<f32> = (0..mr * k).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0).collect();
let b: Vec<f32> = (0..k * nr).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0).collect();
let mut c_gen = vec![0.0f32; mr * nr];
let mut c_ref = vec![0.0f32; mr * nr];
gemm_reference(mr, nr, k, &a, &b, &mut c_ref);
if std::arch::is_x86_feature_detected!("avx512f") {
unsafe {
microkernel_64x6_avx512_bcast_b(k, a.as_ptr(), b.as_ptr(), c_gen.as_mut_ptr(), nr);
}
let max_diff =
c_gen.iter().zip(c_ref.iter()).map(|(g, r)| (g - r).abs()).fold(0.0f32, f32::max);
assert!(max_diff < 1e-2, "FALSIFY-CODEGEN-002c: max diff {max_diff} >= 1e-2");
}
}
}