#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use super::macros::{
define_microkernel_2x_f32, define_microkernel_2x_f64, define_microkernel_f32,
define_microkernel_f64,
};
define_microkernel_f32!(
microkernel_6x16_f32,
16,
"avx512f",
"fma",
_mm512_loadu_ps,
_mm512_storeu_ps,
_mm512_set1_ps,
_mm512_fmadd_ps,
_mm512_setzero_ps,
__m512
);
define_microkernel_f64!(
microkernel_6x8_f64,
8,
"avx512f",
"fma",
_mm512_loadu_pd,
_mm512_storeu_pd,
_mm512_set1_pd,
_mm512_fmadd_pd,
_mm512_setzero_pd,
__m512d
);
define_microkernel_2x_f32!(
microkernel_6x32_f32,
16,
"avx512f",
"fma",
_mm512_loadu_ps,
_mm512_storeu_ps,
_mm512_set1_ps,
_mm512_fmadd_ps,
_mm512_setzero_ps,
__m512
);
define_microkernel_2x_f64!(
microkernel_6x16_f64,
8,
"avx512f",
"fma",
_mm512_loadu_pd,
_mm512_storeu_pd,
_mm512_set1_pd,
_mm512_fmadd_pd,
_mm512_setzero_pd,
__m512d
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_microkernel_6x16_f32_basic() {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("fma") {
eprintln!("Skipping AVX-512 test - CPU doesn't support AVX-512F+FMA");
return;
}
let a: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ];
let mut b: Vec<f32> = vec![0.0; 32];
for i in 0..16 {
b[i] = 1.0;
b[16 + i] = (i + 1) as f32;
}
let mut c: Vec<f32> = vec![0.0; 6 * 16];
unsafe {
microkernel_6x16_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 16, true);
}
for i in 0..6 {
for j in 0..16 {
let expected = (i + 1) as f32 + (j + 1) as f32;
let actual = c[i * 16 + j];
assert!(
(actual - expected).abs() < 1e-5,
"Mismatch at [{i}][{j}]: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_microkernel_6x8_f64_basic() {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("fma") {
eprintln!("Skipping AVX-512 test - CPU doesn't support AVX-512F+FMA");
return;
}
let a: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let mut b: Vec<f64> = vec![0.0; 16];
for i in 0..8 {
b[i] = 1.0;
b[8 + i] = (i + 1) as f64;
}
let mut c: Vec<f64> = vec![0.0; 6 * 8];
unsafe {
microkernel_6x8_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 8, true);
}
for i in 0..6 {
for j in 0..8 {
let expected = (i + 1) as f64 + (j + 1) as f64;
let actual = c[i * 8 + j];
assert!(
(actual - expected).abs() < 1e-10,
"Mismatch at [{i}][{j}]: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_microkernel_accumulate() {
if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("fma") {
eprintln!("Skipping AVX-512 test - CPU doesn't support AVX-512F+FMA");
return;
}
let a: Vec<f32> = vec![1.0; 12]; let b: Vec<f32> = vec![1.0; 32];
let mut c: Vec<f32> = vec![100.0; 6 * 16];
unsafe {
microkernel_6x16_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 16, false);
}
for i in 0..6 {
for j in 0..16 {
let expected = 102.0f32;
let actual = c[i * 16 + j];
assert!(
(actual - expected).abs() < 1e-5,
"Mismatch at [{i}][{j}]: expected {expected}, got {actual}"
);
}
}
}
}