#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn microkernel_16x8_avx512(
k: usize,
a: *const f32,
b: *const f32,
c: *mut f32,
ldc: usize,
) {
unsafe {
use std::arch::x86_64::*;
let mut c0 = _mm512_loadu_ps(c);
let mut c1 = _mm512_loadu_ps(c.add(ldc));
let mut c2 = _mm512_loadu_ps(c.add(2 * ldc));
let mut c3 = _mm512_loadu_ps(c.add(3 * ldc));
let mut c4 = _mm512_loadu_ps(c.add(4 * ldc));
let mut c5 = _mm512_loadu_ps(c.add(5 * ldc));
let mut c6 = _mm512_loadu_ps(c.add(6 * ldc));
let mut c7 = _mm512_loadu_ps(c.add(7 * ldc));
let k4 = k / 4;
let k_rem = k % 4;
for p4 in 0..k4 {
let base = p4 * 4;
let a0 = _mm512_loadu_ps(a.add(base * 16));
let bp0 = b.add(base * 8);
c0 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0), c0);
c1 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(1)), c1);
c2 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(2)), c2);
c3 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(3)), c3);
c4 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(4)), c4);
c5 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(5)), c5);
c6 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(6)), c6);
c7 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(7)), c7);
let a1 = _mm512_loadu_ps(a.add((base + 1) * 16));
let bp1 = b.add((base + 1) * 8);
c0 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1), c0);
c1 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(1)), c1);
c2 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(2)), c2);
c3 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(3)), c3);
c4 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(4)), c4);
c5 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(5)), c5);
c6 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(6)), c6);
c7 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(7)), c7);
let a2 = _mm512_loadu_ps(a.add((base + 2) * 16));
let bp2 = b.add((base + 2) * 8);
c0 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2), c0);
c1 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(1)), c1);
c2 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(2)), c2);
c3 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(3)), c3);
c4 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(4)), c4);
c5 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(5)), c5);
c6 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(6)), c6);
c7 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(7)), c7);
let a3 = _mm512_loadu_ps(a.add((base + 3) * 16));
let bp3 = b.add((base + 3) * 8);
c0 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3), c0);
c1 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(1)), c1);
c2 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(2)), c2);
c3 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(3)), c3);
c4 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(4)), c4);
c5 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(5)), c5);
c6 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(6)), c6);
c7 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(7)), c7);
}
let base_rem = k4 * 4;
for p in 0..k_rem {
let pp = base_rem + p;
let a_col = _mm512_loadu_ps(a.add(pp * 16));
let bp = b.add(pp * 8);
c0 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp), c0);
c1 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(1)), c1);
c2 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(2)), c2);
c3 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(3)), c3);
c4 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(4)), c4);
c5 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(5)), c5);
c6 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(6)), c6);
c7 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(7)), c7);
}
_mm512_storeu_ps(c, c0);
_mm512_storeu_ps(c.add(ldc), c1);
_mm512_storeu_ps(c.add(2 * ldc), c2);
_mm512_storeu_ps(c.add(3 * ldc), c3);
_mm512_storeu_ps(c.add(4 * ldc), c4);
_mm512_storeu_ps(c.add(5 * ldc), c5);
_mm512_storeu_ps(c.add(6 * ldc), c6);
_mm512_storeu_ps(c.add(7 * ldc), c7);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn microkernel_32x6_avx512(
k: usize,
a: *const f32,
b: *const f32,
c: *mut f32,
ldc: usize,
) {
unsafe {
use std::arch::x86_64::*;
let mut c00 = _mm512_loadu_ps(c);
let mut c01 = _mm512_loadu_ps(c.add(ldc));
let mut c02 = _mm512_loadu_ps(c.add(2 * ldc));
let mut c03 = _mm512_loadu_ps(c.add(3 * ldc));
let mut c04 = _mm512_loadu_ps(c.add(4 * ldc));
let mut c05 = _mm512_loadu_ps(c.add(5 * ldc));
let mut c10 = _mm512_loadu_ps(c.add(16));
let mut c11 = _mm512_loadu_ps(c.add(ldc + 16));
let mut c12 = _mm512_loadu_ps(c.add(2 * ldc + 16));
let mut c13 = _mm512_loadu_ps(c.add(3 * ldc + 16));
let mut c14 = _mm512_loadu_ps(c.add(4 * ldc + 16));
let mut c15 = _mm512_loadu_ps(c.add(5 * ldc + 16));
let k2 = k / 2;
let k_rem = k % 2;
for p2 in 0..k2 {
let base = p2 * 2;
let a0_lo = _mm512_loadu_ps(a.add(base * 32));
let a0_hi = _mm512_loadu_ps(a.add(base * 32 + 16));
let bp0 = b.add(base * 6);
let b0 = _mm512_set1_ps(*bp0);
c00 = _mm512_fmadd_ps(a0_lo, b0, c00);
c10 = _mm512_fmadd_ps(a0_hi, b0, c10);
let b1 = _mm512_set1_ps(*bp0.add(1));
c01 = _mm512_fmadd_ps(a0_lo, b1, c01);
c11 = _mm512_fmadd_ps(a0_hi, b1, c11);
let b2 = _mm512_set1_ps(*bp0.add(2));
c02 = _mm512_fmadd_ps(a0_lo, b2, c02);
c12 = _mm512_fmadd_ps(a0_hi, b2, c12);
let b3 = _mm512_set1_ps(*bp0.add(3));
c03 = _mm512_fmadd_ps(a0_lo, b3, c03);
c13 = _mm512_fmadd_ps(a0_hi, b3, c13);
let b4 = _mm512_set1_ps(*bp0.add(4));
c04 = _mm512_fmadd_ps(a0_lo, b4, c04);
c14 = _mm512_fmadd_ps(a0_hi, b4, c14);
let b5 = _mm512_set1_ps(*bp0.add(5));
c05 = _mm512_fmadd_ps(a0_lo, b5, c05);
c15 = _mm512_fmadd_ps(a0_hi, b5, c15);
let a1_lo = _mm512_loadu_ps(a.add((base + 1) * 32));
let a1_hi = _mm512_loadu_ps(a.add((base + 1) * 32 + 16));
let bp1 = b.add((base + 1) * 6);
let b0 = _mm512_set1_ps(*bp1);
c00 = _mm512_fmadd_ps(a1_lo, b0, c00);
c10 = _mm512_fmadd_ps(a1_hi, b0, c10);
let b1 = _mm512_set1_ps(*bp1.add(1));
c01 = _mm512_fmadd_ps(a1_lo, b1, c01);
c11 = _mm512_fmadd_ps(a1_hi, b1, c11);
let b2 = _mm512_set1_ps(*bp1.add(2));
c02 = _mm512_fmadd_ps(a1_lo, b2, c02);
c12 = _mm512_fmadd_ps(a1_hi, b2, c12);
let b3 = _mm512_set1_ps(*bp1.add(3));
c03 = _mm512_fmadd_ps(a1_lo, b3, c03);
c13 = _mm512_fmadd_ps(a1_hi, b3, c13);
let b4 = _mm512_set1_ps(*bp1.add(4));
c04 = _mm512_fmadd_ps(a1_lo, b4, c04);
c14 = _mm512_fmadd_ps(a1_hi, b4, c14);
let b5 = _mm512_set1_ps(*bp1.add(5));
c05 = _mm512_fmadd_ps(a1_lo, b5, c05);
c15 = _mm512_fmadd_ps(a1_hi, b5, c15);
}
let base_rem = k2 * 2;
for p in 0..k_rem {
let pp = base_rem + p;
let a_lo = _mm512_loadu_ps(a.add(pp * 32));
let a_hi = _mm512_loadu_ps(a.add(pp * 32 + 16));
let bp = b.add(pp * 6);
let b0 = _mm512_set1_ps(*bp);
c00 = _mm512_fmadd_ps(a_lo, b0, c00);
c10 = _mm512_fmadd_ps(a_hi, b0, c10);
let b1 = _mm512_set1_ps(*bp.add(1));
c01 = _mm512_fmadd_ps(a_lo, b1, c01);
c11 = _mm512_fmadd_ps(a_hi, b1, c11);
let b2 = _mm512_set1_ps(*bp.add(2));
c02 = _mm512_fmadd_ps(a_lo, b2, c02);
c12 = _mm512_fmadd_ps(a_hi, b2, c12);
let b3 = _mm512_set1_ps(*bp.add(3));
c03 = _mm512_fmadd_ps(a_lo, b3, c03);
c13 = _mm512_fmadd_ps(a_hi, b3, c13);
let b4 = _mm512_set1_ps(*bp.add(4));
c04 = _mm512_fmadd_ps(a_lo, b4, c04);
c14 = _mm512_fmadd_ps(a_hi, b4, c14);
let b5 = _mm512_set1_ps(*bp.add(5));
c05 = _mm512_fmadd_ps(a_lo, b5, c05);
c15 = _mm512_fmadd_ps(a_hi, b5, c15);
}
_mm512_storeu_ps(c, c00);
_mm512_storeu_ps(c.add(ldc), c01);
_mm512_storeu_ps(c.add(2 * ldc), c02);
_mm512_storeu_ps(c.add(3 * ldc), c03);
_mm512_storeu_ps(c.add(4 * ldc), c04);
_mm512_storeu_ps(c.add(5 * ldc), c05);
_mm512_storeu_ps(c.add(16), c10);
_mm512_storeu_ps(c.add(ldc + 16), c11);
_mm512_storeu_ps(c.add(2 * ldc + 16), c12);
_mm512_storeu_ps(c.add(3 * ldc + 16), c13);
_mm512_storeu_ps(c.add(4 * ldc + 16), c14);
_mm512_storeu_ps(c.add(5 * ldc + 16), c15);
}
}