use generic_array::ArrayLength;
use num_traits::identities::{One, Zero};
use parking_lot::{Condvar, Mutex};
use rawpointer::PointerExt;
use smallvec::SmallVec;
use std::cmp::{max, min};
use std::mem::align_of;
use std::mem::size_of;
use std::sync::atomic::{AtomicUsize, Ordering};
use threadpool::ThreadPool;
use typenum::Unsigned;
use typenum_loops::Loop;
use crate::generic_kernel;
use crate::generic_params::SgemmCache;
use crate::generic_params::*;
use crate::hwl_kernels;
use crate::util::range_chunk;
use crate::util::round_up_div;
use crate::util::round_up_to;
use crate::{prefetch_read, reset_ftz_and_daz, set_ftz_and_daz, snb_kernels};
lazy_static! {
static ref NUM_CPUS: usize = num_cpus::get();
static ref THREAD_POOL: Mutex<ThreadPool> = Mutex::new(ThreadPool::new(*NUM_CPUS));
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn sgemm(
m: usize,
k: usize,
n: usize,
alpha: f32,
a: *const f32,
rsa: isize,
csa: isize,
b: *const f32,
rsb: isize,
csb: isize,
beta: f32,
c: *mut f32,
rsc: isize,
csc: isize,
) {
sgemm_flex(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, true)
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn sgemm_st(
m: usize,
k: usize,
n: usize,
alpha: f32,
a: *const f32,
rsa: isize,
csa: isize,
b: *const f32,
rsb: isize,
csb: isize,
beta: f32,
c: *mut f32,
rsc: isize,
csc: isize,
) {
sgemm_flex(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, false)
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
unsafe fn sgemm_flex(
m: usize,
k: usize,
n: usize,
alpha: f32,
a: *const f32,
rsa: isize,
csa: isize,
b: *const f32,
rsb: isize,
csb: isize,
beta: f32,
c: *mut f32,
rsc: isize,
csc: isize,
multithread: bool,
) {
if k == 0 || m == 0 || n == 0 {
return;
}
let (m, k, n, a, rsa, csa, b, rsb, csb, c, rsc, csc) = if n > m {
(n, k, m, b, csb, rsb, a, csa, rsa, c, csc, rsc)
} else {
(m, k, n, a, rsa, csa, b, rsb, csb, c, rsc, csc)
};
if cfg!(arch_haswell) {
hwl_kernels::sgemm(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else if cfg!(arch_sandybridge) {
snb_kernels::sgemm(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else if cfg!(arch_penryn) {
gemm_loop::<SgemmCache, S4x4>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else if cfg!(arch_generic4x4fma) {
gemm_loop::<SgemmCache, S4x4fma>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else {
gemm_loop::<SgemmCache, S4x4>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
}
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn dgemm(
m: usize,
k: usize,
n: usize,
alpha: f64,
a: *const f64,
rsa: isize,
csa: isize,
b: *const f64,
rsb: isize,
csb: isize,
beta: f64,
c: *mut f64,
rsc: isize,
csc: isize,
) {
dgemm_flex(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, true)
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn dgemm_st(
m: usize,
k: usize,
n: usize,
alpha: f64,
a: *const f64,
rsa: isize,
csa: isize,
b: *const f64,
rsb: isize,
csb: isize,
beta: f64,
c: *mut f64,
rsc: isize,
csc: isize,
) {
dgemm_flex(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, false)
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
unsafe fn dgemm_flex(
m: usize,
k: usize,
n: usize,
alpha: f64,
a: *const f64,
rsa: isize,
csa: isize,
b: *const f64,
rsb: isize,
csb: isize,
beta: f64,
c: *mut f64,
rsc: isize,
csc: isize,
multithread: bool,
) {
if k == 0 || m == 0 || n == 0 {
return;
}
let (m, k, n, a, rsa, csa, b, rsb, csb, c, rsc, csc) = if n > m {
(n, k, m, b, csb, rsb, a, csa, rsa, c, csc, rsc)
} else {
(m, k, n, a, rsa, csa, b, rsb, csb, c, rsc, csc)
};
if cfg!(arch_haswell) {
hwl_kernels::dgemm(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread)
} else if cfg!(arch_sandybridge) {
snb_kernels::dgemm(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread)
} else if cfg!(arch_penryn) {
gemm_loop::<DgemmCache, D2x4>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else if cfg!(arch_generic4x4fma) {
gemm_loop::<DgemmCache, D2x4fma>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
} else {
gemm_loop::<DgemmCache, D2x4>(m, k, n, alpha, a, rsa, csa, b, rsb, csb, beta, c, rsc, csc, multithread);
}
}
#[cfg(no_multithreading)]
fn get_num_threads_and_cmc<C: CacheConfig<K>, K: KernelConfig>(_m: usize, _k: usize, _n: usize) -> (usize, usize) {
(1, C::mc())
}
#[cfg(not(no_multithreading))]
fn get_num_threads_and_cmc<C: CacheConfig<K>, K: KernelConfig>(m: usize, k: usize, n: usize) -> (usize, usize) {
let m_bands = round_up_div(m, K::MR::to_usize());
let max_mc_bands = {
let max_size = C::kc() * C::mc();
let max_kc = max(min(k, C::kc()), 1);
(max_size / max_kc + C::mc()) / (K::MR::to_usize() * 2)
};
let min_split_mc_bands = min(max_mc_bands, {
let max_compute = C::mc() * C::kc() * C::nc();
let min_compute = max(max_compute / C::multithread_factor(), 1);
let max_kc = max(min(k, C::kc()), 1);
let max_nc = max(min(round_up_to(n, K::NR::to_usize()), C::nc()), 1);
round_up_div(round_up_div(min_compute, max_nc * max_kc), K::MR::to_usize())
});
let num_threads = {
let full_blocks = m_bands / min_split_mc_bands;
max(min(*NUM_CPUS, full_blocks), 1)
};
let mc = {
let m_bands_per_thread = max(round_up_div(m_bands, num_threads), 1);
let blocks_per_thread = round_up_div(m_bands_per_thread, max_mc_bands);
let mc_bands = min(round_up_div(m_bands_per_thread, blocks_per_thread), m_bands);
mc_bands * K::MR::to_usize()
};
debug_assert!(
mc <= max_mc_bands * K::MR::to_usize(),
"mc{} min{} max{}",
mc,
min_split_mc_bands,
max_mc_bands
);
debug_assert!(num_threads <= *NUM_CPUS);
debug_assert_eq!(0, mc % K::MR::to_usize());
(num_threads, mc)
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn gemm_loop<C: CacheConfig<K>, K: KernelConfig>(
m: usize,
k: usize,
n: usize,
alpha: K::T,
a: *const K::T,
rsa: isize,
csa: isize,
b: *const K::T,
rsb: isize,
csb: isize,
beta: K::T,
c: *mut K::T,
rsc: isize,
csc: isize,
multithread: bool,
) {
debug_assert!(m * n == 0 || (rsc != 0 && csc != 0));
let knr = K::NR::to_usize();
let kmr = K::MR::to_usize();
let cnc = C::nc();
let ckc = C::kc();
let (num_threads, cmc) = get_num_threads_and_cmc::<C, K>(m, k, n);
assert_eq!(0, cnc % knr);
assert_eq!(0, cmc % kmr);
let pool_opt = if num_threads > 1 && multithread {
Some(THREAD_POOL.lock())
} else {
None
};
assert!(C::alignment() % size_of::<K::T>() == 0);
let (mut vec, app_stride, a_offset, b_offset) = aligned_packing_vec::<K, C::A>(m, k, n, cmc, ckc, cnc, num_threads);
let ptr = vec.as_mut_ptr();
let app_base = ptr.add(a_offset);
let bpp = ptr.add(b_offset);
debug_assert_eq!(bpp as usize % align_of::<K::T>(), 0);
for (l5, nc) in range_chunk(n, cnc) {
dprint!("LOOP 5, {}, nc={}", l5, nc);
let b = b.stride_offset(csb, cnc * l5);
let c = c.stride_offset(csc, cnc * l5);
for (l4, kc) in range_chunk(k, ckc) {
dprint!("LOOP 4, {}, kc={}", l4, kc);
let b = b.stride_offset(rsb, ckc * l4);
let a = a.stride_offset(csa, ckc * l4);
debug!(for elt in &mut packv {
*elt = K::T::one();
});
pack::<K::T, K::NR>(kc, nc, knr, bpp, b, csb, rsb);
if let (Some(pool), true) = (pool_opt.as_ref(), num_threads > 1) {
struct Ptrs<T: Element> {
app: *mut T,
bpp: *mut T,
a: *const T,
c: *mut T,
loop_counter: *const AtomicUsize,
sync: *const (Mutex<bool>, Condvar, AtomicUsize),
}
unsafe impl<T: Element> Send for Ptrs<T> {}
let sync = (Mutex::new(false), Condvar::new(), AtomicUsize::new(num_threads));
let loop_counter = AtomicUsize::new(0);
for cpu_id in 0..num_threads {
let p = Ptrs::<K::T> {
app: app_base.offset(app_stride * cpu_id as isize),
bpp,
a,
c,
loop_counter: &loop_counter as *const _,
sync: &sync as *const _,
};
debug_assert_eq!(p.app as usize % align_of::<K::T>(), 0);
pool.execute(move || {
let bpp = p.bpp;
let app = p.app;
let a = p.a;
let c = p.c;
let (ref lock, ref cvar, ref thread_counter) = *p.sync;
let mut next_id = (*p.loop_counter).fetch_add(1, Ordering::Relaxed);
let mxcsr = set_ftz_and_daz();
for (l3, mc) in range_chunk(m, cmc) {
if l3 < next_id {
continue;
}
dprint!("LOOP 3, {}, mc={}, id={}", l3, mc);
let a = a.stride_offset(rsa, cmc * l3);
let c = c.stride_offset(rsc, cmc * l3);
pack::<K::T, K::MR>(kc, mc, kmr, app, a, rsa, csa);
let betap = if l4 == 0 { beta } else { K::T::one() };
gemm_packed::<K>(nc, kc, mc, alpha, app, bpp, betap, c, rsc, csc);
next_id = (*p.loop_counter).fetch_add(1, Ordering::Relaxed);
}
let x = thread_counter.fetch_sub(1, Ordering::AcqRel);
if x == 1 {
{
*lock.lock() = true;
}
cvar.notify_all();
}
reset_ftz_and_daz(mxcsr);
});
}
let (ref lock, ref cvar, ref thread_counter) = sync;
let mut finished = lock.lock();
while !*finished {
cvar.wait(&mut finished);
}
debug_assert!(thread_counter.load(Ordering::SeqCst) == 0);
} else {
let app = app_base;
let mxcsr = set_ftz_and_daz();
for (l3, mc) in range_chunk(m, cmc) {
dprint!("LOOP 3, {}, mc={}", l3, mc);
let a = a.stride_offset(rsa, cmc * l3);
let c = c.stride_offset(rsc, cmc * l3);
pack::<K::T, K::MR>(kc, mc, kmr, app, a, rsa, csa);
let betap = if l4 == 0 { beta } else { <K::T>::one() };
gemm_packed::<K>(nc, kc, mc, alpha, app, bpp, betap, c, rsc, csc);
}
reset_ftz_and_daz(mxcsr);
}
}
}
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
unsafe fn gemm_packed<K: KernelConfig>(
nc: usize,
kc: usize,
mc: usize,
alpha: K::T,
app: *const K::T,
bpp: *const K::T,
beta: K::T,
c: *mut K::T,
rsc: isize,
csc: isize,
) {
let mr = K::MR::to_usize();
let nr = K::NR::to_usize();
if beta.is_zero() {
zero_block::<K::T>(mc, nc, c, rsc, csc);
} else if beta != K::T::one() {
scale_block::<K::T>(beta, mc, nc, c, rsc, csc);
}
for (l2, nr_) in range_chunk(nc, nr) {
let bpp = bpp.stride_offset(1, kc * nr * l2);
let c = c.stride_offset(csc, nr * l2);
for (l1, mr_) in range_chunk(mc, mr) {
let app = app.stride_offset(1, kc * mr * l1);
let c = c.stride_offset(rsc, mr * l1);
if nr_ < nr || mr_ < mr {
generic_kernel::masked_kernel::<K>(kc, alpha, &*app, &*bpp, &mut *c, rsc, csc, mr_, nr_);
} else {
generic_kernel::kernel::<K>(kc, alpha, app, bpp, c, rsc, csc);
}
}
}
}
unsafe fn scale_block<T: Element>(beta: T, rows: usize, cols: usize, c: *mut T, rsc: isize, csc: isize) {
if rsc == 1 {
for col in 0..cols {
for row in 0..rows {
let cptr = c.offset(1 * row as isize + csc * col as isize);
*cptr = *cptr * beta;
}
}
} else if csc == 1 {
for row in 0..rows {
for col in 0..cols {
let cptr = c.offset(rsc * row as isize + 1 * col as isize);
*cptr = *cptr * beta;
}
}
} else {
for col in 0..cols {
for row in 0..rows {
let cptr = c.offset(rsc * row as isize + csc * col as isize);
*cptr = *cptr * beta;
}
}
}
}
unsafe fn zero_block<T: Element>(rows: usize, cols: usize, c: *mut T, rsc: isize, csc: isize) {
if rsc == 1 {
for col in 0..cols {
for row in 0..rows {
let cptr = c.offset(1 * row as isize + csc * col as isize);
*cptr = T::zero();
}
}
} else if csc == 1 {
for row in 0..rows {
for col in 0..cols {
let cptr = c.offset(rsc * row as isize + 1 * col as isize);
*cptr = T::zero();
}
}
} else {
for col in 0..cols {
for row in 0..rows {
let cptr = c.offset(rsc * row as isize + csc * col as isize);
*cptr = T::zero();
}
}
}
}
#[inline(always)]
unsafe fn aligned_packing_vec<K: KernelConfig, A: Unsigned>(
m: usize,
k: usize,
n: usize,
cmc: usize,
ckc: usize,
cnc: usize,
num_a: usize,
) -> (SmallVec<[K::T; 128]>, isize, usize, usize) {
let m = min(m, cmc);
let k = min(k, ckc);
let n = min(n, cnc);
let align = A::to_usize();
assert!(align % size_of::<K::T>() == 0);
let align_elems = align / size_of::<K::T>();
let apack_size = k * round_up_to(m, K::MR::to_usize());
let bpack_size = k * round_up_to(n, K::NR::to_usize());
let padding_bytes1 = align_elems; let padding_bytes2 = if align_elems == 0 {
0
} else {
round_up_to(apack_size, align_elems) - apack_size
}; let nelem = padding_bytes1 + (apack_size + padding_bytes2) * num_a + bpack_size;
let mut v = SmallVec::with_capacity(nelem);
v.set_len(nelem);
dprint!(
"packed nelem={}, apack={}, bpack={},
m={} k={} n={}",
nelem,
apack_size,
bpack_size,
m,
k,
n
);
let mut a_offset = 0;
if align != 0 {
let current_misalignment = v.as_ptr() as usize % align;
debug_assert!(current_misalignment % size_of::<K::T>() == 0); if current_misalignment != 0 {
a_offset = (align - current_misalignment) / size_of::<K::T>();
}
}
let b_offset = a_offset + (apack_size + padding_bytes2) * num_a;
(v, (apack_size + padding_bytes2) as isize, a_offset, b_offset)
}
unsafe fn pack<T: Element, MR: Loop + ArrayLength<T>>(
kc: usize,
mc: usize,
mr: usize,
pack: *mut T,
a: *const T,
rsa: isize,
csa: isize,
) {
debug_assert_eq!(mr, MR::to_usize());
if csa == 1 {
part_pack_row_major::<T, MR>(kc, mc, mr, pack, a, rsa, csa);
} else if rsa == 1 {
part_pack_col_major::<T, MR>(kc, mc, mr, pack, a, rsa, csa);
} else {
part_pack_strided::<T, MR>(kc, mc, mr, pack, a, rsa, csa);
}
let rest = mc % mr;
if rest > 0 {
part_pack_end::<T, MR>(kc, mc, mr, pack, a, rsa, csa, rest);
}
}
#[inline(never)]
unsafe fn part_pack_row_major<T: Element, MR: Loop + ArrayLength<T>>(
kc: usize,
mc: usize,
mr: usize,
pack: *mut T,
a: *const T,
rsa: isize,
csa: isize,
) {
debug_assert_eq!(mr, MR::to_usize());
debug_assert_eq!(csa, 1);
let csa = 1isize;
let mr = MR::to_usize();
for ir in 0..mc / mr {
let a = a.offset((ir * mr) as isize * rsa);
let pack = pack.add(ir * mr * kc);
let kc_prefetch = kc.saturating_sub(128 / mr);
for j in 0..kc_prefetch {
let a = a.stride_offset(csa, j);
let mut arr = <GA<T, MR>>::default();
MR::full_unroll(&mut |i| {
arr[i] = *a.stride_offset(rsa, i);
});
MR::full_unroll(&mut |i| {
*pack.add(j * mr + i) = arr[i];
});
}
MR::full_unroll(&mut |i| {
prefetch_read(a.offset(((ir + 1) * mr + i) as isize * rsa) as *mut i8);
});
for j in kc_prefetch..kc {
let a = a.stride_offset(csa, j);
let mut arr = <GA<T, MR>>::default();
MR::full_unroll(&mut |i| {
arr[i] = *a.stride_offset(rsa, i);
});
MR::full_unroll(&mut |i| {
*pack.add(j * mr + i) = arr[i];
});
}
}
}
#[inline(never)]
unsafe fn part_pack_col_major<T: Element, MR: Loop + ArrayLength<T>>(
kc: usize,
mc: usize,
mr: usize,
pack: *mut T,
a: *const T,
rsa: isize,
csa: isize,
) {
debug_assert_eq!(mr, MR::to_usize());
debug_assert_eq!(rsa, 1);
let rsa = 1isize;
let mr = MR::to_usize();
for ir in 0..mc / mr {
let a = a.offset((ir * mr) as isize * rsa);
let pack = pack.add(ir * mr * kc);
prefetch_read(a.offset(((ir + 1) * mr) as isize * rsa) as *mut i8);
for j in 0..kc {
prefetch_read(a.stride_offset(csa, j + 64 / mr) as *mut i8);
let mut arr = <GA<T, MR>>::default();
let a = a.stride_offset(csa, j);
MR::full_unroll(&mut |i| {
arr[i] = *a.stride_offset(rsa, i);
});
MR::full_unroll(&mut |i| {
*pack.add(j * mr + i) = arr[i];
});
}
}
}
#[inline(never)]
unsafe fn part_pack_strided<T: Element, MR: Loop + ArrayLength<T>>(
kc: usize,
mc: usize,
mr: usize,
pack: *mut T,
a: *const T,
rsa: isize,
csa: isize,
) {
debug_assert_eq!(mr, MR::to_usize());
let mr = MR::to_usize();
for ir in 0..mc / mr {
let a = a.offset((ir * mr) as isize * rsa);
let pack = pack.add(ir * mr * kc);
for j in 0..kc {
MR::full_unroll(&mut |i| {
*pack.add(j * mr + i) = *a.stride_offset(rsa, i).stride_offset(csa, j);
});
}
}
}
#[allow(clippy::too_many_arguments)]
unsafe fn part_pack_end<T: Element, MR: Loop + ArrayLength<T>>(
kc: usize,
mc: usize,
mr: usize,
pack: *mut T,
a: *const T,
rsa: isize,
csa: isize,
rest: usize,
) {
debug_assert_eq!(mr, MR::to_usize());
let mr = MR::to_usize();
let mut pack = pack.add((mc / mr) * mr * kc);
let row_offset = mc - rest; for j in 0..kc {
MR::full_unroll(&mut |i| {
if i < rest {
*pack = *a.stride_offset(rsa, i + row_offset).stride_offset(csa, j);
} else {
*pack = T::zero();
}
pack.inc();
});
}
}