1use proc_macro::TokenStream;
17use proc_macro2::TokenStream as TokenStream2;
18use quote::{format_ident, quote};
19use syn::parse::{Parse, ParseStream};
20use syn::{LitInt, Token};
21
22struct MicrokernelParams {
24 mr: usize,
25 nr: usize,
26}
27
28impl Parse for MicrokernelParams {
29 fn parse(input: ParseStream) -> syn::Result<Self> {
30 let _mr_ident: syn::Ident = input.parse()?;
32 let _eq: Token![=] = input.parse()?;
33 let mr_lit: LitInt = input.parse()?;
34 let _comma: Token![,] = input.parse()?;
35 let _nr_ident: syn::Ident = input.parse()?;
36 let _eq2: Token![=] = input.parse()?;
37 let nr_lit: LitInt = input.parse()?;
38
39 Ok(MicrokernelParams {
40 mr: mr_lit.base10_parse()?,
41 nr: nr_lit.base10_parse()?,
42 })
43 }
44}
45
46#[proc_macro]
59pub fn avx512_microkernel(input: TokenStream) -> TokenStream {
60 let params = syn::parse_macro_input!(input as MicrokernelParams);
61 let mr = params.mr;
62 let nr = params.nr;
63
64 let b_regs = nr.div_ceil(16); let acc_count = mr * b_regs; let total_regs = acc_count + b_regs + 4; if total_regs > 32 {
69 return syn::Error::new(
70 proc_macro2::Span::call_site(),
71 format!(
72 "C-CODEGEN-004: {mr}x{nr} needs {total_regs} zmm registers (max 32). \
73 Reduce MR or NR. accumulators={acc_count}, B_loads={b_regs}"
74 ),
75 )
76 .to_compile_error()
77 .into();
78 }
79
80 let fn_name = format_ident!("microkernel_{}x{}_avx512_gen", mr, nr);
81
82 let mut acc_idents = Vec::new();
84 for row in 0..mr {
85 for h in 0..b_regs {
86 acc_idents.push(format_ident!("c{}_{}", row, h));
87 }
88 }
89
90 let c_loads = generate_c_loads(mr, b_regs, &acc_idents);
92
93 let k_body = generate_k_body(mr, nr, b_regs, &acc_idents);
95
96 let c_stores = generate_c_stores(mr, b_regs, &acc_idents);
98
99 let doc = format!(
100 "Generated {mr}x{nr} AVX-512 microkernel ({acc_count} zmm accumulators, \
101 {b_regs} B loads, {} FMAs/K-step). Contract: cgp-gemm-codegen-v1.yaml.",
102 mr * b_regs
103 );
104
105 let output = quote! {
106 #[doc = #doc]
107 #[cfg(target_arch = "x86_64")]
108 #[target_feature(enable = "avx512f", enable = "fma")]
109 pub unsafe fn #fn_name(
110 k: usize,
111 a: *const f32,
112 b: *const f32,
113 c: *mut f32,
114 ldc: usize,
115 ) {
116 use std::arch::x86_64::*;
117
118 #(#c_loads)*
120
121 for p in 0..k {
123 #(#k_body)*
124 }
125
126 #(#c_stores)*
128 }
129 };
130
131 output.into()
132}
133
134fn generate_c_loads(
136 mr: usize,
137 b_regs: usize,
138 acc_idents: &[proc_macro2::Ident],
139) -> Vec<TokenStream2> {
140 let mut loads = Vec::new();
141 for row in 0..mr {
142 for h in 0..b_regs {
143 let ident = &acc_idents[row * b_regs + h];
144 let offset = if row == 0 && h == 0 {
145 quote! { c }
146 } else if h == 0 {
147 let row_val = row;
148 quote! { c.add(#row_val * ldc) }
149 } else {
150 let row_val = row;
151 let col_offset = h * 16;
152 quote! { c.add(#row_val * ldc + #col_offset) }
153 };
154 loads.push(quote! {
155 let mut #ident = _mm512_loadu_ps(#offset);
156 });
157 }
158 }
159 loads
160}
161
162fn generate_k_body(
164 mr: usize,
165 nr: usize,
166 b_regs: usize,
167 acc_idents: &[proc_macro2::Ident],
168) -> Vec<TokenStream2> {
169 let mut body = Vec::new();
170
171 let nr_val = nr;
173 for h in 0..b_regs {
174 let b_ident = format_ident!("b{}", h);
175 let offset = h * 16;
176 body.push(quote! {
177 let #b_ident = _mm512_loadu_ps(b.add(p * #nr_val + #offset));
178 });
179 }
180
181 let mr_val = mr;
183 for row in 0..mr {
184 let a_ident = format_ident!("a{}", row);
185 body.push(quote! {
186 let #a_ident = _mm512_set1_ps(*a.add(p * #mr_val + #row));
187 });
188 for h in 0..b_regs {
189 let c_ident = &acc_idents[row * b_regs + h];
190 let b_ident = format_ident!("b{}", h);
191 body.push(quote! {
192 #c_ident = _mm512_fmadd_ps(#a_ident, #b_ident, #c_ident);
193 });
194 }
195 }
196
197 body
198}
199
200#[proc_macro]
215pub fn avx512_microkernel_broadcast_b(input: TokenStream) -> TokenStream {
216 let params = syn::parse_macro_input!(input as MicrokernelParams);
217 let mr = params.mr;
218 let nr = params.nr;
219
220 if mr % 16 != 0 {
221 return syn::Error::new(
222 proc_macro2::Span::call_site(),
223 format!("broadcast-B requires MR divisible by 16, got MR={mr}"),
224 )
225 .to_compile_error()
226 .into();
227 }
228
229 let a_regs = mr / 16; let acc_count = a_regs * nr; let total_regs = acc_count + a_regs + 4; if total_regs > 32 {
234 return syn::Error::new(
235 proc_macro2::Span::call_site(),
236 format!(
237 "broadcast-B {mr}x{nr} needs {total_regs} zmm registers (max 32). \
238 Reduce MR or NR. accumulators={acc_count}, A_loads={a_regs}"
239 ),
240 )
241 .to_compile_error()
242 .into();
243 }
244
245 let fn_name = format_ident!("microkernel_{}x{}_avx512_bcast_b", mr, nr);
246
247 let mut acc_idents = Vec::new();
249 for col in 0..nr {
250 for chunk in 0..a_regs {
251 acc_idents.push(format_ident!("c{}_{}", chunk, col));
252 }
253 }
254
255 let c_loads = generate_bcast_b_c_loads(mr, nr, a_regs, &acc_idents);
257
258 let k_body = generate_bcast_b_k_body(mr, nr, a_regs, &acc_idents);
260
261 let c_stores = generate_bcast_b_c_stores(mr, nr, a_regs, &acc_idents);
263
264 let doc = format!(
265 "Generated {mr}x{nr} AVX-512 broadcast-B microkernel ({acc_count} zmm accumulators, \
266 {a_regs} A loads, {nr} B broadcasts, {} FMAs/K-step). \
267 Faer-style: small NR keeps B panel tiny, large KC.",
268 a_regs * nr
269 );
270
271 let output = quote! {
272 #[doc = #doc]
273 #[cfg(target_arch = "x86_64")]
274 #[target_feature(enable = "avx512f", enable = "fma")]
275 pub unsafe fn #fn_name(
276 k: usize,
277 a: *const f32,
278 b: *const f32,
279 c: *mut f32,
280 ldc: usize,
281 ) {
282 use std::arch::x86_64::*;
283
284 #(#c_loads)*
286
287 for p in 0..k {
289 #(#k_body)*
290 }
291
292 #(#c_stores)*
294 }
295 };
296
297 output.into()
298}
299
300fn generate_bcast_b_c_loads(
320 _mr: usize,
321 nr: usize,
322 a_regs: usize,
323 acc_idents: &[proc_macro2::Ident],
324) -> Vec<TokenStream2> {
325 let mut loads = Vec::new();
326 for col in 0..nr {
329 for chunk in 0..a_regs {
330 let ident = &acc_idents[col * a_regs + chunk];
331 loads.push(quote! {
332 let mut #ident = _mm512_setzero_ps();
333 });
334 }
335 }
336 loads
337}
338
339fn generate_bcast_b_k_body(
341 mr: usize,
342 nr: usize,
343 a_regs: usize,
344 acc_idents: &[proc_macro2::Ident],
345) -> Vec<TokenStream2> {
346 let mut body = Vec::new();
347 let mr_val = mr;
348 let nr_val = nr;
349
350 for chunk in 0..a_regs {
352 let a_ident = format_ident!("a{}", chunk);
353 let offset = chunk * 16;
354 body.push(quote! {
355 let #a_ident = _mm512_loadu_ps(a.add(p * #mr_val + #offset));
356 });
357 }
358
359 for col in 0..nr {
361 let b_ident = format_ident!("b{}", col);
362 body.push(quote! {
363 let #b_ident = _mm512_set1_ps(*b.add(p * #nr_val + #col));
364 });
365 for chunk in 0..a_regs {
366 let c_ident = &acc_idents[col * a_regs + chunk];
367 let a_ident = format_ident!("a{}", chunk);
368 body.push(quote! {
369 #c_ident = _mm512_fmadd_ps(#a_ident, #b_ident, #c_ident);
370 });
371 }
372 }
373
374 body
375}
376
377fn generate_bcast_b_c_stores(
384 mr: usize,
385 nr: usize,
386 a_regs: usize,
387 acc_idents: &[proc_macro2::Ident],
388) -> Vec<TokenStream2> {
389 let mut stores = Vec::new();
390
391 for col in 0..nr {
394 for chunk in 0..a_regs {
395 let ident = &acc_idents[col * a_regs + chunk];
396 let base_row = chunk * 16;
397
398 let mut scatter_stmts = Vec::new();
400 for i in 0..16 {
401 let row = base_row + i;
402 if row < mr {
403 scatter_stmts.push(quote! {
404 *c.add(#row * ldc + #col) += tmp[#i];
405 });
406 }
407 }
408
409 stores.push(quote! {
410 {
411 let mut tmp = [0.0f32; 16];
412 _mm512_storeu_ps(tmp.as_mut_ptr(), #ident);
413 #(#scatter_stmts)*
414 }
415 });
416 }
417 }
418
419 stores
420}
421
422fn generate_c_stores(
424 mr: usize,
425 b_regs: usize,
426 acc_idents: &[proc_macro2::Ident],
427) -> Vec<TokenStream2> {
428 let mut stores = Vec::new();
429 for row in 0..mr {
430 for h in 0..b_regs {
431 let ident = &acc_idents[row * b_regs + h];
432 let offset = if row == 0 && h == 0 {
433 quote! { c }
434 } else if h == 0 {
435 let row_val = row;
436 quote! { c.add(#row_val * ldc) }
437 } else {
438 let row_val = row;
439 let col_offset = h * 16;
440 quote! { c.add(#row_val * ldc + #col_offset) }
441 };
442 stores.push(quote! {
443 _mm512_storeu_ps(#offset, #ident);
444 });
445 }
446 }
447 stores
448}