use super::{MR, NR};
pub fn pack_a(
a: &[f32],
lda: usize, mc: usize, kc: usize, packed: &mut [f32],
) {
let mut pack_idx = 0;
let full_panels = mc / MR;
let remainder = mc % MR;
for panel in 0..full_panels {
let row_start = panel * MR;
for col in 0..kc {
for row in 0..MR {
packed[pack_idx] = a[(row_start + row) * lda + col];
pack_idx += 1;
}
}
}
if remainder > 0 {
let row_start = full_panels * MR;
for col in 0..kc {
for row in 0..MR {
if row < remainder {
packed[pack_idx] = a[(row_start + row) * lda + col];
} else {
packed[pack_idx] = 0.0; }
pack_idx += 1;
}
}
}
}
pub fn pack_b(
b: &[f32],
ldb: usize, kc: usize, nc: usize, packed: &mut [f32],
) {
let mut pack_idx = 0;
let full_panels = nc / NR;
let remainder = nc % NR;
for panel in 0..full_panels {
let col_start = panel * NR;
for row in 0..kc {
for col in 0..NR {
packed[pack_idx] = b[row * ldb + col_start + col];
pack_idx += 1;
}
}
}
if remainder > 0 {
let col_start = full_panels * NR;
for row in 0..kc {
for col in 0..NR {
if col < remainder {
packed[pack_idx] = b[row * ldb + col_start + col];
} else {
packed[pack_idx] = 0.0;
}
pack_idx += 1;
}
}
}
}
#[inline]
pub fn packed_a_size(mc: usize, kc: usize) -> usize {
let panels = (mc + MR - 1) / MR;
panels * MR * kc
}
#[inline]
pub fn packed_b_size(kc: usize, nc: usize) -> usize {
let panels = (nc + NR - 1) / NR;
panels * NR * kc
}
pub(super) fn pack_a_block(
a: &[f32],
lda: usize,
row_start: usize,
col_start: usize,
rows: usize,
cols: usize,
packed: &mut [f32],
) {
let mut pack_idx = 0;
let panels = (rows + MR - 1) / MR;
for panel in 0..panels {
let ir = panel * MR;
let mr_actual = MR.min(rows - ir);
for col in 0..cols {
for row in 0..MR {
if row < mr_actual {
packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
} else {
packed[pack_idx] = 0.0;
}
pack_idx += 1;
}
}
}
}
pub(super) fn pack_b_block(
b: &[f32],
ldb: usize,
row_start: usize,
col_start: usize,
rows: usize,
cols: usize,
packed: &mut [f32],
) {
let mut pack_idx = 0;
let panels = (cols + NR - 1) / NR;
for panel in 0..panels {
let jr = panel * NR;
let nr_actual = NR.min(cols - jr);
for row in 0..rows {
for col in 0..NR {
if col < nr_actual {
packed[pack_idx] = b[(row_start + row) * ldb + col_start + jr + col];
} else {
packed[pack_idx] = 0.0;
}
pack_idx += 1;
}
}
}
}