#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use super::macros::{
define_microkernel_2x_f32, define_microkernel_2x_f64, define_microkernel_f32,
define_microkernel_f64,
};
define_microkernel_f32!(
microkernel_6x8_f32,
8,
"avx2",
"fma",
_mm256_loadu_ps,
_mm256_storeu_ps,
_mm256_set1_ps,
_mm256_fmadd_ps,
_mm256_setzero_ps,
__m256
);
define_microkernel_f64!(
microkernel_6x4_f64,
4,
"avx2",
"fma",
_mm256_loadu_pd,
_mm256_storeu_pd,
_mm256_set1_pd,
_mm256_fmadd_pd,
_mm256_setzero_pd,
__m256d
);
define_microkernel_2x_f32!(
microkernel_6x16_f32,
8,
"avx2",
"fma",
_mm256_loadu_ps,
_mm256_storeu_ps,
_mm256_set1_ps,
_mm256_fmadd_ps,
_mm256_setzero_ps,
__m256
);
define_microkernel_2x_f64!(
microkernel_6x8_f64,
4,
"avx2",
"fma",
_mm256_loadu_pd,
_mm256_storeu_pd,
_mm256_set1_pd,
_mm256_fmadd_pd,
_mm256_setzero_pd,
__m256d
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_microkernel_6x8_f32_basic() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
eprintln!("Skipping AVX2 test - CPU doesn't support AVX2+FMA");
return;
}
let a: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ];
let mut b: Vec<f32> = vec![0.0; 16];
for i in 0..8 {
b[i] = 1.0; b[8 + i] = (i + 1) as f32; }
let mut c: Vec<f32> = vec![0.0; 6 * 8];
unsafe {
microkernel_6x8_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 8, true);
}
for i in 0..6 {
for j in 0..8 {
let expected = (i + 1) as f32 + (j + 1) as f32;
let actual = c[i * 8 + j];
assert!(
(actual - expected).abs() < 1e-5,
"Mismatch at [{i}][{j}]: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_microkernel_6x4_f64_basic() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
eprintln!("Skipping AVX2 test - CPU doesn't support AVX2+FMA");
return;
}
let a: Vec<f64> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ];
let mut b: Vec<f64> = vec![0.0; 8];
for i in 0..4 {
b[i] = 1.0;
b[4 + i] = (i + 1) as f64;
}
let mut c: Vec<f64> = vec![0.0; 6 * 4];
unsafe {
microkernel_6x4_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 4, true);
}
for i in 0..6 {
for j in 0..4 {
let expected = (i + 1) as f64 + (j + 1) as f64;
let actual = c[i * 4 + j];
assert!(
(actual - expected).abs() < 1e-10,
"Mismatch at [{i}][{j}]: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_microkernel_accumulate() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
eprintln!("Skipping AVX2 test - CPU doesn't support AVX2+FMA");
return;
}
let a: Vec<f32> = vec![1.0; 12]; let b: Vec<f32> = vec![1.0; 16];
let mut c: Vec<f32> = vec![100.0; 6 * 8];
unsafe {
microkernel_6x8_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 8, false);
}
for i in 0..6 {
for j in 0..8 {
let expected = 102.0f32;
let actual = c[i * 8 + j];
assert!(
(actual - expected).abs() < 1e-5,
"Mismatch at [{i}][{j}]: expected {expected}, got {actual}"
);
}
}
}
}