use super::MR;
macro_rules! define_pack_a {
($name:ident, $ty:ty) => {
#[inline]
pub unsafe fn $name(a: *const $ty, packed: *mut $ty, mc: usize, kc: usize, lda: usize) {
let mut p = 0;
for ir in (0..mc).step_by(MR) {
let mr_actual = (mc - ir).min(MR);
if mr_actual == MR {
for k in 0..kc {
for i in 0..MR {
*packed.add(p) = *a.add((ir + i) * lda + k);
p += 1;
}
}
} else {
for k in 0..kc {
for i in 0..mr_actual {
*packed.add(p) = *a.add((ir + i) * lda + k);
p += 1;
}
for _ in mr_actual..MR {
*packed.add(p) = 0.0;
p += 1;
}
}
}
}
}
};
}
macro_rules! define_pack_b {
($name:ident, $ty:ty) => {
#[inline]
pub unsafe fn $name<const NR: usize>(
b: *const $ty,
packed: *mut $ty,
nc: usize,
kc: usize,
ldb: usize,
) {
let mut p = 0;
for jr in (0..nc).step_by(NR) {
let nr_actual = (nc - jr).min(NR);
if nr_actual == NR {
for k in 0..kc {
std::ptr::copy_nonoverlapping(b.add(k * ldb + jr), packed.add(p), NR);
p += NR;
}
} else {
for k in 0..kc {
for j in 0..nr_actual {
*packed.add(p) = *b.add(k * ldb + jr + j);
p += 1;
}
for _ in nr_actual..NR {
*packed.add(p) = 0.0;
p += 1;
}
}
}
}
}
};
}
define_pack_a!(pack_a_f32, f32);
define_pack_a!(pack_a_f64, f64);
define_pack_b!(pack_b_f32, f32);
define_pack_b!(pack_b_f64, f64);