1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote_spanned, ToTokens};
4use syn::{
5 parse::Parse, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, Error,
6 FnArg, Ident, ItemFn, ReturnType, Stmt, Token,
7};
8
9#[proc_macro_attribute]
23pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
24 let cloned = input.clone();
25 let _ = parse_macro_input!(input as KernelHints);
26 let input = parse_macro_input!(cloned as proc_macro2::TokenStream);
27 let mut item = parse_macro_input!(item as ItemFn);
28 let no_mangle = parse_quote!(#[no_mangle]);
29 item.attrs.push(no_mangle);
30 let internal = parse_quote!(#[cfg_attr(any(target_arch="nvptx", target_arch="nvptx64"), nvvm_internal(kernel(#input)))]);
31 item.attrs.push(internal);
32
33 item.sig.abi = Some(parse_quote!(extern "C"));
35
36 let check_fn = parse_quote! {
37 fn assert_kernel_parameter_is_copy<T: Copy>() {}
38 };
39 item.block.stmts.insert(0, check_fn);
40
41 for param in &item.sig.inputs {
42 let ty = match param {
43 FnArg::Receiver(_) => quote_spanned! {
44 param.span() => ::core::compile_error!("Kernel functions may not be struct methods");
45 },
46 FnArg::Typed(ty) => ty.ty.to_token_stream(),
47 };
48 let call = parse_quote! {
49 assert_kernel_parameter_is_copy::<#ty>();
50 };
51 item.block.stmts.insert(0, call);
52 }
53
54 let ret = item.sig.output.clone();
55 if let ReturnType::Type(_, _) = ret {
56 let err = quote_spanned! {
57 ret.span() => ::core::compile_err!("Kernel functions should not return anything");
58 }
59 .into();
60 item.block.stmts.insert(0, parse_macro_input!(err as Stmt));
61 }
62
63 if item.sig.unsafety.is_none() {
64 let err = quote_spanned! {
65 item.span() => ::core::compile_error!("Kernel functions must be marked as unsafe");
66 }
67 .into();
68 item.block.stmts.insert(0, parse_macro_input!(err as Stmt));
69 }
70
71 item.to_token_stream().into()
72}
73
74#[derive(Debug, Clone, Copy, PartialEq)]
75enum Dimension {
76 Dim1,
77 Dim2,
78 Dim3,
79}
80
81impl Parse for Dimension {
82 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
83 let val = Ident::parse(input)?;
84 let val = val.to_string();
85 match val.as_str() {
86 "1d" | "1D" => Ok(Self::Dim1),
87 "2d" | "2D" => Ok(Self::Dim2),
88 "3d" | "3D" => Ok(Self::Dim3),
89 _ => Err(syn::Error::new(Span::call_site(), "Invalid dimension")),
90 }
91 }
92}
93
94enum KernelHint {
95 GridDim(Dimension),
96 BlockDim(Dimension),
97}
98
99impl Parse for KernelHint {
100 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
101 let name = Ident::parse(input)?;
102 let key = name.to_string();
103 <Token![=]>::parse(input)?;
104 match key.as_str() {
105 "grid_dim" => {
106 let dim = Dimension::parse(input)?;
107 Ok(Self::GridDim(dim))
108 }
109 "block_dim" => {
110 let dim = Dimension::parse(input)?;
111 Ok(Self::BlockDim(dim))
112 }
113 _ => Err(Error::new(Span::call_site(), "Unrecognized option")),
114 }
115 }
116}
117
118#[derive(Debug, Default, Clone, PartialEq)]
119struct KernelHints {
120 grid_dim: Option<Dimension>,
121 block_dim: Option<Dimension>,
122}
123
124impl Parse for KernelHints {
125 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
126 let iter = Punctuated::<KernelHint, Token![,]>::parse_terminated(input)?;
127 let hints = iter
128 .into_pairs()
129 .map(|x| x.into_value())
130 .collect::<Vec<_>>();
131
132 let mut out = KernelHints::default();
133
134 for hint in hints {
135 match hint {
136 KernelHint::GridDim(dim) => out.grid_dim = Some(dim),
137 KernelHint::BlockDim(dim) => out.block_dim = Some(dim),
138 }
139 }
140
141 Ok(out)
142 }
143}
144
145#[proc_macro_attribute]
149pub fn gpu_only(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
150 let syn::ItemFn {
151 attrs,
152 vis,
153 sig,
154 block,
155 } = syn::parse_macro_input!(item as syn::ItemFn);
156
157 let mut cloned_attrs = attrs.clone();
158 cloned_attrs.retain(|a| {
159 !a.path
160 .get_ident()
161 .map(|x| *x == "nvvm_internal")
162 .unwrap_or_default()
163 });
164
165 let fn_name = sig.ident.clone();
166
167 let sig_cpu = syn::Signature {
168 abi: None,
169 ..sig.clone()
170 };
171
172 let output = quote::quote! {
173 #[cfg(not(any(target_arch="nvptx", target_arch="nvptx64")))]
174 #[allow(unused_variables)]
175 #(#cloned_attrs)* #vis #sig_cpu {
176 unimplemented!(concat!("`", stringify!(#fn_name), "` can only be used on the GPU with rustc_codegen_nvvm"))
177 }
178
179 #[cfg(any(target_arch="nvptx", target_arch="nvptx64"))]
180 #(#attrs)* #vis #sig {
181 #block
182 }
183 };
184
185 output.into()
186}
187
188#[proc_macro_attribute]
195pub fn externally_visible(
196 _attr: proc_macro::TokenStream,
197 item: proc_macro::TokenStream,
198) -> TokenStream {
199 let mut func = syn::parse_macro_input!(item as syn::ItemFn);
200
201 assert!(
202 func.attrs.iter().any(|a| a.path.is_ident("no_mangle")),
203 "#[externally_visible] function should also be #[no_mangle]"
204 );
205
206 let new_attr = parse_quote!(#[cfg_attr(target_os = "cuda", nvvm_internal(used))]);
207 func.attrs.push(new_attr);
208
209 func.into_token_stream().into()
210}
211
212#[proc_macro_attribute]
220pub fn address_space(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
221 let mut global = syn::parse_macro_input!(item as syn::ItemStatic);
222 let input = syn::parse_macro_input!(attr as Ident);
223
224 let addrspace_num = match input.to_string().as_str() {
225 "global" => 1,
226 "shared" => 3,
228 "constant" => 4,
229 "local" => 5,
230 addr => panic!("Invalid address space `{}`", addr),
231 };
232
233 let new_attr =
234 parse_quote!(#[cfg_attr(target_os = "cuda", nvvm_internal(addrspace(#addrspace_num)))]);
235 global.attrs.push(new_attr);
236
237 global.into_token_stream().into()
238}