use super::super::{MR, NR};
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn microkernel_8x6_avx2(
k: usize,
a: *const f32, b: *const f32, c: *mut f32, ldc: usize, ) {
unsafe {
use std::arch::x86_64::*;
let mut c0 = _mm256_loadu_ps(c);
let mut c1 = _mm256_loadu_ps(c.add(ldc));
let mut c2 = _mm256_loadu_ps(c.add(2 * ldc));
let mut c3 = _mm256_loadu_ps(c.add(3 * ldc));
let mut c4 = _mm256_loadu_ps(c.add(4 * ldc));
let mut c5 = _mm256_loadu_ps(c.add(5 * ldc));
for p in 0..k {
let a_col = _mm256_loadu_ps(a.add(p * MR));
let b0 = _mm256_set1_ps(*b.add(p * NR));
let b1 = _mm256_set1_ps(*b.add(p * NR + 1));
let b2 = _mm256_set1_ps(*b.add(p * NR + 2));
let b3 = _mm256_set1_ps(*b.add(p * NR + 3));
let b4 = _mm256_set1_ps(*b.add(p * NR + 4));
let b5 = _mm256_set1_ps(*b.add(p * NR + 5));
c0 = _mm256_fmadd_ps(a_col, b0, c0);
c1 = _mm256_fmadd_ps(a_col, b1, c1);
c2 = _mm256_fmadd_ps(a_col, b2, c2);
c3 = _mm256_fmadd_ps(a_col, b3, c3);
c4 = _mm256_fmadd_ps(a_col, b4, c4);
c5 = _mm256_fmadd_ps(a_col, b5, c5);
}
_mm256_storeu_ps(c, c0);
_mm256_storeu_ps(c.add(ldc), c1);
_mm256_storeu_ps(c.add(2 * ldc), c2);
_mm256_storeu_ps(c.add(3 * ldc), c3);
_mm256_storeu_ps(c.add(4 * ldc), c4);
_mm256_storeu_ps(c.add(5 * ldc), c5);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn microkernel_8x6_avx2_asm(
k: usize,
a: *const f32, b: *const f32, c: *mut f32, ldc: usize, ) {
unsafe {
use std::arch::x86_64::*;
if k < 4 {
microkernel_8x6_avx2(k, a, b, c, ldc);
return;
}
let mut c0 = _mm256_loadu_ps(c);
let mut c1 = _mm256_loadu_ps(c.add(ldc));
let mut c2 = _mm256_loadu_ps(c.add(2 * ldc));
let mut c3 = _mm256_loadu_ps(c.add(3 * ldc));
let mut c4 = _mm256_loadu_ps(c.add(4 * ldc));
let mut c5 = _mm256_loadu_ps(c.add(5 * ldc));
let k_unrolled = k / 4;
let k_remainder = k % 4;
for p in 0..k_unrolled {
let base_p = p * 4;
let a0 = _mm256_loadu_ps(a.add((base_p) * MR));
let b00 = _mm256_broadcast_ss(&*b.add((base_p) * NR));
let b01 = _mm256_broadcast_ss(&*b.add((base_p) * NR + 1));
let b02 = _mm256_broadcast_ss(&*b.add((base_p) * NR + 2));
let b03 = _mm256_broadcast_ss(&*b.add((base_p) * NR + 3));
let b04 = _mm256_broadcast_ss(&*b.add((base_p) * NR + 4));
let b05 = _mm256_broadcast_ss(&*b.add((base_p) * NR + 5));
let a1 = _mm256_loadu_ps(a.add((base_p + 1) * MR));
c0 = _mm256_fmadd_ps(a0, b00, c0);
c1 = _mm256_fmadd_ps(a0, b01, c1);
c2 = _mm256_fmadd_ps(a0, b02, c2);
let b10 = _mm256_broadcast_ss(&*b.add((base_p + 1) * NR));
let b11 = _mm256_broadcast_ss(&*b.add((base_p + 1) * NR + 1));
let b12 = _mm256_broadcast_ss(&*b.add((base_p + 1) * NR + 2));
c3 = _mm256_fmadd_ps(a0, b03, c3);
c4 = _mm256_fmadd_ps(a0, b04, c4);
c5 = _mm256_fmadd_ps(a0, b05, c5);
let b13 = _mm256_broadcast_ss(&*b.add((base_p + 1) * NR + 3));
let b14 = _mm256_broadcast_ss(&*b.add((base_p + 1) * NR + 4));
let b15 = _mm256_broadcast_ss(&*b.add((base_p + 1) * NR + 5));
let a2 = _mm256_loadu_ps(a.add((base_p + 2) * MR));
c0 = _mm256_fmadd_ps(a1, b10, c0);
c1 = _mm256_fmadd_ps(a1, b11, c1);
c2 = _mm256_fmadd_ps(a1, b12, c2);
let b20 = _mm256_broadcast_ss(&*b.add((base_p + 2) * NR));
let b21 = _mm256_broadcast_ss(&*b.add((base_p + 2) * NR + 1));
let b22 = _mm256_broadcast_ss(&*b.add((base_p + 2) * NR + 2));
c3 = _mm256_fmadd_ps(a1, b13, c3);
c4 = _mm256_fmadd_ps(a1, b14, c4);
c5 = _mm256_fmadd_ps(a1, b15, c5);
let b23 = _mm256_broadcast_ss(&*b.add((base_p + 2) * NR + 3));
let b24 = _mm256_broadcast_ss(&*b.add((base_p + 2) * NR + 4));
let b25 = _mm256_broadcast_ss(&*b.add((base_p + 2) * NR + 5));
let a3 = _mm256_loadu_ps(a.add((base_p + 3) * MR));
c0 = _mm256_fmadd_ps(a2, b20, c0);
c1 = _mm256_fmadd_ps(a2, b21, c1);
c2 = _mm256_fmadd_ps(a2, b22, c2);
let b30 = _mm256_broadcast_ss(&*b.add((base_p + 3) * NR));
let b31 = _mm256_broadcast_ss(&*b.add((base_p + 3) * NR + 1));
let b32 = _mm256_broadcast_ss(&*b.add((base_p + 3) * NR + 2));
c3 = _mm256_fmadd_ps(a2, b23, c3);
c4 = _mm256_fmadd_ps(a2, b24, c4);
c5 = _mm256_fmadd_ps(a2, b25, c5);
let b33 = _mm256_broadcast_ss(&*b.add((base_p + 3) * NR + 3));
let b34 = _mm256_broadcast_ss(&*b.add((base_p + 3) * NR + 4));
let b35 = _mm256_broadcast_ss(&*b.add((base_p + 3) * NR + 5));
c0 = _mm256_fmadd_ps(a3, b30, c0);
c1 = _mm256_fmadd_ps(a3, b31, c1);
c2 = _mm256_fmadd_ps(a3, b32, c2);
c3 = _mm256_fmadd_ps(a3, b33, c3);
c4 = _mm256_fmadd_ps(a3, b34, c4);
c5 = _mm256_fmadd_ps(a3, b35, c5);
}
let base_p = k_unrolled * 4;
for p in 0..k_remainder {
let pp = base_p + p;
let a_col = _mm256_loadu_ps(a.add(pp * MR));
let b0 = _mm256_broadcast_ss(&*b.add(pp * NR));
let b1 = _mm256_broadcast_ss(&*b.add(pp * NR + 1));
let b2 = _mm256_broadcast_ss(&*b.add(pp * NR + 2));
let b3 = _mm256_broadcast_ss(&*b.add(pp * NR + 3));
let b4 = _mm256_broadcast_ss(&*b.add(pp * NR + 4));
let b5 = _mm256_broadcast_ss(&*b.add(pp * NR + 5));
c0 = _mm256_fmadd_ps(a_col, b0, c0);
c1 = _mm256_fmadd_ps(a_col, b1, c1);
c2 = _mm256_fmadd_ps(a_col, b2, c2);
c3 = _mm256_fmadd_ps(a_col, b3, c3);
c4 = _mm256_fmadd_ps(a_col, b4, c4);
c5 = _mm256_fmadd_ps(a_col, b5, c5);
}
_mm256_storeu_ps(c, c0);
_mm256_storeu_ps(c.add(ldc), c1);
_mm256_storeu_ps(c.add(2 * ldc), c2);
_mm256_storeu_ps(c.add(3 * ldc), c3);
_mm256_storeu_ps(c.add(4 * ldc), c4);
_mm256_storeu_ps(c.add(5 * ldc), c5);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn microkernel_8x6_true_asm(
k: usize,
a: *const f32,
b: *const f32,
c: *mut f32,
ldc: usize,
) {
unsafe {
use std::arch::asm;
if k < 4 {
microkernel_8x6_avx2(k, a, b, c, ldc);
return;
}
let ldc_bytes = ldc * 4;
asm!(
"vmovups ymm0, [{c_ptr}]",
"vmovups ymm1, [{c_ptr} + {ldc}]",
"vmovups ymm2, [{c_ptr} + {ldc}*2]",
"lea {tmp}, [{c_ptr} + {ldc}*2]",
"vmovups ymm3, [{tmp} + {ldc}]",
"vmovups ymm4, [{tmp} + {ldc}*2]",
"lea {tmp}, [{tmp} + {ldc}*2]",
"vmovups ymm5, [{tmp} + {ldc}]",
"vmovups ymm6, [{a_ptr}]", "vmovups ymm7, [{a_ptr} + 32]", "vmovups ymm8, [{a_ptr} + 64]", "vmovups ymm9, [{a_ptr} + 96]", "add {a_ptr}, 128",
"mov {k_cnt}, {k}",
"shr {k_cnt}, 2", "test {k_cnt}, {k_cnt}",
"jz 2f",
".p2align 4", "3:",
"vbroadcastss ymm10, dword ptr [{b_ptr}]",
"vbroadcastss ymm11, dword ptr [{b_ptr} + 4]",
"vbroadcastss ymm12, dword ptr [{b_ptr} + 8]",
"vfmadd231ps ymm0, ymm6, ymm10", "vfmadd231ps ymm1, ymm6, ymm11", "vfmadd231ps ymm2, ymm6, ymm12", "vbroadcastss ymm13, dword ptr [{b_ptr} + 12]",
"vbroadcastss ymm14, dword ptr [{b_ptr} + 16]",
"vbroadcastss ymm15, dword ptr [{b_ptr} + 20]",
"vfmadd231ps ymm3, ymm6, ymm13", "vfmadd231ps ymm4, ymm6, ymm14", "vfmadd231ps ymm5, ymm6, ymm15", "vmovups ymm6, [{a_ptr}]",
"vbroadcastss ymm10, dword ptr [{b_ptr} + 24]",
"vbroadcastss ymm11, dword ptr [{b_ptr} + 28]",
"vbroadcastss ymm12, dword ptr [{b_ptr} + 32]",
"vfmadd231ps ymm0, ymm7, ymm10",
"vfmadd231ps ymm1, ymm7, ymm11",
"vfmadd231ps ymm2, ymm7, ymm12",
"vbroadcastss ymm13, dword ptr [{b_ptr} + 36]",
"vbroadcastss ymm14, dword ptr [{b_ptr} + 40]",
"vbroadcastss ymm15, dword ptr [{b_ptr} + 44]",
"vfmadd231ps ymm3, ymm7, ymm13",
"vfmadd231ps ymm4, ymm7, ymm14",
"vfmadd231ps ymm5, ymm7, ymm15",
"vmovups ymm7, [{a_ptr} + 32]",
"vbroadcastss ymm10, dword ptr [{b_ptr} + 48]",
"vbroadcastss ymm11, dword ptr [{b_ptr} + 52]",
"vbroadcastss ymm12, dword ptr [{b_ptr} + 56]",
"vfmadd231ps ymm0, ymm8, ymm10",
"vfmadd231ps ymm1, ymm8, ymm11",
"vfmadd231ps ymm2, ymm8, ymm12",
"vbroadcastss ymm13, dword ptr [{b_ptr} + 60]",
"vbroadcastss ymm14, dword ptr [{b_ptr} + 64]",
"vbroadcastss ymm15, dword ptr [{b_ptr} + 68]",
"vfmadd231ps ymm3, ymm8, ymm13",
"vfmadd231ps ymm4, ymm8, ymm14",
"vfmadd231ps ymm5, ymm8, ymm15",
"vmovups ymm8, [{a_ptr} + 64]",
"vbroadcastss ymm10, dword ptr [{b_ptr} + 72]",
"vbroadcastss ymm11, dword ptr [{b_ptr} + 76]",
"vbroadcastss ymm12, dword ptr [{b_ptr} + 80]",
"vfmadd231ps ymm0, ymm9, ymm10",
"vfmadd231ps ymm1, ymm9, ymm11",
"vfmadd231ps ymm2, ymm9, ymm12",
"vbroadcastss ymm13, dword ptr [{b_ptr} + 84]",
"vbroadcastss ymm14, dword ptr [{b_ptr} + 88]",
"vbroadcastss ymm15, dword ptr [{b_ptr} + 92]",
"vfmadd231ps ymm3, ymm9, ymm13",
"vfmadd231ps ymm4, ymm9, ymm14",
"vfmadd231ps ymm5, ymm9, ymm15",
"vmovups ymm9, [{a_ptr} + 96]",
"add {a_ptr}, 128", "add {b_ptr}, 96",
"dec {k_cnt}",
"jnz 3b",
"2:",
"vmovups [{c_ptr}], ymm0",
"vmovups [{c_ptr} + {ldc}], ymm1",
"vmovups [{c_ptr} + {ldc}*2], ymm2",
"lea {tmp}, [{c_ptr} + {ldc}*2]",
"vmovups [{tmp} + {ldc}], ymm3",
"vmovups [{tmp} + {ldc}*2], ymm4",
"lea {tmp}, [{tmp} + {ldc}*2]",
"vmovups [{tmp} + {ldc}], ymm5",
a_ptr = inout(reg) a => _,
b_ptr = inout(reg) b => _,
c_ptr = in(reg) c,
k = in(reg) k,
ldc = in(reg) ldc_bytes,
k_cnt = out(reg) _,
tmp = out(reg) _,
out("ymm0") _,
out("ymm1") _,
out("ymm2") _,
out("ymm3") _,
out("ymm4") _,
out("ymm5") _,
out("ymm6") _,
out("ymm7") _,
out("ymm8") _,
out("ymm9") _,
out("ymm10") _,
out("ymm11") _,
out("ymm12") _,
out("ymm13") _,
out("ymm14") _,
out("ymm15") _,
options(nostack),
);
let k_rem = k % 4;
if k_rem > 0 {
let k_done = (k / 4) * 4;
let a_rem = a.add(k_done * MR);
let b_rem = b.add(k_done * NR);
use std::arch::x86_64::*;
let mut c0 = _mm256_loadu_ps(c);
let mut c1 = _mm256_loadu_ps(c.add(ldc));
let mut c2 = _mm256_loadu_ps(c.add(2 * ldc));
let mut c3 = _mm256_loadu_ps(c.add(3 * ldc));
let mut c4 = _mm256_loadu_ps(c.add(4 * ldc));
let mut c5 = _mm256_loadu_ps(c.add(5 * ldc));
for p in 0..k_rem {
let a_col = _mm256_loadu_ps(a_rem.add(p * MR));
let b0 = _mm256_broadcast_ss(&*b_rem.add(p * NR));
let b1 = _mm256_broadcast_ss(&*b_rem.add(p * NR + 1));
let b2 = _mm256_broadcast_ss(&*b_rem.add(p * NR + 2));
let b3 = _mm256_broadcast_ss(&*b_rem.add(p * NR + 3));
let b4 = _mm256_broadcast_ss(&*b_rem.add(p * NR + 4));
let b5 = _mm256_broadcast_ss(&*b_rem.add(p * NR + 5));
c0 = _mm256_fmadd_ps(a_col, b0, c0);
c1 = _mm256_fmadd_ps(a_col, b1, c1);
c2 = _mm256_fmadd_ps(a_col, b2, c2);
c3 = _mm256_fmadd_ps(a_col, b3, c3);
c4 = _mm256_fmadd_ps(a_col, b4, c4);
c5 = _mm256_fmadd_ps(a_col, b5, c5);
}
_mm256_storeu_ps(c, c0);
_mm256_storeu_ps(c.add(ldc), c1);
_mm256_storeu_ps(c.add(2 * ldc), c2);
_mm256_storeu_ps(c.add(3 * ldc), c3);
_mm256_storeu_ps(c.add(4 * ldc), c4);
_mm256_storeu_ps(c.add(5 * ldc), c5);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn microkernel_8x8_avx2_fma(
k: usize,
a: *const f32,
b: *const f32,
c: *mut f32,
ldc: usize,
) {
unsafe {
use std::arch::x86_64::*;
let mut c0 = _mm256_loadu_ps(c);
let mut c1 = _mm256_loadu_ps(c.add(ldc));
let mut c2 = _mm256_loadu_ps(c.add(2 * ldc));
let mut c3 = _mm256_loadu_ps(c.add(3 * ldc));
let mut c4 = _mm256_loadu_ps(c.add(4 * ldc));
let mut c5 = _mm256_loadu_ps(c.add(5 * ldc));
let mut c6 = _mm256_loadu_ps(c.add(6 * ldc));
let mut c7 = _mm256_loadu_ps(c.add(7 * ldc));
let k4 = k / 4;
let k_rem = k % 4;
for p4 in 0..k4 {
let base = p4 * 4;
let a0 = _mm256_loadu_ps(a.add(base * 8));
let bp0 = b.add(base * 8);
c0 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0), c0);
c1 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0.add(1)), c1);
c2 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0.add(2)), c2);
c3 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0.add(3)), c3);
c4 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0.add(4)), c4);
c5 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0.add(5)), c5);
c6 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0.add(6)), c6);
c7 = _mm256_fmadd_ps(a0, _mm256_broadcast_ss(&*bp0.add(7)), c7);
let a1 = _mm256_loadu_ps(a.add((base + 1) * 8));
let bp1 = b.add((base + 1) * 8);
c0 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1), c0);
c1 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1.add(1)), c1);
c2 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1.add(2)), c2);
c3 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1.add(3)), c3);
c4 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1.add(4)), c4);
c5 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1.add(5)), c5);
c6 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1.add(6)), c6);
c7 = _mm256_fmadd_ps(a1, _mm256_broadcast_ss(&*bp1.add(7)), c7);
let a2 = _mm256_loadu_ps(a.add((base + 2) * 8));
let bp2 = b.add((base + 2) * 8);
c0 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2), c0);
c1 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2.add(1)), c1);
c2 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2.add(2)), c2);
c3 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2.add(3)), c3);
c4 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2.add(4)), c4);
c5 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2.add(5)), c5);
c6 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2.add(6)), c6);
c7 = _mm256_fmadd_ps(a2, _mm256_broadcast_ss(&*bp2.add(7)), c7);
let a3 = _mm256_loadu_ps(a.add((base + 3) * 8));
let bp3 = b.add((base + 3) * 8);
c0 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3), c0);
c1 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3.add(1)), c1);
c2 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3.add(2)), c2);
c3 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3.add(3)), c3);
c4 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3.add(4)), c4);
c5 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3.add(5)), c5);
c6 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3.add(6)), c6);
c7 = _mm256_fmadd_ps(a3, _mm256_broadcast_ss(&*bp3.add(7)), c7);
}
let base_rem = k4 * 4;
for p in 0..k_rem {
let pp = base_rem + p;
let a_col = _mm256_loadu_ps(a.add(pp * 8));
let bp = b.add(pp * 8);
c0 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp), c0);
c1 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp.add(1)), c1);
c2 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp.add(2)), c2);
c3 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp.add(3)), c3);
c4 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp.add(4)), c4);
c5 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp.add(5)), c5);
c6 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp.add(6)), c6);
c7 = _mm256_fmadd_ps(a_col, _mm256_broadcast_ss(&*bp.add(7)), c7);
}
_mm256_storeu_ps(c, c0);
_mm256_storeu_ps(c.add(ldc), c1);
_mm256_storeu_ps(c.add(2 * ldc), c2);
_mm256_storeu_ps(c.add(3 * ldc), c3);
_mm256_storeu_ps(c.add(4 * ldc), c4);
_mm256_storeu_ps(c.add(5 * ldc), c5);
_mm256_storeu_ps(c.add(6 * ldc), c6);
_mm256_storeu_ps(c.add(7 * ldc), c7);
}
}