1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
//! Matrix packing functions for microkernel consumption
//!
//! These functions reorder matrix data into a layout optimized for the
//! 6×NR microkernels. Packing improves cache utilization by ensuring
//! sequential memory access in the innermost loop.
use super::MR;
/// Generate pack_a function for a given type
macro_rules! define_pack_a {
($name:ident, $ty:ty) => {
/// Pack A matrix panel for microkernel consumption
///
/// Layout: For each MR-row block, for each k: MR consecutive elements
/// `[a[0,0], a[1,0], ..., a[MR-1,0], a[0,1], a[1,1], ..., a[MR-1,1], ...]`
///
/// # Safety
/// - `a` must be valid for reading `mc * kc` elements with stride `lda`
/// - `packed` must be valid for writing `(mc.div_ceil(MR) * MR) * kc` elements
#[inline]
pub unsafe fn $name(a: *const $ty, packed: *mut $ty, mc: usize, kc: usize, lda: usize) {
let mut p = 0;
for ir in (0..mc).step_by(MR) {
let mr_actual = (mc - ir).min(MR);
if mr_actual == MR {
// Full MR block - no padding needed
for k in 0..kc {
for i in 0..MR {
*packed.add(p) = *a.add((ir + i) * lda + k);
p += 1;
}
}
} else {
// Partial block - pad with zeros
for k in 0..kc {
for i in 0..mr_actual {
*packed.add(p) = *a.add((ir + i) * lda + k);
p += 1;
}
for _ in mr_actual..MR {
*packed.add(p) = 0.0;
p += 1;
}
}
}
}
}
};
}
/// Generate pack_b function for a given type
macro_rules! define_pack_b {
($name:ident, $ty:ty) => {
/// Pack B matrix panel for microkernel consumption
///
/// Layout: For each NR-column block, for each k: NR consecutive elements.
/// Uses bulk copy for full NR blocks since B is row-major.
///
/// # Safety
/// - `b` must be valid for reading `kc * nc` elements with stride `ldb`
/// - `packed` must be valid for writing `(nc.div_ceil(NR) * NR) * kc` elements
#[inline]
pub unsafe fn $name<const NR: usize>(
b: *const $ty,
packed: *mut $ty,
nc: usize,
kc: usize,
ldb: usize,
) {
let mut p = 0;
for jr in (0..nc).step_by(NR) {
let nr_actual = (nc - jr).min(NR);
if nr_actual == NR {
// Full NR block: B elements are contiguous in each row
for k in 0..kc {
std::ptr::copy_nonoverlapping(b.add(k * ldb + jr), packed.add(p), NR);
p += NR;
}
} else {
// Partial (or half) block — pack CONTIGUOUSLY with stride
// `nr_actual`, NOT padded to NR. The consuming microkernels
// (microkernel_edge / the single-width half kernel) index this
// block with stride `nr_actual` (`b.add(kk * nr + j)`), so the
// packed stride MUST be nr_actual; zero-padding to NR here would
// make every kk>0 read into the previous row's pad → wrong dot
// product. The partial block is always the LAST jr-block, so its
// (smaller) size does not shift any later block's offset.
for k in 0..kc {
for j in 0..nr_actual {
*packed.add(p) = *b.add(k * ldb + jr + j);
p += 1;
}
}
}
}
}
};
}
define_pack_a!(pack_a_f32, f32);
define_pack_a!(pack_a_f64, f64);
define_pack_b!(pack_b_f32, f32);
define_pack_b!(pack_b_f64, f64);