#![allow(unsafe_op_in_unsafe_fn)]
use std::cell::RefCell;
use std::time::Instant;
use crate::error::TruenoError;
#[cfg(target_arch = "x86_64")]
use super::microkernels::microkernel_16x8_avx512;
#[cfg(target_arch = "x86_64")]
use super::microkernels::microkernel_8x6_true_asm;
use super::microkernels::microkernel_scalar;
use super::packing::{pack_a_block, pack_b_block, packed_a_size, packed_b_size};
#[cfg(target_arch = "x86_64")]
use super::packing::{pack_a_block_512, pack_b_block_512, packed_a_size_512, packed_b_size_512};
use super::prepacked::PrepackedB;
use super::profiler::{BlisProfileLevel, BlisProfiler};
use super::reference::gemm_reference;
use super::{KC, MC, MR, NC, NR};
#[cfg(target_arch = "x86_64")]
use super::{KC_512, MC_512, MR_512, NC_512, NR_512};
thread_local! {
static TL_PACKED_A: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
static TL_PACKED_B: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
static TL_C_MICRO: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
}
#[inline(always)]
fn load_c_tile(
c: &[f32],
c_micro: &mut [f32],
row: usize,
col: usize,
mr: usize,
nr: usize,
n: usize,
) {
for jj in 0..nr {
for ii in 0..mr {
c_micro[jj * MR + ii] = c[(row + ii) * n + (col + jj)];
}
for ii in mr..MR {
c_micro[jj * MR + ii] = 0.0;
}
}
for jj in nr..NR {
for ii in 0..MR {
c_micro[jj * MR + ii] = 0.0;
}
}
}
#[inline(always)]
fn store_c_tile(
c: &mut [f32],
c_micro: &[f32],
row: usize,
col: usize,
mr: usize,
nr: usize,
n: usize,
) {
for jj in 0..nr {
for ii in 0..mr {
c[(row + ii) * n + (col + jj)] = c_micro[jj * MR + ii];
}
}
}
#[inline(always)]
fn dispatch_microkernel(
kc: usize,
a_panel: &[f32],
b_panel: &[f32],
c_micro: &mut [f32],
mr_block: usize,
nr_block: usize,
) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2")
&& is_x86_feature_detected!("fma")
&& mr_block == MR
&& nr_block == NR
{
unsafe {
microkernel_8x6_true_asm(
kc,
a_panel.as_ptr(),
b_panel.as_ptr(),
c_micro.as_mut_ptr(),
MR,
);
}
return;
}
}
microkernel_scalar(kc, a_panel, b_panel, c_micro, MR);
}
#[allow(clippy::too_many_arguments)]
fn compute_macroblock(
c: &mut [f32],
packed_a: &[f32],
packed_b: &[f32],
c_micro: &mut [f32],
ic: usize,
jc: usize,
mc_block: usize,
nc_block: usize,
kc_block: usize,
n: usize,
profiler: &mut Option<&mut BlisProfiler>,
) {
let track_time = profiler.is_some();
let midi_start = if track_time { Some(Instant::now()) } else { None };
for ir in (0..mc_block).step_by(MR) {
let mr_block = MR.min(mc_block - ir);
for jr in (0..nc_block).step_by(NR) {
let nr_block = NR.min(nc_block - jr);
let micro_start = if track_time { Some(Instant::now()) } else { None };
let a_panel = &packed_a[(ir / MR) * MR * kc_block..];
let b_panel = &packed_b[(jr / NR) * NR * kc_block..];
load_c_tile(c, c_micro, ic + ir, jc + jr, mr_block, nr_block, n);
dispatch_microkernel(kc_block, a_panel, b_panel, c_micro, mr_block, nr_block);
store_c_tile(c, c_micro, ic + ir, jc + jr, mr_block, nr_block, n);
if let (Some(ref mut prof), Some(start)) = (profiler.as_deref_mut(), micro_start) {
prof.record(
BlisProfileLevel::Micro,
start.elapsed().as_nanos() as u64,
(2 * mr_block * nr_block * kc_block) as u64,
);
}
}
}
if let (Some(ref mut prof), Some(start)) = (profiler.as_deref_mut(), midi_start) {
prof.record(
BlisProfileLevel::Midi,
start.elapsed().as_nanos() as u64,
(2 * mc_block * nc_block * kc_block) as u64,
);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn gemm_direct_rowmajor(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use std::arch::x86_64::*;
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let c_ptr = c.as_mut_ptr();
unsafe {
for ir in (0..m).step_by(8) {
for jr in (0..n).step_by(8) {
let c_base = c_ptr.add(ir * n + jr);
let mut c0 = _mm256_loadu_ps(c_base);
let mut c1 = _mm256_loadu_ps(c_base.add(n));
let mut c2 = _mm256_loadu_ps(c_base.add(2 * n));
let mut c3 = _mm256_loadu_ps(c_base.add(3 * n));
let mut c4 = _mm256_loadu_ps(c_base.add(4 * n));
let mut c5 = _mm256_loadu_ps(c_base.add(5 * n));
let mut c6 = _mm256_loadu_ps(c_base.add(6 * n));
let mut c7 = _mm256_loadu_ps(c_base.add(7 * n));
let a0 = a_ptr.add(ir * k);
let a1 = a_ptr.add((ir + 1) * k);
let a2 = a_ptr.add((ir + 2) * k);
let a3 = a_ptr.add((ir + 3) * k);
let a4 = a_ptr.add((ir + 4) * k);
let a5 = a_ptr.add((ir + 5) * k);
let a6 = a_ptr.add((ir + 6) * k);
let a7 = a_ptr.add((ir + 7) * k);
let b_base = b_ptr.add(jr);
let k4 = k / 4;
let k_rem = k % 4;
for p4 in 0..k4 {
let p = p4 * 4;
let b_row = _mm256_loadu_ps(b_base.add(p * n));
c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p)), b_row, c0);
c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p)), b_row, c1);
c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p)), b_row, c2);
c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p)), b_row, c3);
c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p)), b_row, c4);
c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p)), b_row, c5);
c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p)), b_row, c6);
c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p)), b_row, c7);
let b_row = _mm256_loadu_ps(b_base.add((p + 1) * n));
c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 1)), b_row, c0);
c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 1)), b_row, c1);
c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 1)), b_row, c2);
c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 1)), b_row, c3);
c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 1)), b_row, c4);
c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 1)), b_row, c5);
c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 1)), b_row, c6);
c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 1)), b_row, c7);
let b_row = _mm256_loadu_ps(b_base.add((p + 2) * n));
c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 2)), b_row, c0);
c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 2)), b_row, c1);
c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 2)), b_row, c2);
c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 2)), b_row, c3);
c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 2)), b_row, c4);
c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 2)), b_row, c5);
c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 2)), b_row, c6);
c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 2)), b_row, c7);
let b_row = _mm256_loadu_ps(b_base.add((p + 3) * n));
c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 3)), b_row, c0);
c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 3)), b_row, c1);
c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 3)), b_row, c2);
c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 3)), b_row, c3);
c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 3)), b_row, c4);
c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 3)), b_row, c5);
c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 3)), b_row, c6);
c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 3)), b_row, c7);
}
let base_rem = k4 * 4;
for rp in 0..k_rem {
let pp = base_rem + rp;
let b_row = _mm256_loadu_ps(b_base.add(pp * n));
c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(pp)), b_row, c0);
c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(pp)), b_row, c1);
c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(pp)), b_row, c2);
c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(pp)), b_row, c3);
c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(pp)), b_row, c4);
c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(pp)), b_row, c5);
c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(pp)), b_row, c6);
c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(pp)), b_row, c7);
}
_mm256_storeu_ps(c_base, c0);
_mm256_storeu_ps(c_base.add(n), c1);
_mm256_storeu_ps(c_base.add(2 * n), c2);
_mm256_storeu_ps(c_base.add(3 * n), c3);
_mm256_storeu_ps(c_base.add(4 * n), c4);
_mm256_storeu_ps(c_base.add(5 * n), c5);
_mm256_storeu_ps(c_base.add(6 * n), c6);
_mm256_storeu_ps(c_base.add(7 * n), c7);
}
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn gemm_small_strided_avx2(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use std::arch::x86_64::*;
unsafe {
for jr in (0..n).step_by(NR) {
let nr = NR.min(n - jr);
for ir in (0..m).step_by(MR) {
let mr = MR.min(m - ir);
let mut cv = [_mm256_setzero_ps(); 6];
for j in 0..nr {
if mr == MR {
cv[j] = _mm256_set_ps(
*c.get_unchecked((ir + 7) * n + jr + j),
*c.get_unchecked((ir + 6) * n + jr + j),
*c.get_unchecked((ir + 5) * n + jr + j),
*c.get_unchecked((ir + 4) * n + jr + j),
*c.get_unchecked((ir + 3) * n + jr + j),
*c.get_unchecked((ir + 2) * n + jr + j),
*c.get_unchecked((ir + 1) * n + jr + j),
*c.get_unchecked(ir * n + jr + j),
);
} else {
let mut t = [0.0f32; 8];
for i in 0..mr {
t[i] = *c.get_unchecked((ir + i) * n + jr + j);
}
cv[j] = _mm256_loadu_ps(t.as_ptr());
}
}
for p in 0..k {
let a_col = if mr == MR {
_mm256_set_ps(
*a.get_unchecked((ir + 7) * k + p),
*a.get_unchecked((ir + 6) * k + p),
*a.get_unchecked((ir + 5) * k + p),
*a.get_unchecked((ir + 4) * k + p),
*a.get_unchecked((ir + 3) * k + p),
*a.get_unchecked((ir + 2) * k + p),
*a.get_unchecked((ir + 1) * k + p),
*a.get_unchecked(ir * k + p),
)
} else {
let mut t = [0.0f32; 8];
for i in 0..mr {
t[i] = *a.get_unchecked((ir + i) * k + p);
}
_mm256_loadu_ps(t.as_ptr())
};
let bp = b.as_ptr().add(p * n + jr);
if nr == NR {
cv[0] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp), cv[0]);
cv[1] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(1)), cv[1]);
cv[2] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(2)), cv[2]);
cv[3] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(3)), cv[3]);
cv[4] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(4)), cv[4]);
cv[5] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(5)), cv[5]);
} else {
for j in 0..nr {
cv[j] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(j)), cv[j]);
}
}
}
for j in 0..nr {
let mut t = [0.0f32; 8];
_mm256_storeu_ps(t.as_mut_ptr(), cv[j]);
for i in 0..mr {
*c.get_unchecked_mut((ir + i) * n + jr + j) = t[i];
}
}
}
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn gemm_small_nopack_8x8(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use crate::blis::microkernels::microkernel_8x8_avx2_fma;
use std::arch::x86_64::*;
let panels_m = m / 8;
let panels_n = n / 8;
let mut packed_a = [0.0f32; 8 * 256];
let mut c_micro = [0.0f32; 64];
let mut all_packed_b = vec![0.0f32; panels_n * k * 8];
unsafe {
for jr_panel in 0..panels_n {
let jr = jr_panel * 8;
let b_dst = all_packed_b.as_mut_ptr().add(jr_panel * k * 8);
for p in 0..k {
_mm256_storeu_ps(b_dst.add(p * 8), _mm256_loadu_ps(b.as_ptr().add(p * n + jr)));
}
}
for ir_panel in 0..panels_m {
let ir = ir_panel * 8;
let k_blocks = k / 8;
let k_rem = k_blocks * 8;
for kb in 0..k_blocks {
let p = kb * 8;
let r0 = _mm256_loadu_ps(a.as_ptr().add(ir * k + p));
let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 1) * k + p));
let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 2) * k + p));
let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 3) * k + p));
let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 4) * k + p));
let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 5) * k + p));
let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 6) * k + p));
let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 7) * k + p));
let t0 = _mm256_unpacklo_ps(r0, r1);
let t1 = _mm256_unpackhi_ps(r0, r1);
let t2 = _mm256_unpacklo_ps(r2, r3);
let t3 = _mm256_unpackhi_ps(r2, r3);
let t4 = _mm256_unpacklo_ps(r4, r5);
let t5 = _mm256_unpackhi_ps(r4, r5);
let t6 = _mm256_unpacklo_ps(r6, r7);
let t7 = _mm256_unpackhi_ps(r6, r7);
let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
let dst = packed_a.as_mut_ptr().add(p * 8);
_mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
_mm256_storeu_ps(dst.add(8), _mm256_permute2f128_ps(u1, u5, 0x20));
_mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u2, u6, 0x20));
_mm256_storeu_ps(dst.add(24), _mm256_permute2f128_ps(u3, u7, 0x20));
_mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u0, u4, 0x31));
_mm256_storeu_ps(dst.add(40), _mm256_permute2f128_ps(u1, u5, 0x31));
_mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u2, u6, 0x31));
_mm256_storeu_ps(dst.add(56), _mm256_permute2f128_ps(u3, u7, 0x31));
}
for p in k_rem..k {
for i in 0..8 {
*packed_a.get_unchecked_mut(p * 8 + i) = *a.get_unchecked((ir + i) * k + p);
}
}
for jr_panel in 0..panels_n {
let jr = jr_panel * 8;
let packed_b_ptr = all_packed_b.as_ptr().add(jr_panel * k * 8);
for jj in 0..8 {
for ii in 0..8 {
c_micro[jj * 8 + ii] = *c.get_unchecked((ir + ii) * n + jr + jj);
}
}
microkernel_8x8_avx2_fma(
k,
packed_a.as_ptr(),
packed_b_ptr,
c_micro.as_mut_ptr(),
8,
);
for jj in 0..8 {
for ii in 0..8 {
*c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 8 + ii];
}
}
}
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[allow(dead_code)] unsafe fn gemm_small_8x8(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use crate::blis::microkernels::microkernel_8x8_avx2_fma;
let mut packed_a = vec![0.0f32; m * k];
let mut packed_b = vec![0.0f32; k * n];
let mut c_micro = [0.0f32; 8 * 8];
let panels_m = (m + 7) / 8;
for panel in 0..panels_m {
let ir = panel * 8;
let mr = 8.min(m - ir);
for p in 0..k {
for i in 0..8 {
unsafe {
packed_a[panel * 8 * k + p * 8 + i] =
if i < mr { *a.get_unchecked((ir + i) * k + p) } else { 0.0 };
}
}
}
}
let panels_n = (n + 7) / 8;
for panel in 0..panels_n {
let jr = panel * 8;
let nr = 8.min(n - jr);
for p in 0..k {
for j in 0..8 {
unsafe {
packed_b[panel * 8 * k + p * 8 + j] =
if j < nr { *b.get_unchecked(p * n + jr + j) } else { 0.0 };
}
}
}
}
unsafe {
for ir_panel in 0..panels_m {
let ir = ir_panel * 8;
let mr = 8.min(m - ir);
for jr_panel in 0..panels_n {
let jr = jr_panel * 8;
let nr = 8.min(n - jr);
for jj in 0..8 {
for ii in 0..8 {
c_micro[jj * 8 + ii] = if ii < mr && jj < nr {
*c.get_unchecked((ir + ii) * n + jr + jj)
} else {
0.0
};
}
}
let ap = packed_a.as_ptr().add(ir_panel * 8 * k);
let bp = packed_b.as_ptr().add(jr_panel * 8 * k);
microkernel_8x8_avx2_fma(k, ap, bp, c_micro.as_mut_ptr(), 8);
for jj in 0..nr {
for ii in 0..mr {
*c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 8 + ii];
}
}
}
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn gemm_small_avx512_16x8(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use super::microkernels::microkernel_16x8_avx512;
use std::arch::x86_64::*;
let panels_m = m / 16;
let panels_n = n / 8;
let mut all_packed_b = vec![0.0f32; panels_n * k * 8];
unsafe {
for jr_panel in 0..panels_n {
let jr = jr_panel * 8;
let b_dst = all_packed_b.as_mut_ptr().add(jr_panel * k * 8);
for p in 0..k {
_mm256_storeu_ps(b_dst.add(p * 8), _mm256_loadu_ps(b.as_ptr().add(p * n + jr)));
}
}
}
let mut packed_a = [0.0f32; 16 * 256];
let mut c_micro = [0.0f32; 16 * 8];
unsafe {
for ir_panel in 0..panels_m {
let ir = ir_panel * 16;
let k_blocks = k / 8;
let k_rem_start = k_blocks * 8;
for kb in 0..k_blocks {
let p = kb * 8;
let r0 = _mm256_loadu_ps(a.as_ptr().add(ir * k + p));
let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 1) * k + p));
let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 2) * k + p));
let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 3) * k + p));
let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 4) * k + p));
let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 5) * k + p));
let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 6) * k + p));
let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 7) * k + p));
let t0 = _mm256_unpacklo_ps(r0, r1);
let t1 = _mm256_unpackhi_ps(r0, r1);
let t2 = _mm256_unpacklo_ps(r2, r3);
let t3 = _mm256_unpackhi_ps(r2, r3);
let t4 = _mm256_unpacklo_ps(r4, r5);
let t5 = _mm256_unpackhi_ps(r4, r5);
let t6 = _mm256_unpacklo_ps(r6, r7);
let t7 = _mm256_unpackhi_ps(r6, r7);
let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
let dst = packed_a.as_mut_ptr().add(p * 16);
_mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
_mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u1, u5, 0x20));
_mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u2, u6, 0x20));
_mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u3, u7, 0x20));
_mm256_storeu_ps(dst.add(64), _mm256_permute2f128_ps(u0, u4, 0x31));
_mm256_storeu_ps(dst.add(80), _mm256_permute2f128_ps(u1, u5, 0x31));
_mm256_storeu_ps(dst.add(96), _mm256_permute2f128_ps(u2, u6, 0x31));
_mm256_storeu_ps(dst.add(112), _mm256_permute2f128_ps(u3, u7, 0x31));
let r0 = _mm256_loadu_ps(a.as_ptr().add((ir + 8) * k + p));
let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 9) * k + p));
let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 10) * k + p));
let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 11) * k + p));
let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 12) * k + p));
let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 13) * k + p));
let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 14) * k + p));
let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 15) * k + p));
let t0 = _mm256_unpacklo_ps(r0, r1);
let t1 = _mm256_unpackhi_ps(r0, r1);
let t2 = _mm256_unpacklo_ps(r2, r3);
let t3 = _mm256_unpackhi_ps(r2, r3);
let t4 = _mm256_unpacklo_ps(r4, r5);
let t5 = _mm256_unpackhi_ps(r4, r5);
let t6 = _mm256_unpacklo_ps(r6, r7);
let t7 = _mm256_unpackhi_ps(r6, r7);
let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
let dst_lo = packed_a.as_mut_ptr().add(p * 16 + 8);
_mm256_storeu_ps(dst_lo, _mm256_permute2f128_ps(u0, u4, 0x20));
_mm256_storeu_ps(dst_lo.add(16), _mm256_permute2f128_ps(u1, u5, 0x20));
_mm256_storeu_ps(dst_lo.add(32), _mm256_permute2f128_ps(u2, u6, 0x20));
_mm256_storeu_ps(dst_lo.add(48), _mm256_permute2f128_ps(u3, u7, 0x20));
_mm256_storeu_ps(dst_lo.add(64), _mm256_permute2f128_ps(u0, u4, 0x31));
_mm256_storeu_ps(dst_lo.add(80), _mm256_permute2f128_ps(u1, u5, 0x31));
_mm256_storeu_ps(dst_lo.add(96), _mm256_permute2f128_ps(u2, u6, 0x31));
_mm256_storeu_ps(dst_lo.add(112), _mm256_permute2f128_ps(u3, u7, 0x31));
}
for p in k_rem_start..k {
for i in 0..16 {
*packed_a.get_unchecked_mut(p * 16 + i) = *a.get_unchecked((ir + i) * k + p);
}
}
for jr_panel in 0..panels_n {
let jr = jr_panel * 8;
let packed_b_ptr = all_packed_b.as_ptr().add(jr_panel * k * 8);
for jj in 0..8 {
for ii in 0..16 {
c_micro[jj * 16 + ii] = *c.get_unchecked((ir + ii) * n + jr + jj);
}
}
microkernel_16x8_avx512(
k,
packed_a.as_ptr(),
packed_b_ptr,
c_micro.as_mut_ptr(),
16,
);
for jj in 0..8 {
for ii in 0..16 {
*c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 16 + ii];
}
}
}
}
}
Ok(())
}
fn validate_gemm_dims(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &[f32],
) -> Result<(), TruenoError> {
if a.len() != m * k {
return Err(TruenoError::InvalidInput(format!(
"A size mismatch: expected {}, got {}",
m * k,
a.len()
)));
}
if b.len() != k * n {
return Err(TruenoError::InvalidInput(format!(
"B size mismatch: expected {}, got {}",
k * n,
b.len()
)));
}
if c.len() != m * n {
return Err(TruenoError::InvalidInput(format!(
"C size mismatch: expected {}, got {}",
m * n,
c.len()
)));
}
Ok(())
}
#[inline(always)]
fn record_prof(
profiler: &mut Option<&mut BlisProfiler>,
level: BlisProfileLevel,
start: Option<Instant>,
flops: u64,
) {
if let (Some(ref mut prof), Some(s)) = (profiler.as_deref_mut(), start) {
prof.record(level, s.elapsed().as_nanos() as u64, flops);
}
}
pub fn gemm_blis(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
mut profiler: Option<&mut BlisProfiler>,
) -> Result<(), TruenoError> {
contract_pre_flops_per_tile!();
validate_gemm_dims(m, n, k, a, b, c)?;
if m == 0 || n == 0 || k == 0 {
return Ok(());
}
if m * n * k < 4096 {
return gemm_reference(m, n, k, a, b, c);
}
#[cfg(target_arch = "x86_64")]
if profiler.is_none()
&& m <= 256
&& n <= 256
&& k <= 256
&& is_x86_feature_detected!("avx2")
&& is_x86_feature_detected!("fma")
{
unsafe {
if m <= 128 && n <= 128 && m % 8 == 0 && n % 8 == 0 {
return gemm_direct_rowmajor(m, n, k, a, b, c);
}
if is_x86_feature_detected!("avx512f") && m >= 16 && m % 16 == 0 && n % 8 == 0 {
return gemm_small_avx512_16x8(m, n, k, a, b, c);
}
if m >= MR && m % 8 == 0 && n % 8 == 0 {
return gemm_small_nopack_8x8(m, n, k, a, b, c);
}
return gemm_small_strided_avx2(m, n, k, a, b, c);
}
}
#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma") {
return unsafe { gemm_blis_avx512_large(m, n, k, a, b, c, &mut profiler) };
}
#[cfg(target_arch = "x86_64")]
if profiler.is_none() && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { gemm_blis_nr8_rowmajor_c(m, n, k, a, b, c) };
}
let track_time = profiler.is_some();
let start = if track_time { Some(Instant::now()) } else { None };
let mc = MC.min(m);
let nc = NC.min(n);
let kc = KC.min(k);
let needed_a = packed_a_size(mc, kc);
let needed_b = packed_b_size(kc, nc);
let needed_c = MR * NR;
TL_PACKED_A.with(|tl_a| {
TL_PACKED_B.with(|tl_b| {
TL_C_MICRO.with(|tl_c| {
let mut packed_a = tl_a.borrow_mut();
let mut packed_b = tl_b.borrow_mut();
let mut c_micro = tl_c.borrow_mut();
if packed_a.len() < needed_a {
packed_a.resize(needed_a, 0.0);
}
if packed_b.len() < needed_b {
packed_b.resize(needed_b, 0.0);
}
if c_micro.len() < needed_c {
c_micro.resize(needed_c, 0.0);
}
for jc in (0..n).step_by(NC) {
let nc_block = NC.min(n - jc);
for pc in (0..k).step_by(KC) {
let kc_block = KC.min(k - pc);
let pack_start = if track_time { Some(Instant::now()) } else { None };
pack_b_block(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
for ic in (0..m).step_by(MC) {
let mc_block = MC.min(m - ic);
let pack_start = if track_time { Some(Instant::now()) } else { None };
pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
compute_macroblock(
c,
&packed_a,
&packed_b,
&mut c_micro,
ic,
jc,
mc_block,
nc_block,
kc_block,
n,
&mut profiler,
);
}
}
}
if let (Some(prof), Some(s)) = (profiler, start) {
prof.record(
BlisProfileLevel::Macro,
s.elapsed().as_nanos() as u64,
(2 * m * n * k) as u64,
);
}
});
});
});
contract_post_flops_per_tile!(c);
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "fma")]
unsafe fn gemm_blis_avx512_large(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
profiler: &mut Option<&mut BlisProfiler>,
) -> Result<(), TruenoError> {
let track_time = profiler.is_some();
let start = if track_time { Some(Instant::now()) } else { None };
let blk = if n >= 32 {
super::cache_topology::blocking_8x32()
} else {
super::cache_topology::blocking_8x16()
};
let mr = blk.mr;
let nr = blk.nr;
let mc = blk.mc.min(m);
let nc = blk.nc.min(n);
let kc_param = blk.kc;
TL_PACKED_A.with(|tl_a| {
TL_PACKED_B.with(|tl_b| {
let mut packed_a = tl_a.borrow_mut();
let mut packed_b = tl_b.borrow_mut();
let needed_a = packed_a_size(mc, kc_param);
let b_panels = (nc + nr - 1) / nr;
let needed_b = b_panels * nr * kc_param;
if packed_a.len() < needed_a {
packed_a.resize(needed_a, 0.0);
}
if packed_b.len() < needed_b {
packed_b.resize(needed_b, 0.0);
}
for jc in (0..n).step_by(nc) {
let nc_block = nc.min(n - jc);
for pc in (0..k).step_by(kc_param) {
let kc_block = kc_param.min(k - pc);
if nr == 48 {
pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, 48, &mut packed_b);
} else if nr == 32 {
pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, 32, &mut packed_b);
} else {
pack_b_block_nr16(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
}
for ic in (0..m).step_by(mc) {
let mc_block = mc.min(m - ic);
pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
let panels_m = (mc_block + mr - 1) / mr;
let panels_n = (nc_block + nr - 1) / nr;
for ir_panel in 0..panels_m {
let ir = ir_panel * mr;
let mr_block = mr.min(mc_block - ir);
for jr_panel in 0..panels_n {
let jr = jr_panel * nr;
let nr_block = nr.min(nc_block - jr);
let a_panel = &packed_a[ir_panel * mr * kc_block..];
let b_panel = &packed_b[jr_panel * nr * kc_block..];
if mr_block == 8 && nr_block == 48 && nr == 48 {
unsafe {
super::microkernels::codegen::microkernel_8x48_avx512_gen(
kc_block,
a_panel.as_ptr(),
b_panel.as_ptr(),
c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
n,
);
}
} else if mr_block == 8 && nr_block == 32 && nr == 32 {
unsafe {
avx512_microkernel_8x32_rowmajor(
kc_block,
a_panel.as_ptr(),
b_panel.as_ptr(),
c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
n,
);
}
} else if mr_block == 8 && nr_block == 16 && nr == 16 {
unsafe {
avx512_microkernel_8x16_rowmajor(
kc_block,
a_panel.as_ptr(),
b_panel.as_ptr(),
c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
n,
);
}
} else {
for ir_local in 0..mr_block {
for jr_local in 0..nr_block {
let mut sum =
c[(ic + ir + ir_local) * n + (jc + jr + jr_local)];
for p in 0..kc_block {
sum += a_panel[p * mr + ir_local]
* b_panel[p * nr + jr_local];
}
c[(ic + ir + ir_local) * n + (jc + jr + jr_local)] =
sum;
}
}
}
}
}
}
}
}
if let (Some(prof), Some(s)) = (profiler.as_mut(), start) {
prof.record_avx512_blis(m, n, k, s.elapsed());
}
Ok(())
})
})
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "fma")]
unsafe fn gemm_blis_avx512_bcast_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
let blk = super::cache_topology::blocking_64x6_bcast_b();
let mr = blk.mr; let nr = blk.nr; let mc = blk.mc.min(m);
let nc = blk.nc.min(n);
let kc = blk.kc;
let a_panels = (mc + mr - 1) / mr;
let needed_a = a_panels * mr * kc;
let b_panels = (nc + nr - 1) / nr;
let needed_b = b_panels * nr * kc;
let mut packed_a = vec![0.0f32; needed_a];
let mut packed_b = vec![0.0f32; needed_b];
for jc in (0..n).step_by(nc) {
let nc_block = nc.min(n - jc);
for pc in (0..k).step_by(kc) {
let kc_block = kc.min(k - pc);
pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, nr, &mut packed_b);
for ic in (0..m).step_by(mc) {
let mc_block = mc.min(m - ic);
pack_a_block_generic(a, k, ic, pc, mc_block, kc_block, mr, &mut packed_a);
let panels_m = (mc_block + mr - 1) / mr;
let panels_n = (nc_block + nr - 1) / nr;
for ir_panel in 0..panels_m {
let ir = ir_panel * mr;
let mr_block = mr.min(mc_block - ir);
for jr_panel in 0..panels_n {
let jr = jr_panel * nr;
let nr_block = nr.min(nc_block - jr);
let a_panel = &packed_a[ir_panel * mr * kc_block..];
let b_panel = &packed_b[jr_panel * nr * kc_block..];
if mr_block == 64 && nr_block == 6 {
unsafe {
super::microkernels::codegen::microkernel_64x6_avx512_bcast_b(
kc_block,
a_panel.as_ptr(),
b_panel.as_ptr(),
c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
n,
);
}
} else {
for ir_local in 0..mr_block {
for jr_local in 0..nr_block {
let mut sum = 0.0f32;
for p in 0..kc_block {
sum +=
a_panel[p * mr + ir_local] * b_panel[p * nr + jr_local];
}
c[(ic + ir + ir_local) * n + (jc + jr + jr_local)] += sum;
}
}
}
}
}
}
}
}
Ok(())
}
fn pack_a_block_generic(
a: &[f32],
lda: usize,
row_start: usize,
col_start: usize,
rows: usize,
cols: usize,
mr: usize,
packed: &mut [f32],
) {
let panels = (rows + mr - 1) / mr;
let mut pack_idx = 0;
for panel in 0..panels {
let ir = panel * mr;
let mr_actual = mr.min(rows - ir);
for col in 0..cols {
for row in 0..mr {
if row < mr_actual {
packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
} else {
packed[pack_idx] = 0.0;
}
pack_idx += 1;
}
}
}
}
#[cfg(target_arch = "x86_64")]
pub fn gemm_blis_broadcast_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
}
if std::arch::is_x86_feature_detected!("avx512f") {
unsafe { gemm_blis_avx512_bcast_b(m, n, k, a, b, c) }
} else {
gemm_blis(m, n, k, a, b, c, None)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "fma")]
pub(super) unsafe fn avx512_microkernel_8x16_rowmajor(
k: usize,
a: *const f32, b: *const f32, c: *mut f32,
ldc: usize, ) {
use std::arch::x86_64::*;
let mut c0 = _mm512_loadu_ps(c);
let mut c1 = _mm512_loadu_ps(c.add(ldc));
let mut c2 = _mm512_loadu_ps(c.add(2 * ldc));
let mut c3 = _mm512_loadu_ps(c.add(3 * ldc));
let mut c4 = _mm512_loadu_ps(c.add(4 * ldc));
let mut c5 = _mm512_loadu_ps(c.add(5 * ldc));
let mut c6 = _mm512_loadu_ps(c.add(6 * ldc));
let mut c7 = _mm512_loadu_ps(c.add(7 * ldc));
for p in 0..k {
let b_row = _mm512_loadu_ps(b.add(p * 16));
let ap = a.add(p * 8);
c0 = _mm512_fmadd_ps(_mm512_set1_ps(*ap), b_row, c0);
c1 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(1)), b_row, c1);
c2 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(2)), b_row, c2);
c3 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(3)), b_row, c3);
c4 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(4)), b_row, c4);
c5 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(5)), b_row, c5);
c6 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(6)), b_row, c6);
c7 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(7)), b_row, c7);
}
_mm512_storeu_ps(c, c0);
_mm512_storeu_ps(c.add(ldc), c1);
_mm512_storeu_ps(c.add(2 * ldc), c2);
_mm512_storeu_ps(c.add(3 * ldc), c3);
_mm512_storeu_ps(c.add(4 * ldc), c4);
_mm512_storeu_ps(c.add(5 * ldc), c5);
_mm512_storeu_ps(c.add(6 * ldc), c6);
_mm512_storeu_ps(c.add(7 * ldc), c7);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "fma")]
pub(super) unsafe fn avx512_microkernel_8x32_rowmajor(
k: usize,
a: *const f32, b: *const f32, c: *mut f32,
ldc: usize, ) {
use std::arch::x86_64::*;
let mut c0l = _mm512_loadu_ps(c);
let mut c0h = _mm512_loadu_ps(c.add(16));
let mut c1l = _mm512_loadu_ps(c.add(ldc));
let mut c1h = _mm512_loadu_ps(c.add(ldc + 16));
let mut c2l = _mm512_loadu_ps(c.add(2 * ldc));
let mut c2h = _mm512_loadu_ps(c.add(2 * ldc + 16));
let mut c3l = _mm512_loadu_ps(c.add(3 * ldc));
let mut c3h = _mm512_loadu_ps(c.add(3 * ldc + 16));
let mut c4l = _mm512_loadu_ps(c.add(4 * ldc));
let mut c4h = _mm512_loadu_ps(c.add(4 * ldc + 16));
let mut c5l = _mm512_loadu_ps(c.add(5 * ldc));
let mut c5h = _mm512_loadu_ps(c.add(5 * ldc + 16));
let mut c6l = _mm512_loadu_ps(c.add(6 * ldc));
let mut c6h = _mm512_loadu_ps(c.add(6 * ldc + 16));
let mut c7l = _mm512_loadu_ps(c.add(7 * ldc));
let mut c7h = _mm512_loadu_ps(c.add(7 * ldc + 16));
for p in 0..k {
let bl = _mm512_loadu_ps(b.add(p * 32));
let bh = _mm512_loadu_ps(b.add(p * 32 + 16));
let ap = a.add(p * 8);
let a0 = _mm512_set1_ps(*ap);
c0l = _mm512_fmadd_ps(a0, bl, c0l);
c0h = _mm512_fmadd_ps(a0, bh, c0h);
let a1 = _mm512_set1_ps(*ap.add(1));
c1l = _mm512_fmadd_ps(a1, bl, c1l);
c1h = _mm512_fmadd_ps(a1, bh, c1h);
let a2 = _mm512_set1_ps(*ap.add(2));
c2l = _mm512_fmadd_ps(a2, bl, c2l);
c2h = _mm512_fmadd_ps(a2, bh, c2h);
let a3 = _mm512_set1_ps(*ap.add(3));
c3l = _mm512_fmadd_ps(a3, bl, c3l);
c3h = _mm512_fmadd_ps(a3, bh, c3h);
let a4 = _mm512_set1_ps(*ap.add(4));
c4l = _mm512_fmadd_ps(a4, bl, c4l);
c4h = _mm512_fmadd_ps(a4, bh, c4h);
let a5 = _mm512_set1_ps(*ap.add(5));
c5l = _mm512_fmadd_ps(a5, bl, c5l);
c5h = _mm512_fmadd_ps(a5, bh, c5h);
let a6 = _mm512_set1_ps(*ap.add(6));
c6l = _mm512_fmadd_ps(a6, bl, c6l);
c6h = _mm512_fmadd_ps(a6, bh, c6h);
let a7 = _mm512_set1_ps(*ap.add(7));
c7l = _mm512_fmadd_ps(a7, bl, c7l);
c7h = _mm512_fmadd_ps(a7, bh, c7h);
}
_mm512_storeu_ps(c, c0l);
_mm512_storeu_ps(c.add(16), c0h);
_mm512_storeu_ps(c.add(ldc), c1l);
_mm512_storeu_ps(c.add(ldc + 16), c1h);
_mm512_storeu_ps(c.add(2 * ldc), c2l);
_mm512_storeu_ps(c.add(2 * ldc + 16), c2h);
_mm512_storeu_ps(c.add(3 * ldc), c3l);
_mm512_storeu_ps(c.add(3 * ldc + 16), c3h);
_mm512_storeu_ps(c.add(4 * ldc), c4l);
_mm512_storeu_ps(c.add(4 * ldc + 16), c4h);
_mm512_storeu_ps(c.add(5 * ldc), c5l);
_mm512_storeu_ps(c.add(5 * ldc + 16), c5h);
_mm512_storeu_ps(c.add(6 * ldc), c6l);
_mm512_storeu_ps(c.add(6 * ldc + 16), c6h);
_mm512_storeu_ps(c.add(7 * ldc), c7l);
_mm512_storeu_ps(c.add(7 * ldc + 16), c7h);
}
pub(super) fn pack_b_block_nr16(
b: &[f32],
ldb: usize,
pc: usize,
jc: usize,
kc: usize,
nc: usize,
packed: &mut [f32],
) {
let nr = 16;
let panels = (nc + nr - 1) / nr;
for panel in 0..panels {
let j_start = panel * nr;
let nr_local = nr.min(nc - j_start);
for p in 0..kc {
for j in 0..nr_local {
packed[panel * nr * kc + p * nr + j] = b[(pc + p) * ldb + (jc + j_start + j)];
}
for j in nr_local..nr {
packed[panel * nr * kc + p * nr + j] = 0.0;
}
}
}
}
pub(super) fn pack_b_block_generic(
b: &[f32],
ldb: usize,
pc: usize,
jc: usize,
kc: usize,
nc: usize,
nr: usize,
packed: &mut [f32],
) {
#[cfg(target_arch = "x86_64")]
if nr == 32 && std::arch::is_x86_feature_detected!("avx512f") {
unsafe {
pack_b_block_nr32_avx512(b, ldb, pc, jc, kc, nc, packed);
}
return;
}
let panels = (nc + nr - 1) / nr;
for panel in 0..panels {
let j_start = panel * nr;
let nr_local = nr.min(nc - j_start);
for p in 0..kc {
let dst_base = panel * nr * kc + p * nr;
for j in 0..nr_local {
packed[dst_base + j] = b[(pc + p) * ldb + (jc + j_start + j)];
}
for j in nr_local..nr {
packed[dst_base + j] = 0.0;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn pack_b_block_nr32_avx512(
b: &[f32],
ldb: usize,
pc: usize,
jc: usize,
kc: usize,
nc: usize,
packed: &mut [f32],
) {
use std::arch::x86_64::*;
let nr = 32;
let panels = (nc + nr - 1) / nr;
for panel in 0..panels {
let j_start = panel * nr;
let nr_local = nr.min(nc - j_start);
if nr_local == 32 {
let panel_base = panel * nr * kc;
let b_col = jc + j_start;
let kc2 = kc / 2 * 2;
let mut p = 0;
while p < kc2 {
let src0 = b.as_ptr().add((pc + p) * ldb + b_col);
let src1 = b.as_ptr().add((pc + p + 1) * ldb + b_col);
let dst0 = packed.as_mut_ptr().add(panel_base + p * nr);
let dst1 = packed.as_mut_ptr().add(panel_base + (p + 1) * nr);
let v0a = _mm512_loadu_ps(src0);
let v0b = _mm512_loadu_ps(src0.add(16));
let v1a = _mm512_loadu_ps(src1);
let v1b = _mm512_loadu_ps(src1.add(16));
_mm512_storeu_ps(dst0, v0a);
_mm512_storeu_ps(dst0.add(16), v0b);
_mm512_storeu_ps(dst1, v1a);
_mm512_storeu_ps(dst1.add(16), v1b);
p += 2;
}
while p < kc {
let src = b.as_ptr().add((pc + p) * ldb + b_col);
let dst = packed.as_mut_ptr().add(panel_base + p * nr);
let v0 = _mm512_loadu_ps(src);
let v1 = _mm512_loadu_ps(src.add(16));
_mm512_storeu_ps(dst, v0);
_mm512_storeu_ps(dst.add(16), v1);
p += 1;
}
} else {
for p in 0..kc {
let dst_base = panel * nr * kc + p * nr;
for j in 0..nr_local {
packed[dst_base + j] = b[(pc + p) * ldb + (jc + j_start + j)];
}
for j in nr_local..nr {
packed[dst_base + j] = 0.0;
}
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn gemm_blis_nr8_rowmajor_c(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use std::arch::x86_64::*;
let mc = 64_usize.min(m);
let nc = 1024_usize.min(n);
let kc_param = KC;
let nr = 8_usize; let mr = MR;
TL_PACKED_A.with(|tl_a| {
TL_PACKED_B.with(|tl_b| {
let mut packed_a = tl_a.borrow_mut();
let mut packed_b = tl_b.borrow_mut();
let needed_a = packed_a_size(mc, kc_param);
let needed_b = packed_b_size_512(kc_param, nc); if packed_a.len() < needed_a {
packed_a.resize(needed_a, 0.0);
}
if packed_b.len() < needed_b {
packed_b.resize(needed_b, 0.0);
}
for jc in (0..n).step_by(nc) {
let nc_block = nc.min(n - jc);
for pc in (0..k).step_by(kc_param) {
let kc_block = kc_param.min(k - pc);
pack_b_block_512(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
for ic in (0..m).step_by(mc) {
let mc_block = mc.min(m - ic);
pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
let panels_m = (mc_block + mr - 1) / mr;
let panels_n = (nc_block + nr - 1) / nr;
for ir_panel in 0..panels_m {
let ir = ir_panel * mr;
let mr_block = mr.min(mc_block - ir);
for jr_panel in 0..panels_n {
let jr = jr_panel * nr;
let nr_block = nr.min(nc_block - jr);
let a_panel = &packed_a[ir_panel * mr * kc_block..];
let b_panel = &packed_b[jr_panel * nr * kc_block..];
if mr_block == 8 && nr_block == 8 {
unsafe {
let c_base = c.as_mut_ptr().add((ic + ir) * n + (jc + jr));
let mut c0 = _mm256_loadu_ps(c_base);
let mut c1 = _mm256_loadu_ps(c_base.add(n));
let mut c2 = _mm256_loadu_ps(c_base.add(2 * n));
let mut c3 = _mm256_loadu_ps(c_base.add(3 * n));
let mut c4 = _mm256_loadu_ps(c_base.add(4 * n));
let mut c5 = _mm256_loadu_ps(c_base.add(5 * n));
let mut c6 = _mm256_loadu_ps(c_base.add(6 * n));
let mut c7 = _mm256_loadu_ps(c_base.add(7 * n));
let ap = a_panel.as_ptr();
let bp = b_panel.as_ptr();
let k4 = kc_block / 4;
let k_rem = kc_block % 4;
for p4 in 0..k4 {
let p = p4 * 4;
let b_row = _mm256_loadu_ps(bp.add(p * 8));
c0 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8)),
b_row,
c0,
);
c1 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8 + 1)),
b_row,
c1,
);
c2 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8 + 2)),
b_row,
c2,
);
c3 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8 + 3)),
b_row,
c3,
);
c4 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8 + 4)),
b_row,
c4,
);
c5 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8 + 5)),
b_row,
c5,
);
c6 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8 + 6)),
b_row,
c6,
);
c7 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(p * 8 + 7)),
b_row,
c7,
);
let b_row = _mm256_loadu_ps(bp.add((p + 1) * 8));
c0 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8)),
b_row,
c0,
);
c1 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 1)),
b_row,
c1,
);
c2 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 2)),
b_row,
c2,
);
c3 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 3)),
b_row,
c3,
);
c4 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 4)),
b_row,
c4,
);
c5 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 5)),
b_row,
c5,
);
c6 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 6)),
b_row,
c6,
);
c7 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 7)),
b_row,
c7,
);
let b_row = _mm256_loadu_ps(bp.add((p + 2) * 8));
c0 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8)),
b_row,
c0,
);
c1 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 1)),
b_row,
c1,
);
c2 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 2)),
b_row,
c2,
);
c3 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 3)),
b_row,
c3,
);
c4 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 4)),
b_row,
c4,
);
c5 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 5)),
b_row,
c5,
);
c6 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 6)),
b_row,
c6,
);
c7 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 7)),
b_row,
c7,
);
let b_row = _mm256_loadu_ps(bp.add((p + 3) * 8));
c0 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8)),
b_row,
c0,
);
c1 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 1)),
b_row,
c1,
);
c2 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 2)),
b_row,
c2,
);
c3 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 3)),
b_row,
c3,
);
c4 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 4)),
b_row,
c4,
);
c5 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 5)),
b_row,
c5,
);
c6 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 6)),
b_row,
c6,
);
c7 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 7)),
b_row,
c7,
);
}
let base_rem = k4 * 4;
for rp in 0..k_rem {
let pp = base_rem + rp;
let b_row = _mm256_loadu_ps(bp.add(pp * 8));
c0 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8)),
b_row,
c0,
);
c1 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8 + 1)),
b_row,
c1,
);
c2 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8 + 2)),
b_row,
c2,
);
c3 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8 + 3)),
b_row,
c3,
);
c4 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8 + 4)),
b_row,
c4,
);
c5 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8 + 5)),
b_row,
c5,
);
c6 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8 + 6)),
b_row,
c6,
);
c7 = _mm256_fmadd_ps(
_mm256_broadcast_ss(&*ap.add(pp * 8 + 7)),
b_row,
c7,
);
}
_mm256_storeu_ps(c_base, c0);
_mm256_storeu_ps(c_base.add(n), c1);
_mm256_storeu_ps(c_base.add(2 * n), c2);
_mm256_storeu_ps(c_base.add(3 * n), c3);
_mm256_storeu_ps(c_base.add(4 * n), c4);
_mm256_storeu_ps(c_base.add(5 * n), c5);
_mm256_storeu_ps(c_base.add(6 * n), c6);
_mm256_storeu_ps(c_base.add(7 * n), c7);
}
} else {
for p in 0..kc_block {
for jj in 0..nr_block {
let b_val = b_panel[p * nr + jj];
for ii in 0..mr_block {
c[(ic + ir + ii) * n + (jc + jr + jj)] +=
a_panel[p * mr + ii] * b_val;
}
}
}
}
}
}
}
}
}
});
});
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[allow(dead_code)] fn gemm_blis_avx512_packed(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
let mc = MC_512.min(m);
let nc = NC_512.min(n);
let kc = KC_512.min(k);
let needed_a = packed_a_size_512(mc, kc);
let needed_b = packed_b_size_512(kc, nc);
let needed_c = MR_512 * NR_512;
TL_PACKED_A.with(|tl_a| {
TL_PACKED_B.with(|tl_b| {
TL_C_MICRO.with(|tl_c| {
let mut packed_a = tl_a.borrow_mut();
let mut packed_b = tl_b.borrow_mut();
let mut c_micro = tl_c.borrow_mut();
if packed_a.len() < needed_a {
packed_a.resize(needed_a, 0.0);
}
if packed_b.len() < needed_b {
packed_b.resize(needed_b, 0.0);
}
if c_micro.len() < needed_c {
c_micro.resize(needed_c, 0.0);
}
for jc in (0..n).step_by(NC_512) {
let nc_block = NC_512.min(n - jc);
for pc in (0..k).step_by(KC_512) {
let kc_block = KC_512.min(k - pc);
pack_b_block_512(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
for ic in (0..m).step_by(MC_512) {
let mc_block = MC_512.min(m - ic);
pack_a_block_512(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
for ir in (0..mc_block).step_by(MR_512) {
let mr_block = MR_512.min(mc_block - ir);
for jr in (0..nc_block).step_by(NR_512) {
let nr_block = NR_512.min(nc_block - jr);
let a_panel = &packed_a[(ir / MR_512) * MR_512 * kc_block..];
let b_panel = &packed_b[(jr / NR_512) * NR_512 * kc_block..];
for jj in 0..nr_block {
for ii in 0..mr_block {
c_micro[jj * MR_512 + ii] =
c[(ic + ir + ii) * n + (jc + jr + jj)];
}
for ii in mr_block..MR_512 {
c_micro[jj * MR_512 + ii] = 0.0;
}
}
for jj in nr_block..NR_512 {
for ii in 0..MR_512 {
c_micro[jj * MR_512 + ii] = 0.0;
}
}
if mr_block == MR_512 && nr_block == NR_512 {
unsafe {
microkernel_16x8_avx512(
kc_block,
a_panel.as_ptr(),
b_panel.as_ptr(),
c_micro.as_mut_ptr(),
MR_512,
);
}
} else {
for p in 0..kc_block {
for jj in 0..NR_512 {
let b_val = b_panel[p * NR_512 + jj];
for ii in 0..MR_512 {
c_micro[jj * MR_512 + ii] +=
a_panel[p * MR_512 + ii] * b_val;
}
}
}
}
for jj in 0..nr_block {
for ii in 0..mr_block {
c[(ic + ir + ii) * n + (jc + jr + jj)] =
c_micro[jj * MR_512 + ii];
}
}
}
}
}
}
}
});
});
});
Ok(())
}
pub fn gemm_blis_with_prepacked_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
prepacked_b: &PrepackedB,
c: &mut [f32],
mut profiler: Option<&mut BlisProfiler>,
) -> Result<(), TruenoError> {
if a.len() != m * k {
return Err(TruenoError::InvalidInput(format!(
"A size mismatch: expected {}, got {}",
m * k,
a.len()
)));
}
if c.len() != m * n {
return Err(TruenoError::InvalidInput(format!(
"C size mismatch: expected {}, got {}",
m * n,
c.len()
)));
}
if prepacked_b.k != k || prepacked_b.n != n {
return Err(TruenoError::InvalidInput(format!(
"PrepackedB dimension mismatch: expected ({}, {}), got ({}, {})",
k, n, prepacked_b.k, prepacked_b.n
)));
}
if m == 0 || n == 0 || k == 0 {
return Ok(());
}
let track_time = profiler.is_some();
let start = if track_time { Some(Instant::now()) } else { None };
let mc = MC.min(m);
let kc = KC.min(k);
let needed_a = packed_a_size(mc, kc);
let needed_c = MR * NR;
TL_PACKED_A.with(|tl_a| {
TL_C_MICRO.with(|tl_c| {
let mut packed_a = tl_a.borrow_mut();
let mut c_micro = tl_c.borrow_mut();
if packed_a.len() < needed_a {
packed_a.resize(needed_a, 0.0);
} else {
packed_a[..needed_a].fill(0.0);
}
if c_micro.len() < needed_c {
c_micro.resize(needed_c, 0.0);
} else {
c_micro[..needed_c].fill(0.0);
}
for (jc_idx, jc) in (0..n).step_by(NC).enumerate() {
let nc_block = NC.min(n - jc);
for (pc_idx, pc) in (0..k).step_by(KC).enumerate() {
let kc_block = KC.min(k - pc);
let packed_b_tile = prepacked_b.tile(jc_idx, pc_idx);
for ic in (0..m).step_by(MC) {
let mc_block = MC.min(m - ic);
let pack_start = if track_time { Some(Instant::now()) } else { None };
pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
compute_macroblock(
c,
&packed_a,
packed_b_tile,
&mut c_micro,
ic,
jc,
mc_block,
nc_block,
kc_block,
n,
&mut profiler,
);
}
}
}
if let (Some(prof), Some(s)) = (profiler, start) {
prof.record(
BlisProfileLevel::Macro,
s.elapsed().as_nanos() as u64,
(2 * m * n * k) as u64,
);
}
});
});
Ok(())
}