use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, FnArg, ItemFn, Pat};
#[proc_macro_attribute]
pub fn warp_kernel(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let name = &input.sig.ident;
let params = &input.sig.inputs;
let body = &input.block;
let vis = &input.vis;
let attrs = &input.attrs;
if let syn::ReturnType::Type(_, ref ty) = input.sig.output {
let msg = "warp_kernel: GPU kernels must return `()`. \
PTX kernel entry points are always void.";
return syn::Error::new_spanned(ty, msg).to_compile_error().into();
}
if !input.sig.generics.params.is_empty() {
let msg = "warp_kernel: GPU kernels cannot be generic. \
PTX entry points require concrete types.";
return syn::Error::new_spanned(&input.sig.generics, msg)
.to_compile_error()
.into();
}
for param in params.iter() {
if let FnArg::Typed(pat_type) = param {
if let Err(err) = validate_kernel_param(&pat_type.ty, &pat_type.pat) {
return err;
}
}
}
let expanded = quote! {
#(#attrs)*
#[no_mangle]
#vis unsafe extern "ptx-kernel" fn #name(#params) #body
};
TokenStream::from(expanded)
}
fn validate_kernel_param(ty: &syn::Type, pat: &Pat) -> Result<(), TokenStream> {
match ty {
syn::Type::Ptr(_) => Ok(()),
syn::Type::Path(tp) => {
if tp.path.segments.len() > 1 {
let msg = format!(
"warp_kernel: parameter `{}` uses qualified type `{}`. \
Use unqualified scalar types (u32, i32, f32, etc.) for kernel parameters.",
quote!(#pat),
quote!(#ty)
);
return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
}
if let Some(seg) = tp.path.segments.last() {
let name = seg.ident.to_string();
let valid_scalars = [
"u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64", "bool",
];
if !valid_scalars.contains(&name.as_str()) {
let msg = format!(
"warp_kernel: parameter `{}` has type `{}` which is not a GPU-compatible type. \
Use raw pointers (*const T, *mut T) for device memory or scalar types (u32, i32, f32, etc.).",
quote!(#pat), name
);
return Err(syn::Error::new_spanned(ty, msg).to_compile_error().into());
}
}
Ok(())
}
_ => {
let msg = format!(
"warp_kernel: parameter `{}` has unsupported type `{}`. \
Kernel parameters must be raw pointers or scalar types.",
quote!(#pat),
quote!(#ty)
);
Err(syn::Error::new_spanned(ty, msg).to_compile_error().into())
}
}
}