use proc_macro2::TokenStream;
use syn::parse::{Parse, ParseStream};
use syn::{Ident, LitInt, Token};
use crate::kernel_ir::KernelConfig;
struct GpuKernelAttrs {
block_size: Option<(u32, proc_macro2::Span)>,
}
impl Parse for GpuKernelAttrs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut block_size = None;
while !input.is_empty() {
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;
match key.to_string().as_str() {
"block_size" => {
let lit: LitInt = input.parse()?;
let span = lit.span();
let value: u32 = lit.base10_parse().map_err(|_| {
syn::Error::new(span, "block_size must be a positive integer")
})?;
block_size = Some((value, span));
}
other => {
return Err(syn::Error::new(
key.span(),
format!("unknown attribute `{other}`. Expected: block_size"),
));
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(GpuKernelAttrs { block_size })
}
}
pub fn parse_kernel_config(attr: TokenStream) -> syn::Result<KernelConfig> {
let attrs: GpuKernelAttrs = syn::parse2(attr)?;
let (block_size, block_size_span) = attrs.block_size.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"`block_size` is required in `#[gpu_kernel(...)]`",
)
})?;
if !block_size.is_power_of_two() {
return Err(syn::Error::new(
block_size_span,
format!("`block_size` must be a power of 2, got {block_size}"),
));
}
if block_size > 1024 {
return Err(syn::Error::new(
block_size_span,
format!("`block_size` cannot exceed 1024, got {block_size}"),
));
}
Ok(KernelConfig {
block_size,
block_size_span,
})
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
#[test]
fn parse_block_size_256() {
let tokens = quote! { block_size = 256 };
let config = parse_kernel_config(tokens).unwrap();
assert_eq!(config.block_size, 256);
}
#[test]
fn parse_block_size_1() {
let tokens = quote! { block_size = 1 };
let config = parse_kernel_config(tokens).unwrap();
assert_eq!(config.block_size, 1);
}
#[test]
fn parse_block_size_1024() {
let tokens = quote! { block_size = 1024 };
let config = parse_kernel_config(tokens).unwrap();
assert_eq!(config.block_size, 1024);
}
#[test]
fn parse_block_size_with_trailing_comma() {
let tokens = quote! { block_size = 128, };
let config = parse_kernel_config(tokens).unwrap();
assert_eq!(config.block_size, 128);
}
#[test]
fn reject_missing_block_size() {
let tokens = quote! {};
let err = parse_kernel_config(tokens).unwrap_err();
assert!(err.to_string().contains("block_size"));
}
#[test]
fn reject_non_power_of_two() {
let tokens = quote! { block_size = 100 };
let err = parse_kernel_config(tokens).unwrap_err();
assert!(err.to_string().contains("power of 2"));
}
#[test]
fn reject_exceeds_1024() {
let tokens = quote! { block_size = 2048 };
let err = parse_kernel_config(tokens).unwrap_err();
assert!(err.to_string().contains("1024"));
}
#[test]
fn reject_unknown_attribute() {
let tokens = quote! { block_size = 256, warp_size = 32 };
let err = parse_kernel_config(tokens).unwrap_err();
assert!(err.to_string().contains("unknown attribute"));
}
}