use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::{LitInt, Token};
struct MicrokernelParams {
mr: usize,
nr: usize,
}
impl Parse for MicrokernelParams {
fn parse(input: ParseStream) -> syn::Result<Self> {
let _mr_ident: syn::Ident = input.parse()?;
let _eq: Token![=] = input.parse()?;
let mr_lit: LitInt = input.parse()?;
let _comma: Token![,] = input.parse()?;
let _nr_ident: syn::Ident = input.parse()?;
let _eq2: Token![=] = input.parse()?;
let nr_lit: LitInt = input.parse()?;
Ok(MicrokernelParams {
mr: mr_lit.base10_parse()?,
nr: nr_lit.base10_parse()?,
})
}
}
#[proc_macro]
pub fn avx512_microkernel(input: TokenStream) -> TokenStream {
let params = syn::parse_macro_input!(input as MicrokernelParams);
let mr = params.mr;
let nr = params.nr;
let b_regs = nr.div_ceil(16); let acc_count = mr * b_regs; let total_regs = acc_count + b_regs + 4;
if total_regs > 32 {
return syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"C-CODEGEN-004: {mr}x{nr} needs {total_regs} zmm registers (max 32). \
Reduce MR or NR. accumulators={acc_count}, B_loads={b_regs}"
),
)
.to_compile_error()
.into();
}
let fn_name = format_ident!("microkernel_{}x{}_avx512_gen", mr, nr);
let mut acc_idents = Vec::new();
for row in 0..mr {
for h in 0..b_regs {
acc_idents.push(format_ident!("c{}_{}", row, h));
}
}
let c_loads = generate_c_loads(mr, b_regs, &acc_idents);
let k_body = generate_k_body(mr, nr, b_regs, &acc_idents);
let c_stores = generate_c_stores(mr, b_regs, &acc_idents);
let doc = format!(
"Generated {mr}x{nr} AVX-512 microkernel ({acc_count} zmm accumulators, \
{b_regs} B loads, {} FMAs/K-step). Contract: cgp-gemm-codegen-v1.yaml.",
mr * b_regs
);
let output = quote! {
#[doc = #doc]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "fma")]
pub unsafe fn #fn_name(
k: usize,
a: *const f32,
b: *const f32,
c: *mut f32,
ldc: usize,
) {
use std::arch::x86_64::*;
#(#c_loads)*
for p in 0..k {
#(#k_body)*
}
#(#c_stores)*
}
};
output.into()
}
fn generate_c_loads(
mr: usize,
b_regs: usize,
acc_idents: &[proc_macro2::Ident],
) -> Vec<TokenStream2> {
let mut loads = Vec::new();
for row in 0..mr {
for h in 0..b_regs {
let ident = &acc_idents[row * b_regs + h];
let offset = if row == 0 && h == 0 {
quote! { c }
} else if h == 0 {
let row_val = row;
quote! { c.add(#row_val * ldc) }
} else {
let row_val = row;
let col_offset = h * 16;
quote! { c.add(#row_val * ldc + #col_offset) }
};
loads.push(quote! {
let mut #ident = _mm512_loadu_ps(#offset);
});
}
}
loads
}
fn generate_k_body(
mr: usize,
nr: usize,
b_regs: usize,
acc_idents: &[proc_macro2::Ident],
) -> Vec<TokenStream2> {
let mut body = Vec::new();
let nr_val = nr;
for h in 0..b_regs {
let b_ident = format_ident!("b{}", h);
let offset = h * 16;
body.push(quote! {
let #b_ident = _mm512_loadu_ps(b.add(p * #nr_val + #offset));
});
}
let mr_val = mr;
for row in 0..mr {
let a_ident = format_ident!("a{}", row);
body.push(quote! {
let #a_ident = _mm512_set1_ps(*a.add(p * #mr_val + #row));
});
for h in 0..b_regs {
let c_ident = &acc_idents[row * b_regs + h];
let b_ident = format_ident!("b{}", h);
body.push(quote! {
#c_ident = _mm512_fmadd_ps(#a_ident, #b_ident, #c_ident);
});
}
}
body
}
#[proc_macro]
pub fn avx512_microkernel_broadcast_b(input: TokenStream) -> TokenStream {
let params = syn::parse_macro_input!(input as MicrokernelParams);
let mr = params.mr;
let nr = params.nr;
if mr % 16 != 0 {
return syn::Error::new(
proc_macro2::Span::call_site(),
format!("broadcast-B requires MR divisible by 16, got MR={mr}"),
)
.to_compile_error()
.into();
}
let a_regs = mr / 16; let acc_count = a_regs * nr; let total_regs = acc_count + a_regs + 4;
if total_regs > 32 {
return syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"broadcast-B {mr}x{nr} needs {total_regs} zmm registers (max 32). \
Reduce MR or NR. accumulators={acc_count}, A_loads={a_regs}"
),
)
.to_compile_error()
.into();
}
let fn_name = format_ident!("microkernel_{}x{}_avx512_bcast_b", mr, nr);
let mut acc_idents = Vec::new();
for col in 0..nr {
for chunk in 0..a_regs {
acc_idents.push(format_ident!("c{}_{}", chunk, col));
}
}
let c_loads = generate_bcast_b_c_loads(mr, nr, a_regs, &acc_idents);
let k_body = generate_bcast_b_k_body(mr, nr, a_regs, &acc_idents);
let c_stores = generate_bcast_b_c_stores(mr, nr, a_regs, &acc_idents);
let doc = format!(
"Generated {mr}x{nr} AVX-512 broadcast-B microkernel ({acc_count} zmm accumulators, \
{a_regs} A loads, {nr} B broadcasts, {} FMAs/K-step). \
Faer-style: small NR keeps B panel tiny, large KC.",
a_regs * nr
);
let output = quote! {
#[doc = #doc]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "fma")]
pub unsafe fn #fn_name(
k: usize,
a: *const f32,
b: *const f32,
c: *mut f32,
ldc: usize,
) {
use std::arch::x86_64::*;
#(#c_loads)*
for p in 0..k {
#(#k_body)*
}
#(#c_stores)*
}
};
output.into()
}
fn generate_bcast_b_c_loads(
_mr: usize,
nr: usize,
a_regs: usize,
acc_idents: &[proc_macro2::Ident],
) -> Vec<TokenStream2> {
let mut loads = Vec::new();
for col in 0..nr {
for chunk in 0..a_regs {
let ident = &acc_idents[col * a_regs + chunk];
loads.push(quote! {
let mut #ident = _mm512_setzero_ps();
});
}
}
loads
}
fn generate_bcast_b_k_body(
mr: usize,
nr: usize,
a_regs: usize,
acc_idents: &[proc_macro2::Ident],
) -> Vec<TokenStream2> {
let mut body = Vec::new();
let mr_val = mr;
let nr_val = nr;
for chunk in 0..a_regs {
let a_ident = format_ident!("a{}", chunk);
let offset = chunk * 16;
body.push(quote! {
let #a_ident = _mm512_loadu_ps(a.add(p * #mr_val + #offset));
});
}
for col in 0..nr {
let b_ident = format_ident!("b{}", col);
body.push(quote! {
let #b_ident = _mm512_set1_ps(*b.add(p * #nr_val + #col));
});
for chunk in 0..a_regs {
let c_ident = &acc_idents[col * a_regs + chunk];
let a_ident = format_ident!("a{}", chunk);
body.push(quote! {
#c_ident = _mm512_fmadd_ps(#a_ident, #b_ident, #c_ident);
});
}
}
body
}
fn generate_bcast_b_c_stores(
mr: usize,
nr: usize,
a_regs: usize,
acc_idents: &[proc_macro2::Ident],
) -> Vec<TokenStream2> {
let mut stores = Vec::new();
for col in 0..nr {
for chunk in 0..a_regs {
let ident = &acc_idents[col * a_regs + chunk];
let base_row = chunk * 16;
let mut scatter_stmts = Vec::new();
for i in 0..16 {
let row = base_row + i;
if row < mr {
scatter_stmts.push(quote! {
*c.add(#row * ldc + #col) += tmp[#i];
});
}
}
stores.push(quote! {
{
let mut tmp = [0.0f32; 16];
_mm512_storeu_ps(tmp.as_mut_ptr(), #ident);
#(#scatter_stmts)*
}
});
}
}
stores
}
fn generate_c_stores(
mr: usize,
b_regs: usize,
acc_idents: &[proc_macro2::Ident],
) -> Vec<TokenStream2> {
let mut stores = Vec::new();
for row in 0..mr {
for h in 0..b_regs {
let ident = &acc_idents[row * b_regs + h];
let offset = if row == 0 && h == 0 {
quote! { c }
} else if h == 0 {
let row_val = row;
quote! { c.add(#row_val * ldc) }
} else {
let row_val = row;
let col_offset = h * 16;
quote! { c.add(#row_val * ldc + #col_offset) }
};
stores.push(quote! {
_mm512_storeu_ps(#offset, #ident);
});
}
}
stores
}