#[target_feature(enable = "avx2,fma")]
#[allow(clippy::identity_op)]
#[allow(clippy::erasing_op)]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn kernel_12x4_avx2(
a_pack: *const f64,
b_pack: *const f64,
c: *mut f64,
k: usize,
ldc: usize,
) {
use std::arch::x86_64::*;
let mut c0 = _mm256_loadu_pd(c.add(0 * ldc));
let mut c1 = _mm256_loadu_pd(c.add(1 * ldc));
let mut c2 = _mm256_loadu_pd(c.add(2 * ldc));
let mut c3 = _mm256_loadu_pd(c.add(3 * ldc));
let mut c4 = _mm256_loadu_pd(c.add(4 * ldc));
let mut c5 = _mm256_loadu_pd(c.add(5 * ldc));
let mut c6 = _mm256_loadu_pd(c.add(6 * ldc));
let mut c7 = _mm256_loadu_pd(c.add(7 * ldc));
let mut c8 = _mm256_loadu_pd(c.add(8 * ldc));
let mut c9 = _mm256_loadu_pd(c.add(9 * ldc));
let mut c10 = _mm256_loadu_pd(c.add(10 * ldc));
let mut c11 = _mm256_loadu_pd(c.add(11 * ldc));
for p in 0..k {
let b_vec = _mm256_loadu_pd(b_pack.add(p * 4));
c0 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 0)), b_vec, c0);
c1 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 1)), b_vec, c1);
c2 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 2)), b_vec, c2);
c3 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 3)), b_vec, c3);
c4 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 4)), b_vec, c4);
c5 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 5)), b_vec, c5);
c6 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 6)), b_vec, c6);
c7 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 7)), b_vec, c7);
c8 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 8)), b_vec, c8);
c9 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 9)), b_vec, c9);
c10 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 10)), b_vec, c10);
c11 = _mm256_fmadd_pd(_mm256_broadcast_sd(&*a_pack.add(p * 12 + 11)), b_vec, c11);
}
_mm256_storeu_pd(c.add(0 * ldc), c0);
_mm256_storeu_pd(c.add(1 * ldc), c1);
_mm256_storeu_pd(c.add(2 * ldc), c2);
_mm256_storeu_pd(c.add(3 * ldc), c3);
_mm256_storeu_pd(c.add(4 * ldc), c4);
_mm256_storeu_pd(c.add(5 * ldc), c5);
_mm256_storeu_pd(c.add(6 * ldc), c6);
_mm256_storeu_pd(c.add(7 * ldc), c7);
_mm256_storeu_pd(c.add(8 * ldc), c8);
_mm256_storeu_pd(c.add(9 * ldc), c9);
_mm256_storeu_pd(c.add(10 * ldc), c10);
_mm256_storeu_pd(c.add(11 * ldc), c11);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_12x4_correctness() {
if !is_x86_feature_detected!("avx2") {
println!("Skipping - AVX2 not available");
return;
}
let k = 16;
let a: Vec<f64> = (0..12 * k).map(|i| i as f64).collect();
let b: Vec<f64> = (0..k * 4).map(|i| (i % 10) as f64).collect();
let mut c = vec![0.0; 12 * 4];
let mut a_pack = vec![0.0; k * 12];
for p in 0..k {
for i in 0..12 {
a_pack[p * 12 + i] = a[i * k + p];
}
}
let mut b_pack = vec![0.0; k * 4];
for p in 0..k {
for j in 0..4 {
b_pack[p * 4 + j] = b[p * 4 + j];
}
}
unsafe {
kernel_12x4_avx2(a_pack.as_ptr(), b_pack.as_ptr(), c.as_mut_ptr(), k, 4);
}
let mut c_expected = vec![0.0; 12 * 4];
for i in 0..12 {
for j in 0..4 {
for p in 0..k {
c_expected[i * 4 + j] += a[i * k + p] * b[p * 4 + j];
}
}
}
for i in 0..12 * 4 {
assert!(
(c[i] - c_expected[i]).abs() < 1e-10,
"Mismatch at {}: got {}, expected {}",
i,
c[i],
c_expected[i]
);
}
}
}