use crate::prefetch_read;
use crate::{generic_params::*, prefetch_write};
use num_traits::Float;
use std::cmp::min;
use typenum::Unsigned;
use typenum_loops::Loop;
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub unsafe fn masked_kernel<K: KernelConfig>(
k: usize,
alpha: K::T,
a: *const K::T,
b: *const K::T,
c: *mut K::T,
rsc: isize,
csc: isize,
rows: usize,
cols: usize,
) {
let mr = min(K::MR::to_usize(), rows);
let nr = min(K::NR::to_usize(), cols);
prefetch_read(a as *mut i8);
prefetch_read(b as *mut i8);
write_prefetch::<K>(c, rsc, csc);
if K::TR::to_usize() == 0 {
let ab = kernel_compute::<K>(k, alpha, a, b);
for j in 0..nr {
for i in 0..mr {
let cptr = c.offset(rsc * i as isize + csc * j as isize);
*cptr += ab[i][j];
}
}
} else {
let ab = kernel_compute_trans::<K>(k, alpha, a, b);
for j in 0..nr {
for i in 0..mr {
let cptr = c.offset(rsc * i as isize + csc * j as isize);
*cptr += ab[j][i];
}
}
}
}
#[inline(never)]
pub unsafe fn kernel<K: KernelConfig>(
k: usize,
alpha: K::T,
a: *const K::T,
b: *const K::T,
c: *mut K::T,
rsc: isize,
csc: isize,
) {
prefetch_read(a as *mut i8);
prefetch_read(b as *mut i8);
write_prefetch::<K>(c, rsc, csc);
if K::TR::to_usize() == 0 {
let ab = kernel_compute::<K>(k, alpha, a, b);
kernel_write::<K>(c, rsc, csc, &ab);
} else {
let ab = kernel_compute_trans::<K>(k, alpha, a, b);
kernel_write_trans::<K>(c, rsc, csc, &ab);
}
}
#[inline(always)]
unsafe fn kernel_compute<K: KernelConfig>(
k: usize,
alpha: K::T,
a: *const K::T,
b: *const K::T,
) -> GA<GA<K::T, K::NR>, K::MR> {
let mut ab = <GA<GA<K::T, K::NR>, K::MR>>::default();
K::KU::partial_unroll(k, &mut |l, _| {
let a = a.add(l * K::MR::to_usize());
let b = b.add(l * K::NR::to_usize());
K::MR::full_unroll(&mut |i| {
K::NR::full_unroll(&mut |j| {
if K::FMA::to_usize() > 0 {
ab[i][j] = at::<K::T>(a, i).mul_add(at::<K::T>(b, j), ab[i][j]);
} else {
ab[i][j] += at::<K::T>(a, i) * at::<K::T>(b, j);
}
});
});
});
K::MR::full_unroll(&mut |i| {
K::NR::full_unroll(&mut |j| {
ab[i][j] = ab[i][j] * alpha;
});
});
ab
}
#[inline(always)]
unsafe fn kernel_compute_trans<K: KernelConfig>(
k: usize,
alpha: K::T,
a: *const K::T,
b: *const K::T,
) -> GA<GA<K::T, K::MR>, K::NR> {
let mut ab = <GA<GA<K::T, K::MR>, K::NR>>::default();
K::KU::partial_unroll(k, &mut |l, _| {
let a = a.add(l * K::MR::to_usize());
let b = b.add(l * K::NR::to_usize());
K::NR::full_unroll(&mut |j| {
K::MR::full_unroll(&mut |i| {
if K::FMA::to_usize() > 0 {
ab[j][i] = at::<K::T>(a, i).mul_add(at::<K::T>(b, j), ab[j][i]);
} else {
ab[j][i] += at::<K::T>(a, i) * at::<K::T>(b, j);
}
});
});
});
K::NR::full_unroll(&mut |j| {
K::MR::full_unroll(&mut |i| {
ab[j][i] = ab[j][i] * alpha;
});
});
ab
}
#[inline(always)]
unsafe fn write_prefetch<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize) {
if rsc == 1 {
K::NR::full_unroll(&mut |j| {
prefetch_write(c.offset(csc * j as isize) as *mut i8);
});
} else if csc == 1 {
K::MR::full_unroll(&mut |i| {
prefetch_write(c.offset(rsc * i as isize) as *mut i8);
});
} else {
for i in 0..K::MR::to_usize() {
for j in 0..K::NR::to_usize() {
prefetch_write(c.offset(rsc * i as isize + csc * j as isize) as *mut i8);
}
}
}
}
#[inline(always)]
unsafe fn kernel_write<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize, ab: &GA<GA<K::T, K::NR>, K::MR>) {
if rsc == 1 {
for i in 0..K::MR::to_usize() {
for j in 0..K::NR::to_usize() {
let v = c.offset(1 * i as isize + csc * j as isize);
*v += ab[i][j];
}
}
} else if csc == 1 {
for i in 0..K::MR::to_usize() {
for j in 0..K::NR::to_usize() {
let v = c.offset(rsc * i as isize + 1 * j as isize);
*v += ab[i][j];
}
}
} else {
for i in 0..K::MR::to_usize() {
for j in 0..K::NR::to_usize() {
let v = c.offset(rsc * i as isize + csc * j as isize);
*v += ab[i][j];
}
}
}
}
#[inline(always)]
unsafe fn kernel_write_trans<K: KernelConfig>(c: *mut K::T, rsc: isize, csc: isize, ab: &GA<GA<K::T, K::MR>, K::NR>) {
if rsc == 1 {
for j in 0..K::NR::to_usize() {
for i in 0..K::MR::to_usize() {
let v = c.offset(1 * i as isize + csc * j as isize);
*v += ab[j][i];
}
}
} else if csc == 1 {
for j in 0..K::NR::to_usize() {
for i in 0..K::MR::to_usize() {
let v = c.offset(rsc * i as isize + 1 * j as isize);
*v += ab[j][i];
}
}
} else {
for j in 0..K::NR::to_usize() {
for i in 0..K::MR::to_usize() {
let v = c.offset(rsc * i as isize + csc * j as isize);
*v += ab[j][i];
}
}
}
}
#[inline(always)]
unsafe fn at<T: Copy>(ptr: *const T, i: usize) -> T {
*ptr.offset(i as isize)
}