Skip to main content

aprender_gemm_codegen/
lib.rs

1//! trueno-gemm-codegen: Compile-time GEMM microkernel code generation.
2//!
3//! Contract: cgp-gemm-codegen-v1.yaml (C-CODEGEN-001 through C-CODEGEN-004)
4//!
5//! Generates shape-specialized AVX-512 microkernels at compile time via proc macros.
6//! Sovereign implementation — no external BLAS dependencies.
7//!
8//! # Usage
9//! ```ignore
10//! use trueno_gemm_codegen::avx512_microkernel;
11//!
12//! avx512_microkernel!(mr = 8, nr = 32);
13//! // Generates: pub unsafe fn microkernel_8x32_avx512_gen(k, a, b, c, ldc)
14//! ```
15
16use proc_macro::TokenStream;
17use proc_macro2::TokenStream as TokenStream2;
18use quote::{format_ident, quote};
19use syn::parse::{Parse, ParseStream};
20use syn::{LitInt, Token};
21
22/// Parameters for microkernel generation.
23struct MicrokernelParams {
24    mr: usize,
25    nr: usize,
26}
27
28impl Parse for MicrokernelParams {
29    fn parse(input: ParseStream) -> syn::Result<Self> {
30        // Parse: mr = N, nr = M
31        let _mr_ident: syn::Ident = input.parse()?;
32        let _eq: Token![=] = input.parse()?;
33        let mr_lit: LitInt = input.parse()?;
34        let _comma: Token![,] = input.parse()?;
35        let _nr_ident: syn::Ident = input.parse()?;
36        let _eq2: Token![=] = input.parse()?;
37        let nr_lit: LitInt = input.parse()?;
38
39        Ok(MicrokernelParams {
40            mr: mr_lit.base10_parse()?,
41            nr: nr_lit.base10_parse()?,
42        })
43    }
44}
45
46/// Generate an AVX-512 row-major C microkernel.
47///
48/// Layout: A is MR×K packed column-major, B is K×NR packed row-major,
49/// C is MR×NR row-major with stride `ldc`.
50///
51/// Strategy for row-major C (MR rows, NR columns):
52/// - Each C row spans ceil(NR/16) zmm registers
53/// - Total accumulators = MR * ceil(NR/16)
54/// - Per K step: load ceil(NR/16) B zmm, broadcast MR A scalars, MR*ceil(NR/16) FMAs
55///
56/// Register budget check (C-CODEGEN-004):
57///   accumulators + B loads + headroom <= 32 zmm
58#[proc_macro]
59pub fn avx512_microkernel(input: TokenStream) -> TokenStream {
60    let params = syn::parse_macro_input!(input as MicrokernelParams);
61    let mr = params.mr;
62    let nr = params.nr;
63
64    let b_regs = nr.div_ceil(16); // ceil(NR/16) zmm registers for B
65    let acc_count = mr * b_regs; // Total accumulator registers
66    let total_regs = acc_count + b_regs + 4; // +4 headroom for A broadcasts
67
68    if total_regs > 32 {
69        return syn::Error::new(
70            proc_macro2::Span::call_site(),
71            format!(
72                "C-CODEGEN-004: {mr}x{nr} needs {total_regs} zmm registers (max 32). \
73                 Reduce MR or NR. accumulators={acc_count}, B_loads={b_regs}"
74            ),
75        )
76        .to_compile_error()
77        .into();
78    }
79
80    let fn_name = format_ident!("microkernel_{}x{}_avx512_gen", mr, nr);
81
82    // Generate accumulator identifiers: c{row}_{half}
83    let mut acc_idents = Vec::new();
84    for row in 0..mr {
85        for h in 0..b_regs {
86            acc_idents.push(format_ident!("c{}_{}", row, h));
87        }
88    }
89
90    // Generate C load statements
91    let c_loads = generate_c_loads(mr, b_regs, &acc_idents);
92
93    // Generate the inner K-loop body
94    let k_body = generate_k_body(mr, nr, b_regs, &acc_idents);
95
96    // Generate C store statements
97    let c_stores = generate_c_stores(mr, b_regs, &acc_idents);
98
99    let doc = format!(
100        "Generated {mr}x{nr} AVX-512 microkernel ({acc_count} zmm accumulators, \
101         {b_regs} B loads, {} FMAs/K-step). Contract: cgp-gemm-codegen-v1.yaml.",
102        mr * b_regs
103    );
104
105    let output = quote! {
106        #[doc = #doc]
107        #[cfg(target_arch = "x86_64")]
108        #[target_feature(enable = "avx512f", enable = "fma")]
109        pub unsafe fn #fn_name(
110            k: usize,
111            a: *const f32,
112            b: *const f32,
113            c: *mut f32,
114            ldc: usize,
115        ) {
116            use std::arch::x86_64::*;
117
118            // Load C accumulators
119            #(#c_loads)*
120
121            // Main K loop
122            for p in 0..k {
123                #(#k_body)*
124            }
125
126            // Store C accumulators
127            #(#c_stores)*
128        }
129    };
130
131    output.into()
132}
133
134/// Generate C load statements for all accumulators.
135fn generate_c_loads(
136    mr: usize,
137    b_regs: usize,
138    acc_idents: &[proc_macro2::Ident],
139) -> Vec<TokenStream2> {
140    let mut loads = Vec::new();
141    for row in 0..mr {
142        for h in 0..b_regs {
143            let ident = &acc_idents[row * b_regs + h];
144            let offset = if row == 0 && h == 0 {
145                quote! { c }
146            } else if h == 0 {
147                let row_val = row;
148                quote! { c.add(#row_val * ldc) }
149            } else {
150                let row_val = row;
151                let col_offset = h * 16;
152                quote! { c.add(#row_val * ldc + #col_offset) }
153            };
154            loads.push(quote! {
155                let mut #ident = _mm512_loadu_ps(#offset);
156            });
157        }
158    }
159    loads
160}
161
162/// Generate the inner K-loop body.
163fn generate_k_body(
164    mr: usize,
165    nr: usize,
166    b_regs: usize,
167    acc_idents: &[proc_macro2::Ident],
168) -> Vec<TokenStream2> {
169    let mut body = Vec::new();
170
171    // Load B registers
172    let nr_val = nr;
173    for h in 0..b_regs {
174        let b_ident = format_ident!("b{}", h);
175        let offset = h * 16;
176        body.push(quote! {
177            let #b_ident = _mm512_loadu_ps(b.add(p * #nr_val + #offset));
178        });
179    }
180
181    // Broadcast A and FMA for each row
182    let mr_val = mr;
183    for row in 0..mr {
184        let a_ident = format_ident!("a{}", row);
185        body.push(quote! {
186            let #a_ident = _mm512_set1_ps(*a.add(p * #mr_val + #row));
187        });
188        for h in 0..b_regs {
189            let c_ident = &acc_idents[row * b_regs + h];
190            let b_ident = format_ident!("b{}", h);
191            body.push(quote! {
192                #c_ident = _mm512_fmadd_ps(#a_ident, #b_ident, #c_ident);
193            });
194        }
195    }
196
197    body
198}
199
200/// Generate an AVX-512 broadcast-B microkernel (faer-style).
201///
202/// Layout: A is MR×K packed column-major (MR contiguous per K step),
203/// B is K×NR packed row-major (NR contiguous per K step).
204/// C is MR×NR with stride `ldc`, stored in MR/16 zmm chunks per column.
205///
206/// Strategy (broadcast-B):
207/// - Each K step: load MR/16 zmm from A, broadcast NR B scalars
208/// - Each accumulator holds 16 elements of one column of C
209/// - Total accumulators = (MR/16) × NR
210/// - Per K step: MR/16 A loads + NR B broadcasts + (MR/16)*NR FMAs
211///
212/// Advantage over broadcast-A: NR can be small (6), keeping B panel tiny,
213/// allowing KC to stay large (256+). This matches faer's nano-gemm approach.
214#[proc_macro]
215pub fn avx512_microkernel_broadcast_b(input: TokenStream) -> TokenStream {
216    let params = syn::parse_macro_input!(input as MicrokernelParams);
217    let mr = params.mr;
218    let nr = params.nr;
219
220    if mr % 16 != 0 {
221        return syn::Error::new(
222            proc_macro2::Span::call_site(),
223            format!("broadcast-B requires MR divisible by 16, got MR={mr}"),
224        )
225        .to_compile_error()
226        .into();
227    }
228
229    let a_regs = mr / 16; // zmm registers for A loads
230    let acc_count = a_regs * nr; // Total accumulator registers
231    let total_regs = acc_count + a_regs + 4; // +4 headroom for B broadcasts
232
233    if total_regs > 32 {
234        return syn::Error::new(
235            proc_macro2::Span::call_site(),
236            format!(
237                "broadcast-B {mr}x{nr} needs {total_regs} zmm registers (max 32). \
238                 Reduce MR or NR. accumulators={acc_count}, A_loads={a_regs}"
239            ),
240        )
241        .to_compile_error()
242        .into();
243    }
244
245    let fn_name = format_ident!("microkernel_{}x{}_avx512_bcast_b", mr, nr);
246
247    // Accumulator identifiers: c{a_chunk}_{col}
248    let mut acc_idents = Vec::new();
249    for col in 0..nr {
250        for chunk in 0..a_regs {
251            acc_idents.push(format_ident!("c{}_{}", chunk, col));
252        }
253    }
254
255    // C load: load MR/16 zmm per column, NR columns
256    let c_loads = generate_bcast_b_c_loads(mr, nr, a_regs, &acc_idents);
257
258    // K-loop body
259    let k_body = generate_bcast_b_k_body(mr, nr, a_regs, &acc_idents);
260
261    // C store
262    let c_stores = generate_bcast_b_c_stores(mr, nr, a_regs, &acc_idents);
263
264    let doc = format!(
265        "Generated {mr}x{nr} AVX-512 broadcast-B microkernel ({acc_count} zmm accumulators, \
266         {a_regs} A loads, {nr} B broadcasts, {} FMAs/K-step). \
267         Faer-style: small NR keeps B panel tiny, large KC.",
268        a_regs * nr
269    );
270
271    let output = quote! {
272        #[doc = #doc]
273        #[cfg(target_arch = "x86_64")]
274        #[target_feature(enable = "avx512f", enable = "fma")]
275        pub unsafe fn #fn_name(
276            k: usize,
277            a: *const f32,
278            b: *const f32,
279            c: *mut f32,
280            ldc: usize,
281        ) {
282            use std::arch::x86_64::*;
283
284            // Load C accumulators (column-major within tile)
285            #(#c_loads)*
286
287            // Main K loop
288            for p in 0..k {
289                #(#k_body)*
290            }
291
292            // Store C accumulators
293            #(#c_stores)*
294        }
295    };
296
297    output.into()
298}
299
300/// Generate C loads for broadcast-B: each column j has MR/16 zmm registers.
301/// C layout: row-major with stride ldc. C[i][j] = c[i*ldc + j].
302/// For column j, chunk c: load C[c*16..c*16+15][j] — but C is row-major,
303/// so we need 16 individual loads and a gather. Instead, store the tile in
304/// a transposed layout: accumulate as column-major, then scatter-store at end.
305///
306/// Actually, for simplicity and performance, we'll load/store C row-major:
307/// acc[chunk][col] holds C[chunk*16 + i][col] for i=0..15 — but that requires
308/// gather/scatter since C rows are ldc apart.
309///
310/// Better approach: just use scalar loads into accumulators, accumulate, scalar store.
311/// No — the whole point is SIMD accumulation.
312///
313/// faer approach: C tile is stored in a local buffer (column-major), then written
314/// back to C row-major at the end. This avoids gather/scatter in the hot loop.
315///
316/// We'll follow faer: accumulate in registers (column-major tile), then write
317/// back to row-major C via scalar scatter at the end (NR is small, so scatter cost
318/// is amortized over many K iterations).
319fn generate_bcast_b_c_loads(
320    _mr: usize,
321    nr: usize,
322    a_regs: usize,
323    acc_idents: &[proc_macro2::Ident],
324) -> Vec<TokenStream2> {
325    let mut loads = Vec::new();
326    // Initialize accumulators to zero (we'll add existing C values at the end)
327    // This is simpler and avoids gather in the hot setup path.
328    for col in 0..nr {
329        for chunk in 0..a_regs {
330            let ident = &acc_idents[col * a_regs + chunk];
331            loads.push(quote! {
332                let mut #ident = _mm512_setzero_ps();
333            });
334        }
335    }
336    loads
337}
338
339/// Generate K-loop body for broadcast-B.
340fn generate_bcast_b_k_body(
341    mr: usize,
342    nr: usize,
343    a_regs: usize,
344    acc_idents: &[proc_macro2::Ident],
345) -> Vec<TokenStream2> {
346    let mut body = Vec::new();
347    let mr_val = mr;
348    let nr_val = nr;
349
350    // Load A: MR/16 zmm registers
351    for chunk in 0..a_regs {
352        let a_ident = format_ident!("a{}", chunk);
353        let offset = chunk * 16;
354        body.push(quote! {
355            let #a_ident = _mm512_loadu_ps(a.add(p * #mr_val + #offset));
356        });
357    }
358
359    // For each B column: broadcast B[k][j], FMA with all A chunks
360    for col in 0..nr {
361        let b_ident = format_ident!("b{}", col);
362        body.push(quote! {
363            let #b_ident = _mm512_set1_ps(*b.add(p * #nr_val + #col));
364        });
365        for chunk in 0..a_regs {
366            let c_ident = &acc_idents[col * a_regs + chunk];
367            let a_ident = format_ident!("a{}", chunk);
368            body.push(quote! {
369                #c_ident = _mm512_fmadd_ps(#a_ident, #b_ident, #c_ident);
370            });
371        }
372    }
373
374    body
375}
376
377/// Generate C store statements for broadcast-B.
378/// Accumulators are column-major: acc[col][chunk] holds 16 contiguous rows.
379/// C is row-major: C[i][j] = c[i*ldc + j].
380/// We must scatter: for each row i in chunk, store acc element to C[row][col].
381/// Since NR is small (6), we extract each f32 and store individually.
382/// This is the scatter cost — amortized over K iterations (typically 128-256).
383fn generate_bcast_b_c_stores(
384    mr: usize,
385    nr: usize,
386    a_regs: usize,
387    acc_idents: &[proc_macro2::Ident],
388) -> Vec<TokenStream2> {
389    let mut stores = Vec::new();
390
391    // For each accumulator, extract 16 f32 values and add to C row-major.
392    // Use _mm512_storeu_ps to a temp buffer, then scatter to C.
393    for col in 0..nr {
394        for chunk in 0..a_regs {
395            let ident = &acc_idents[col * a_regs + chunk];
396            let base_row = chunk * 16;
397
398            // Build scatter indices for this chunk
399            let mut scatter_stmts = Vec::new();
400            for i in 0..16 {
401                let row = base_row + i;
402                if row < mr {
403                    scatter_stmts.push(quote! {
404                        *c.add(#row * ldc + #col) += tmp[#i];
405                    });
406                }
407            }
408
409            stores.push(quote! {
410                {
411                    let mut tmp = [0.0f32; 16];
412                    _mm512_storeu_ps(tmp.as_mut_ptr(), #ident);
413                    #(#scatter_stmts)*
414                }
415            });
416        }
417    }
418
419    stores
420}
421
422/// Generate C store statements.
423fn generate_c_stores(
424    mr: usize,
425    b_regs: usize,
426    acc_idents: &[proc_macro2::Ident],
427) -> Vec<TokenStream2> {
428    let mut stores = Vec::new();
429    for row in 0..mr {
430        for h in 0..b_regs {
431            let ident = &acc_idents[row * b_regs + h];
432            let offset = if row == 0 && h == 0 {
433                quote! { c }
434            } else if h == 0 {
435                let row_val = row;
436                quote! { c.add(#row_val * ldc) }
437            } else {
438                let row_val = row;
439                let col_offset = h * 16;
440                quote! { c.add(#row_val * ldc + #col_offset) }
441            };
442            stores.push(quote! {
443                _mm512_storeu_ps(#offset, #ident);
444            });
445        }
446    }
447    stores
448}