Skip to main content

trueno/blis/
packing.rs

1//! BLIS matrix packing routines.
2//!
3//! Packing transforms row-major matrices into micro-panel layouts optimized
4//! for sequential access in the microkernel, ensuring optimal cache line
5//! utilization and aligned loads for SIMD.
6//!
7//! # References
8//!
9//! - Van Zee, F. G., & Van de Geijn, R. A. (2015). BLIS: A Framework for Rapidly Instantiating
10//!   BLAS Functionality. ACM TOMS, 41(3), Fig. 4.
11
12use super::{MR, NR};
13#[cfg(target_arch = "x86_64")]
14use super::{MR_512, NR_512};
15
16/// Pack A into MC x KC panel with MR-aligned micro-panels
17///
18/// Memory layout (Van Zee & Van de Geijn, 2015, Fig. 4):
19/// Original A (row-major):     Packed A (column-major micro-panels):
20/// [a00 a01 a02 ...]           [a00 a10 a20 ... a(MR-1)0 | a01 a11 ...]
21/// [a10 a11 a12 ...]            \____ MR elements ____/
22///
23/// This layout ensures:
24/// 1. Sequential access in the microkernel
25/// 2. Optimal cache line utilization
26/// 3. Aligned loads for SIMD
27pub fn pack_a(
28    a: &[f32],
29    lda: usize, // Leading dimension of A (number of columns in original)
30    mc: usize,  // Number of rows to pack
31    kc: usize,  // Number of columns to pack
32    packed: &mut [f32],
33) {
34    let mut pack_idx = 0;
35
36    // Process MR rows at a time
37    let full_panels = mc / MR;
38    let remainder = mc % MR;
39
40    for panel in 0..full_panels {
41        let row_start = panel * MR;
42
43        for col in 0..kc {
44            for row in 0..MR {
45                packed[pack_idx] = a[(row_start + row) * lda + col];
46                pack_idx += 1;
47            }
48        }
49    }
50
51    // Handle remainder rows (pad with zeros)
52    if remainder > 0 {
53        let row_start = full_panels * MR;
54
55        for col in 0..kc {
56            for row in 0..MR {
57                if row < remainder {
58                    packed[pack_idx] = a[(row_start + row) * lda + col];
59                } else {
60                    packed[pack_idx] = 0.0; // Zero padding
61                }
62                pack_idx += 1;
63            }
64        }
65    }
66}
67
68/// Pack B into KC x NC panel with NR-aligned micro-panels
69///
70/// Memory layout:
71/// Original B (row-major):     Packed B (row-major micro-panels):
72/// [b00 b01 b02 ...]           [b00 b01 ... b(NR-1) | b10 b11 ...]
73/// [b10 b11 b12 ...]            \____ NR elements ____/
74pub fn pack_b(
75    b: &[f32],
76    ldb: usize, // Leading dimension of B (number of columns in original)
77    kc: usize,  // Number of rows to pack
78    nc: usize,  // Number of columns to pack
79    packed: &mut [f32],
80) {
81    let mut pack_idx = 0;
82
83    let full_panels = nc / NR;
84    let remainder = nc % NR;
85
86    for panel in 0..full_panels {
87        let col_start = panel * NR;
88
89        for row in 0..kc {
90            for col in 0..NR {
91                packed[pack_idx] = b[row * ldb + col_start + col];
92                pack_idx += 1;
93            }
94        }
95    }
96
97    // Handle remainder columns (pad with zeros)
98    if remainder > 0 {
99        let col_start = full_panels * NR;
100
101        for row in 0..kc {
102            for col in 0..NR {
103                if col < remainder {
104                    packed[pack_idx] = b[row * ldb + col_start + col];
105                } else {
106                    packed[pack_idx] = 0.0;
107                }
108                pack_idx += 1;
109            }
110        }
111    }
112}
113
114/// Compute required packed A buffer size
115#[inline]
116pub fn packed_a_size(mc: usize, kc: usize) -> usize {
117    let panels = (mc + MR - 1) / MR;
118    panels * MR * kc
119}
120
121/// Compute required packed B buffer size
122#[inline]
123pub fn packed_b_size(kc: usize, nc: usize) -> usize {
124    let panels = (nc + NR - 1) / NR;
125    panels * NR * kc
126}
127
128/// Pack A block from row-major source.
129/// Full MR=8 panels use AVX2 8×8 transpose when available (8× fewer scalar ops).
130pub(super) fn pack_a_block(
131    a: &[f32],
132    lda: usize,
133    row_start: usize,
134    col_start: usize,
135    rows: usize,
136    cols: usize,
137    packed: &mut [f32],
138) {
139    let panels = (rows + MR - 1) / MR;
140
141    #[cfg(target_arch = "x86_64")]
142    if is_x86_feature_detected!("avx2") {
143        // SAFETY: AVX2 verified above
144        unsafe {
145            pack_a_block_avx2(a, lda, row_start, col_start, rows, cols, panels, packed);
146        }
147        return;
148    }
149
150    // Scalar fallback
151    let mut pack_idx = 0;
152    for panel in 0..panels {
153        let ir = panel * MR;
154        let mr_actual = MR.min(rows - ir);
155
156        for col in 0..cols {
157            for row in 0..MR {
158                if row < mr_actual {
159                    packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
160                } else {
161                    packed[pack_idx] = 0.0;
162                }
163                pack_idx += 1;
164            }
165        }
166    }
167}
168
169/// AVX2 SIMD-accelerated A packing: 8×8 transpose blocks.
170/// Processes 8 columns at a time using AVX2 unpack/shuffle/permute
171/// to transpose 8×8 blocks from row-major to column-major in-register.
172/// Reduces per-element cost from 1 scalar load + 1 scalar store to
173/// ~0.25 SIMD ops per element (8× improvement).
174#[cfg(target_arch = "x86_64")]
175#[target_feature(enable = "avx2")]
176unsafe fn pack_a_block_avx2(
177    a: &[f32],
178    lda: usize,
179    row_start: usize,
180    col_start: usize,
181    rows: usize,
182    cols: usize,
183    panels: usize,
184    packed: &mut [f32],
185) {
186    use std::arch::x86_64::*;
187
188    let mut pack_idx = 0;
189
190    for panel in 0..panels {
191        let ir = panel * MR;
192        let mr_actual = MR.min(rows - ir);
193
194        if mr_actual == MR {
195            // Full panel: 8×8 SIMD transpose blocks
196            let k_blocks = cols / 8;
197            let k_rem_start = k_blocks * 8;
198
199            for kb in 0..k_blocks {
200                let p = kb * 8;
201                let base = row_start + ir;
202                let col = col_start + p;
203
204                // SAFETY: AVX2 verified by caller. Pointers are within bounds
205                // (base+7 < row_start+rows, col+7 < col_start+cols).
206                unsafe {
207                    let r0 = _mm256_loadu_ps(a.as_ptr().add(base * lda + col));
208                    let r1 = _mm256_loadu_ps(a.as_ptr().add((base + 1) * lda + col));
209                    let r2 = _mm256_loadu_ps(a.as_ptr().add((base + 2) * lda + col));
210                    let r3 = _mm256_loadu_ps(a.as_ptr().add((base + 3) * lda + col));
211                    let r4 = _mm256_loadu_ps(a.as_ptr().add((base + 4) * lda + col));
212                    let r5 = _mm256_loadu_ps(a.as_ptr().add((base + 5) * lda + col));
213                    let r6 = _mm256_loadu_ps(a.as_ptr().add((base + 6) * lda + col));
214                    let r7 = _mm256_loadu_ps(a.as_ptr().add((base + 7) * lda + col));
215
216                    // 8×8 transpose via unpack + shuffle + permute2f128
217                    let t0 = _mm256_unpacklo_ps(r0, r1);
218                    let t1 = _mm256_unpackhi_ps(r0, r1);
219                    let t2 = _mm256_unpacklo_ps(r2, r3);
220                    let t3 = _mm256_unpackhi_ps(r2, r3);
221                    let t4 = _mm256_unpacklo_ps(r4, r5);
222                    let t5 = _mm256_unpackhi_ps(r4, r5);
223                    let t6 = _mm256_unpacklo_ps(r6, r7);
224                    let t7 = _mm256_unpackhi_ps(r6, r7);
225
226                    let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
227                    let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
228                    let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
229                    let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
230                    let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
231                    let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
232                    let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
233                    let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
234
235                    let dst = packed.as_mut_ptr().add(pack_idx);
236                    _mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
237                    _mm256_storeu_ps(dst.add(8), _mm256_permute2f128_ps(u1, u5, 0x20));
238                    _mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u2, u6, 0x20));
239                    _mm256_storeu_ps(dst.add(24), _mm256_permute2f128_ps(u3, u7, 0x20));
240                    _mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u0, u4, 0x31));
241                    _mm256_storeu_ps(dst.add(40), _mm256_permute2f128_ps(u1, u5, 0x31));
242                    _mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u2, u6, 0x31));
243                    _mm256_storeu_ps(dst.add(56), _mm256_permute2f128_ps(u3, u7, 0x31));
244                }
245                pack_idx += 64;
246            }
247
248            // Remainder columns: scalar
249            for col in k_rem_start..cols {
250                for row in 0..MR {
251                    packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
252                    pack_idx += 1;
253                }
254            }
255        } else {
256            // Edge panel: scalar with zero-pad
257            for col in 0..cols {
258                for row in 0..MR {
259                    if row < mr_actual {
260                        packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
261                    } else {
262                        packed[pack_idx] = 0.0;
263                    }
264                    pack_idx += 1;
265                }
266            }
267        }
268    }
269}
270
271/// Pack B block from row-major source
272pub(super) fn pack_b_block(
273    b: &[f32],
274    ldb: usize,
275    row_start: usize,
276    col_start: usize,
277    rows: usize,
278    cols: usize,
279    packed: &mut [f32],
280) {
281    let mut pack_idx = 0;
282    let panels = (cols + NR - 1) / NR;
283
284    for panel in 0..panels {
285        let jr = panel * NR;
286        let nr_actual = NR.min(cols - jr);
287
288        for row in 0..rows {
289            for col in 0..NR {
290                if col < nr_actual {
291                    packed[pack_idx] = b[(row_start + row) * ldb + col_start + jr + col];
292                } else {
293                    packed[pack_idx] = 0.0;
294                }
295                pack_idx += 1;
296            }
297        }
298    }
299}
300
301// ============================================================================
302// AVX-512 packing (MR=16, NR=8)
303// ============================================================================
304
305/// Compute required packed A buffer size for AVX-512
306#[cfg(target_arch = "x86_64")]
307#[inline]
308pub fn packed_a_size_512(mc: usize, kc: usize) -> usize {
309    let panels = (mc + MR_512 - 1) / MR_512;
310    panels * MR_512 * kc
311}
312
313/// Compute required packed B buffer size for AVX-512
314#[cfg(target_arch = "x86_64")]
315#[inline]
316pub fn packed_b_size_512(kc: usize, nc: usize) -> usize {
317    let panels = (nc + NR_512 - 1) / NR_512;
318    panels * NR_512 * kc
319}
320
321/// Pack A block with MR_512=16 micro-panels for AVX-512 microkernel.
322/// A is row-major; packed A is column-major micro-panels of width MR_512=16.
323/// Full panels (mr_actual == 16): SIMD gather from 16 rows via scalar
324/// (gather is limited by row stride), but the inner loop is tight.
325#[cfg(target_arch = "x86_64")]
326#[allow(dead_code)] // Used by gemm_blis_avx512_packed (retained for AVX-512-only systems)
327pub(super) fn pack_a_block_512(
328    a: &[f32],
329    lda: usize,
330    row_start: usize,
331    col_start: usize,
332    rows: usize,
333    cols: usize,
334    packed: &mut [f32],
335) {
336    let mut pack_idx = 0;
337    let panels = (rows + MR_512 - 1) / MR_512;
338
339    for panel in 0..panels {
340        let ir = panel * MR_512;
341        let mr_actual = MR_512.min(rows - ir);
342
343        if mr_actual == MR_512 {
344            // Full panel: no zero-padding needed
345            for col in 0..cols {
346                for row in 0..MR_512 {
347                    packed[pack_idx + row] = a[(row_start + ir + row) * lda + col_start + col];
348                }
349                pack_idx += MR_512;
350            }
351        } else {
352            // Edge panel: zero-pad
353            for col in 0..cols {
354                for row in 0..mr_actual {
355                    packed[pack_idx + row] = a[(row_start + ir + row) * lda + col_start + col];
356                }
357                for row in mr_actual..MR_512 {
358                    packed[pack_idx + row] = 0.0;
359                }
360                pack_idx += MR_512;
361            }
362        }
363    }
364}
365
366/// Pack B block with NR_512=8 micro-panels for AVX-512 microkernel.
367/// B is row-major; each NR_512=8 column slice is contiguous → SIMD copy.
368/// Uses AVX2 _mm256_loadu_ps / _mm256_storeu_ps for full panels (8 f32 = 32B).
369#[cfg(target_arch = "x86_64")]
370pub(super) fn pack_b_block_512(
371    b: &[f32],
372    ldb: usize,
373    row_start: usize,
374    col_start: usize,
375    rows: usize,
376    cols: usize,
377    packed: &mut [f32],
378) {
379    let panels = (cols + NR_512 - 1) / NR_512;
380    let use_simd = is_x86_feature_detected!("avx2");
381
382    for panel in 0..panels {
383        let jr = panel * NR_512;
384        let nr_actual = NR_512.min(cols - jr);
385        let dst_base = panel * NR_512 * rows;
386
387        if nr_actual == NR_512 && use_simd {
388            // Full panel: SIMD 8-wide copy per row
389            // SAFETY: AVX2 verified above, src/dst aligned to f32.
390            unsafe {
391                use std::arch::x86_64::*;
392                for row in 0..rows {
393                    let src = b.as_ptr().add((row_start + row) * ldb + col_start + jr);
394                    let dst = packed.as_mut_ptr().add(dst_base + row * NR_512);
395                    _mm256_storeu_ps(dst, _mm256_loadu_ps(src));
396                }
397            }
398        } else {
399            // Edge panel or no SIMD: scalar with zero-pad
400            let mut pack_idx = dst_base;
401            for row in 0..rows {
402                for col in 0..NR_512 {
403                    if col < nr_actual {
404                        packed[pack_idx] = b[(row_start + row) * ldb + col_start + jr + col];
405                    } else {
406                        packed[pack_idx] = 0.0;
407                    }
408                    pack_idx += 1;
409                }
410            }
411        }
412    }
413}
414
415// ============================================================================
416// 32×6 Packing (Phase 4, Appendix D)
417// ============================================================================
418
419use super::{MR_512V2, NR_512V2};
420
421/// Compute required packed A buffer size for 32×6 microkernel.
422#[cfg(target_arch = "x86_64")]
423#[inline]
424#[allow(dead_code)] // Reserved for column-major 32×6 BLIS path
425pub fn packed_a_size_v2(mc: usize, kc: usize) -> usize {
426    let panels = (mc + MR_512V2 - 1) / MR_512V2;
427    panels * MR_512V2 * kc
428}
429
430/// Compute required packed B buffer size for 32×6 microkernel.
431#[cfg(target_arch = "x86_64")]
432#[inline]
433#[allow(dead_code)] // Reserved for column-major 32×6 BLIS path
434pub fn packed_b_size_v2(kc: usize, nc: usize) -> usize {
435    let panels = (nc + NR_512V2 - 1) / NR_512V2;
436    panels * NR_512V2 * kc
437}
438
439/// Pack A block with MR_512V2=32 micro-panels for 32×6 microkernel.
440#[cfg(target_arch = "x86_64")]
441#[allow(dead_code)] // Reserved for column-major 32×6 BLIS path
442pub(super) fn pack_a_block_v2(
443    a: &[f32],
444    lda: usize,
445    row_start: usize,
446    col_start: usize,
447    rows: usize,
448    cols: usize,
449    packed: &mut [f32],
450) {
451    let mut pack_idx = 0;
452    let panels = (rows + MR_512V2 - 1) / MR_512V2;
453
454    for panel in 0..panels {
455        let ir = panel * MR_512V2;
456        let mr_actual = MR_512V2.min(rows - ir);
457
458        for col in 0..cols {
459            for row in 0..mr_actual {
460                packed[pack_idx + row] = a[(row_start + ir + row) * lda + col_start + col];
461            }
462            for row in mr_actual..MR_512V2 {
463                packed[pack_idx + row] = 0.0;
464            }
465            pack_idx += MR_512V2;
466        }
467    }
468}
469
470/// Pack B block with NR_512V2=6 micro-panels for 32×6 microkernel.
471#[cfg(target_arch = "x86_64")]
472#[allow(dead_code)] // Reserved for column-major 32×6 BLIS path
473pub(super) fn pack_b_block_v2(
474    b: &[f32],
475    ldb: usize,
476    row_start: usize,
477    col_start: usize,
478    rows: usize,
479    cols: usize,
480    packed: &mut [f32],
481) {
482    let panels = (cols + NR_512V2 - 1) / NR_512V2;
483
484    for panel in 0..panels {
485        let jr = panel * NR_512V2;
486        let nr_actual = NR_512V2.min(cols - jr);
487        let dst_base = panel * NR_512V2 * rows;
488
489        for row in 0..rows {
490            let pack_idx = dst_base + row * NR_512V2;
491            for col in 0..nr_actual {
492                packed[pack_idx + col] = b[(row_start + row) * ldb + col_start + jr + col];
493            }
494            for col in nr_actual..NR_512V2 {
495                packed[pack_idx + col] = 0.0;
496            }
497        }
498    }
499}