use rawpointer::PointerExt;
use core::ptr::copy_nonoverlapping;
use crate::kernel::ConstNum;
use crate::kernel::Element;
pub(crate) unsafe fn pack<MR, T>(kc: usize, mc: usize, pack: &mut [T],
a: *const T, rsa: isize, csa: isize)
where T: Element,
MR: ConstNum,
{
pack_impl::<MR, T>(kc, mc, pack, a, rsa, csa)
}
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
#[target_feature(enable="avx2")]
pub(crate) unsafe fn pack_avx2<MR, T>(kc: usize, mc: usize, pack: &mut [T],
a: *const T, rsa: isize, csa: isize)
where T: Element,
MR: ConstNum,
{
pack_impl::<MR, T>(kc, mc, pack, a, rsa, csa)
}
#[inline(always)]
unsafe fn pack_impl<MR, T>(kc: usize, mc: usize, pack: &mut [T],
a: *const T, rsa: isize, csa: isize)
where T: Element,
MR: ConstNum,
{
let pack = pack.as_mut_ptr();
let mr = MR::VALUE;
let mut p = 0;
if rsa == 1 {
for ir in 0..mc/mr {
let row_offset = ir * mr;
for j in 0..kc {
let a_row = a.stride_offset(rsa, row_offset)
.stride_offset(csa, j);
copy_nonoverlapping(a_row, pack.add(p), mr);
p += mr;
}
}
} else {
for ir in 0..mc/mr {
let row_offset = ir * mr;
for j in 0..kc {
for i in 0..mr {
let a_elt = a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
copy_nonoverlapping(a_elt, pack.add(p), 1);
p += 1;
}
}
}
}
let zero = <_>::zero();
let rest = mc % mr;
if rest > 0 {
let row_offset = (mc/mr) * mr;
for j in 0..kc {
for i in 0..mr {
if i < rest {
let a_elt = a.stride_offset(rsa, i + row_offset)
.stride_offset(csa, j);
copy_nonoverlapping(a_elt, pack.add(p), 1);
} else {
*pack.add(p) = zero;
}
p += 1;
}
}
}
}