Skip to main content

trueno/blis/microkernels/
avx512.rs

1//! AVX-512 SIMD Microkernels
2//!
3//! Two microkernel variants:
4//! - **16×8**: Original tile, 8 zmm accumulators. Used by gemm_blis_avx512_large.
5//! - **32×6**: Larger tile (Phase 4, Appendix D), 12 zmm accumulators.
6//!   Uses 2 rows of 16 f32 × 6 columns = 12 accumulators + 2 A loads.
7//!   1.5× more FMAs per K step than 16×8, better register utilization.
8//!
9//! Register allocation:
10//! - zmm0-zmm7: 8 columns of C (16 f32 each) = 128 outputs in registers
11//! - A column loaded per iteration, B broadcast from memory via vbroadcastss
12//!
13//! 4-way K-unrolled main loop hides 5-cycle FMA latency across 2 FMA ports.
14
15/// 16×8 AVX-512 microkernel — 4-way K-unrolled.
16/// A: 16×K packed column-major. B: K×8 packed row-major.
17/// C: 16×8 column-major with stride ldc.
18#[cfg(target_arch = "x86_64")]
19#[target_feature(enable = "avx512f")]
20pub unsafe fn microkernel_16x8_avx512(
21    k: usize,
22    a: *const f32,
23    b: *const f32,
24    c: *mut f32,
25    ldc: usize,
26) {
27    unsafe {
28        use std::arch::x86_64::*;
29
30        // Load C (8 columns of 16 elements)
31        let mut c0 = _mm512_loadu_ps(c);
32        let mut c1 = _mm512_loadu_ps(c.add(ldc));
33        let mut c2 = _mm512_loadu_ps(c.add(2 * ldc));
34        let mut c3 = _mm512_loadu_ps(c.add(3 * ldc));
35        let mut c4 = _mm512_loadu_ps(c.add(4 * ldc));
36        let mut c5 = _mm512_loadu_ps(c.add(5 * ldc));
37        let mut c6 = _mm512_loadu_ps(c.add(6 * ldc));
38        let mut c7 = _mm512_loadu_ps(c.add(7 * ldc));
39
40        let k4 = k / 4;
41        let k_rem = k % 4;
42
43        for p4 in 0..k4 {
44            let base = p4 * 4;
45
46            // K+0
47            let a0 = _mm512_loadu_ps(a.add(base * 16));
48            let bp0 = b.add(base * 8);
49            c0 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0), c0);
50            c1 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(1)), c1);
51            c2 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(2)), c2);
52            c3 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(3)), c3);
53            c4 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(4)), c4);
54            c5 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(5)), c5);
55            c6 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(6)), c6);
56            c7 = _mm512_fmadd_ps(a0, _mm512_set1_ps(*bp0.add(7)), c7);
57
58            // K+1
59            let a1 = _mm512_loadu_ps(a.add((base + 1) * 16));
60            let bp1 = b.add((base + 1) * 8);
61            c0 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1), c0);
62            c1 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(1)), c1);
63            c2 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(2)), c2);
64            c3 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(3)), c3);
65            c4 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(4)), c4);
66            c5 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(5)), c5);
67            c6 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(6)), c6);
68            c7 = _mm512_fmadd_ps(a1, _mm512_set1_ps(*bp1.add(7)), c7);
69
70            // K+2
71            let a2 = _mm512_loadu_ps(a.add((base + 2) * 16));
72            let bp2 = b.add((base + 2) * 8);
73            c0 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2), c0);
74            c1 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(1)), c1);
75            c2 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(2)), c2);
76            c3 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(3)), c3);
77            c4 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(4)), c4);
78            c5 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(5)), c5);
79            c6 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(6)), c6);
80            c7 = _mm512_fmadd_ps(a2, _mm512_set1_ps(*bp2.add(7)), c7);
81
82            // K+3
83            let a3 = _mm512_loadu_ps(a.add((base + 3) * 16));
84            let bp3 = b.add((base + 3) * 8);
85            c0 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3), c0);
86            c1 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(1)), c1);
87            c2 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(2)), c2);
88            c3 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(3)), c3);
89            c4 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(4)), c4);
90            c5 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(5)), c5);
91            c6 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(6)), c6);
92            c7 = _mm512_fmadd_ps(a3, _mm512_set1_ps(*bp3.add(7)), c7);
93        }
94
95        // Remainder
96        let base_rem = k4 * 4;
97        for p in 0..k_rem {
98            let pp = base_rem + p;
99            let a_col = _mm512_loadu_ps(a.add(pp * 16));
100            let bp = b.add(pp * 8);
101            c0 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp), c0);
102            c1 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(1)), c1);
103            c2 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(2)), c2);
104            c3 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(3)), c3);
105            c4 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(4)), c4);
106            c5 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(5)), c5);
107            c6 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(6)), c6);
108            c7 = _mm512_fmadd_ps(a_col, _mm512_set1_ps(*bp.add(7)), c7);
109        }
110
111        // Store C
112        _mm512_storeu_ps(c, c0);
113        _mm512_storeu_ps(c.add(ldc), c1);
114        _mm512_storeu_ps(c.add(2 * ldc), c2);
115        _mm512_storeu_ps(c.add(3 * ldc), c3);
116        _mm512_storeu_ps(c.add(4 * ldc), c4);
117        _mm512_storeu_ps(c.add(5 * ldc), c5);
118        _mm512_storeu_ps(c.add(6 * ldc), c6);
119        _mm512_storeu_ps(c.add(7 * ldc), c7);
120    }
121}
122
123/// 32×6 AVX-512 microkernel — 2-way K-unrolled.
124///
125/// A: 32×K packed column-major (two consecutive zmm rows per K step).
126/// B: K×6 packed row-major.
127/// C: 32×6 column-major with stride ldc (32 = 2 zmm rows).
128///
129/// Register allocation (14 of 32 zmm used):
130///   zmm0-zmm5:  row 0 accumulators (C[0..16, j] for j=0..6)
131///   zmm6-zmm11: row 1 accumulators (C[16..32, j] for j=0..6)
132///   zmm12-zmm13: A column loads (rows 0-15, 16-31)
133///   B: broadcast from memory via vbroadcastss (no register needed)
134///
135/// FMAs per K step: 12 (2 rows × 6 cols). With 2-way unroll: 24 FMAs.
136/// vs 16×8: 8 FMAs/step → 32 FMAs/4-unroll. This kernel: 1.5× more FMA/step.
137///
138/// Appendix D optimization #1: increase register utilization from 25% to 44%.
139#[cfg(target_arch = "x86_64")]
140#[target_feature(enable = "avx512f")]
141pub unsafe fn microkernel_32x6_avx512(
142    k: usize,
143    a: *const f32,
144    b: *const f32,
145    c: *mut f32,
146    ldc: usize,
147) {
148    unsafe {
149        use std::arch::x86_64::*;
150
151        // Load C: 2 zmm rows × 6 columns = 12 accumulators
152        let mut c00 = _mm512_loadu_ps(c);
153        let mut c01 = _mm512_loadu_ps(c.add(ldc));
154        let mut c02 = _mm512_loadu_ps(c.add(2 * ldc));
155        let mut c03 = _mm512_loadu_ps(c.add(3 * ldc));
156        let mut c04 = _mm512_loadu_ps(c.add(4 * ldc));
157        let mut c05 = _mm512_loadu_ps(c.add(5 * ldc));
158        let mut c10 = _mm512_loadu_ps(c.add(16));
159        let mut c11 = _mm512_loadu_ps(c.add(ldc + 16));
160        let mut c12 = _mm512_loadu_ps(c.add(2 * ldc + 16));
161        let mut c13 = _mm512_loadu_ps(c.add(3 * ldc + 16));
162        let mut c14 = _mm512_loadu_ps(c.add(4 * ldc + 16));
163        let mut c15 = _mm512_loadu_ps(c.add(5 * ldc + 16));
164
165        let k2 = k / 2;
166        let k_rem = k % 2;
167
168        // Main loop: 2-way K-unrolled
169        for p2 in 0..k2 {
170            let base = p2 * 2;
171
172            // K+0: load A row0 and row1
173            let a0_lo = _mm512_loadu_ps(a.add(base * 32));
174            let a0_hi = _mm512_loadu_ps(a.add(base * 32 + 16));
175            let bp0 = b.add(base * 6);
176
177            // 6 FMAs for row 0
178            let b0 = _mm512_set1_ps(*bp0);
179            c00 = _mm512_fmadd_ps(a0_lo, b0, c00);
180            c10 = _mm512_fmadd_ps(a0_hi, b0, c10);
181            let b1 = _mm512_set1_ps(*bp0.add(1));
182            c01 = _mm512_fmadd_ps(a0_lo, b1, c01);
183            c11 = _mm512_fmadd_ps(a0_hi, b1, c11);
184            let b2 = _mm512_set1_ps(*bp0.add(2));
185            c02 = _mm512_fmadd_ps(a0_lo, b2, c02);
186            c12 = _mm512_fmadd_ps(a0_hi, b2, c12);
187            let b3 = _mm512_set1_ps(*bp0.add(3));
188            c03 = _mm512_fmadd_ps(a0_lo, b3, c03);
189            c13 = _mm512_fmadd_ps(a0_hi, b3, c13);
190            let b4 = _mm512_set1_ps(*bp0.add(4));
191            c04 = _mm512_fmadd_ps(a0_lo, b4, c04);
192            c14 = _mm512_fmadd_ps(a0_hi, b4, c14);
193            let b5 = _mm512_set1_ps(*bp0.add(5));
194            c05 = _mm512_fmadd_ps(a0_lo, b5, c05);
195            c15 = _mm512_fmadd_ps(a0_hi, b5, c15);
196
197            // K+1
198            let a1_lo = _mm512_loadu_ps(a.add((base + 1) * 32));
199            let a1_hi = _mm512_loadu_ps(a.add((base + 1) * 32 + 16));
200            let bp1 = b.add((base + 1) * 6);
201
202            let b0 = _mm512_set1_ps(*bp1);
203            c00 = _mm512_fmadd_ps(a1_lo, b0, c00);
204            c10 = _mm512_fmadd_ps(a1_hi, b0, c10);
205            let b1 = _mm512_set1_ps(*bp1.add(1));
206            c01 = _mm512_fmadd_ps(a1_lo, b1, c01);
207            c11 = _mm512_fmadd_ps(a1_hi, b1, c11);
208            let b2 = _mm512_set1_ps(*bp1.add(2));
209            c02 = _mm512_fmadd_ps(a1_lo, b2, c02);
210            c12 = _mm512_fmadd_ps(a1_hi, b2, c12);
211            let b3 = _mm512_set1_ps(*bp1.add(3));
212            c03 = _mm512_fmadd_ps(a1_lo, b3, c03);
213            c13 = _mm512_fmadd_ps(a1_hi, b3, c13);
214            let b4 = _mm512_set1_ps(*bp1.add(4));
215            c04 = _mm512_fmadd_ps(a1_lo, b4, c04);
216            c14 = _mm512_fmadd_ps(a1_hi, b4, c14);
217            let b5 = _mm512_set1_ps(*bp1.add(5));
218            c05 = _mm512_fmadd_ps(a1_lo, b5, c05);
219            c15 = _mm512_fmadd_ps(a1_hi, b5, c15);
220        }
221
222        // Remainder
223        let base_rem = k2 * 2;
224        for p in 0..k_rem {
225            let pp = base_rem + p;
226            let a_lo = _mm512_loadu_ps(a.add(pp * 32));
227            let a_hi = _mm512_loadu_ps(a.add(pp * 32 + 16));
228            let bp = b.add(pp * 6);
229            let b0 = _mm512_set1_ps(*bp);
230            c00 = _mm512_fmadd_ps(a_lo, b0, c00);
231            c10 = _mm512_fmadd_ps(a_hi, b0, c10);
232            let b1 = _mm512_set1_ps(*bp.add(1));
233            c01 = _mm512_fmadd_ps(a_lo, b1, c01);
234            c11 = _mm512_fmadd_ps(a_hi, b1, c11);
235            let b2 = _mm512_set1_ps(*bp.add(2));
236            c02 = _mm512_fmadd_ps(a_lo, b2, c02);
237            c12 = _mm512_fmadd_ps(a_hi, b2, c12);
238            let b3 = _mm512_set1_ps(*bp.add(3));
239            c03 = _mm512_fmadd_ps(a_lo, b3, c03);
240            c13 = _mm512_fmadd_ps(a_hi, b3, c13);
241            let b4 = _mm512_set1_ps(*bp.add(4));
242            c04 = _mm512_fmadd_ps(a_lo, b4, c04);
243            c14 = _mm512_fmadd_ps(a_hi, b4, c14);
244            let b5 = _mm512_set1_ps(*bp.add(5));
245            c05 = _mm512_fmadd_ps(a_lo, b5, c05);
246            c15 = _mm512_fmadd_ps(a_hi, b5, c15);
247        }
248
249        // Store C: 2 rows × 6 columns
250        _mm512_storeu_ps(c, c00);
251        _mm512_storeu_ps(c.add(ldc), c01);
252        _mm512_storeu_ps(c.add(2 * ldc), c02);
253        _mm512_storeu_ps(c.add(3 * ldc), c03);
254        _mm512_storeu_ps(c.add(4 * ldc), c04);
255        _mm512_storeu_ps(c.add(5 * ldc), c05);
256        _mm512_storeu_ps(c.add(16), c10);
257        _mm512_storeu_ps(c.add(ldc + 16), c11);
258        _mm512_storeu_ps(c.add(2 * ldc + 16), c12);
259        _mm512_storeu_ps(c.add(3 * ldc + 16), c13);
260        _mm512_storeu_ps(c.add(4 * ldc + 16), c14);
261        _mm512_storeu_ps(c.add(5 * ldc + 16), c15);
262    }
263}