use crate::gemm;
use crate::generic_params::*;
use typenum::*;
use typenum::{U2, U4};
type FS = unsafe fn(
usize,
usize,
usize,
f32,
*const f32,
isize,
isize,
*const f32,
isize,
isize,
f32,
*mut f32,
isize,
isize,
bool,
);
type FD = unsafe fn(
usize,
usize,
usize,
f64,
*const f64,
isize,
isize,
*const f64,
isize,
isize,
f64,
*mut f64,
isize,
isize,
bool,
);
static THIN_SGEMMS: [&FS; 16] = [
&(gemm::gemm_loop::<SgemmCache, S32x1t> as FS), &(gemm::gemm_loop::<SgemmCache, S32x1t> as FS), &(gemm::gemm_loop::<SgemmCache, S32x2t> as FS), &(gemm::gemm_loop::<SgemmCache, S24x3t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x4t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x5t> as FS), &(gemm::gemm_loop::<SgemmCache, S8x6t> as FS), &(gemm::gemm_loop::<SgemmCache, S8x7t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x4t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x5t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x5t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x4t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x4t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x5t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x5t> as FS), &(gemm::gemm_loop::<SgemmCache, S16x5t> as FS), ];
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn sgemm(
m: usize,
k: usize,
n: usize,
alpha: f32,
a: *const f32,
rsa: isize,
csa: isize,
b: *const f32,
rsb: isize,
csb: isize,
beta: f32,
c: *mut f32,
rsc: isize,
csc: isize,
multithread: bool,
) {
if n < THIN_SGEMMS.len() {
THIN_SGEMMS[n](m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
return;
}
if n > 28 && csc == 1 {
gemm::gemm_loop::<SgemmCache, S4x16>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else {
gemm::gemm_loop::<SgemmCache, S16x4t>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
}
}
#[allow(unused)]
pub struct S8x8; impl KernelConfig for S8x8 {
type T = f32;
type MR = U8;
type NR = U8;
type KU = U8; type TR = U0;
type FMA = U1;
}
#[allow(unused)]
pub struct S8x8t; impl KernelConfig for S8x8t {
type T = f32;
type MR = U8;
type NR = U8;
type KU = U8; type TR = U1;
type FMA = U1;
}
#[allow(unused)]
pub struct S6x16; impl KernelConfig for S6x16 {
type T = f32;
type MR = U6;
type NR = U16;
type KU = U4;
type TR = U0;
type FMA = U1;
}
#[allow(unused)]
pub struct S5x16; impl KernelConfig for S5x16 {
type T = f32;
type MR = U5;
type NR = U16;
type KU = U8; type TR = U0;
type FMA = U1;
}
pub struct S4x16; impl KernelConfig for S4x16 {
type T = f32;
type MR = U4;
type NR = U16;
type KU = U8; type TR = U0;
type FMA = U1;
}
pub struct S8x7t; impl KernelConfig for S8x7t {
type T = f32;
type MR = U8;
type NR = U7;
type KU = U8; type TR = U1;
type FMA = U1;
}
#[allow(unused)]
pub struct S8x6t; impl KernelConfig for S8x6t {
type T = f32;
type MR = U8;
type NR = U6;
type KU = U8; type TR = U1;
type FMA = U1;
}
pub struct S16x5t; impl KernelConfig for S16x5t {
type T = f32;
type MR = U16;
type NR = U5;
type KU = U8; type TR = U1;
type FMA = U1;
}
pub struct S16x4t; impl KernelConfig for S16x4t {
type T = f32;
type MR = U16;
type NR = U4;
type KU = U8; type TR = U1;
type FMA = U1;
}
pub struct S24x3t; impl KernelConfig for S24x3t {
type T = f32;
type MR = U24;
type NR = U3;
type KU = U8;
type TR = U1;
type FMA = U1;
}
pub struct S32x2t; impl KernelConfig for S32x2t {
type T = f32;
type MR = U32;
type NR = U2;
type KU = U8;
type TR = U1;
type FMA = U1;
}
pub struct S32x1t; impl KernelConfig for S32x1t {
type T = f32;
type MR = U32;
type NR = U1;
type KU = U8;
type TR = U1;
type FMA = U1;
}
static THIN_DGEMMS: [&FD; 16] = [
&(gemm::gemm_loop::<DgemmCache, D16x1t> as FD), &(gemm::gemm_loop::<DgemmCache, D16x1t> as FD), &(gemm::gemm_loop::<DgemmCache, D16x2t> as FD), &(gemm::gemm_loop::<DgemmCache, D12x3t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x4t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x5t> as FD), &(gemm::gemm_loop::<DgemmCache, D4x6t> as FD), &(gemm::gemm_loop::<DgemmCache, D4x7t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x4t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x5t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x5t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x4t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x4t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x5t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x5t> as FD), &(gemm::gemm_loop::<DgemmCache, D8x5t> as FD), ];
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn dgemm(
m: usize,
k: usize,
n: usize,
alpha: f64,
a: *const f64,
rsa: isize,
csa: isize,
b: *const f64,
rsb: isize,
csb: isize,
beta: f64,
c: *mut f64,
rsc: isize,
csc: isize,
multithread: bool,
) {
if n < THIN_SGEMMS.len() {
THIN_DGEMMS[n](m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
return;
}
if n > 28 && csc == 1 {
gemm::gemm_loop::<DgemmCache, D4x8>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else {
gemm::gemm_loop::<DgemmCache, D8x4t>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
}
}
#[allow(unused)]
pub struct D8x4; impl KernelConfig for D8x4 {
type T = f64;
type MR = U8;
type NR = U4;
type KU = U8; type TR = U0;
type FMA = U1;
}
#[allow(unused)]
pub struct D4x8t; impl KernelConfig for D4x8t {
type T = f64;
type MR = U4;
type NR = U8;
type KU = U8; type TR = U1;
type FMA = U1;
}
#[allow(unused)]
pub struct D6x8; impl KernelConfig for D6x8 {
type T = f64;
type MR = U6;
type NR = U8;
type KU = U4;
type TR = U0;
type FMA = U1;
}
#[allow(unused)]
pub struct D5x8; impl KernelConfig for D5x8 {
type T = f64;
type MR = U5;
type NR = U8;
type KU = U8; type TR = U0;
type FMA = U1;
}
pub struct D4x8; impl KernelConfig for D4x8 {
type T = f64;
type MR = U4;
type NR = U8;
type KU = U8; type TR = U0;
type FMA = U1;
}
pub struct D4x7t; impl KernelConfig for D4x7t {
type T = f64;
type MR = U4;
type NR = U7;
type KU = U8; type TR = U1;
type FMA = U1;
}
#[allow(unused)]
pub struct D4x6t; impl KernelConfig for D4x6t {
type T = f64;
type MR = U8;
type NR = U6;
type KU = U8; type TR = U1;
type FMA = U1;
}
pub struct D8x5t; impl KernelConfig for D8x5t {
type T = f64;
type MR = U8;
type NR = U5;
type KU = U8; type TR = U1;
type FMA = U1;
}
pub struct D8x4t; impl KernelConfig for D8x4t {
type T = f64;
type MR = U8;
type NR = U4;
type KU = U8; type TR = U1;
type FMA = U1;
}
pub struct D12x3t; impl KernelConfig for D12x3t {
type T = f64;
type MR = U12;
type NR = U3;
type KU = U8;
type TR = U1;
type FMA = U1;
}
pub struct D16x2t; impl KernelConfig for D16x2t {
type T = f64;
type MR = U16;
type NR = U2;
type KU = U8;
type TR = U1;
type FMA = U1;
}
pub struct D16x1t; impl KernelConfig for D16x1t {
type T = f64;
type MR = U16;
type NR = U1;
type KU = U8;
type TR = U1;
type FMA = U1;
}
pub struct SgemmCache;
impl CacheConfigValues for SgemmCache {
type A = U32;
type MT = U128;
type MC = U64;
type NC = U1024;
type KC = U256;
}
pub struct DgemmCache;
impl CacheConfigValues for DgemmCache {
type A = U32;
type MT = U128;
type MC = U32;
type NC = U512;
type KC = U256;
}