cuda_std_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote_spanned, ToTokens};
4use syn::{
5    parse::Parse, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, Error,
6    FnArg, Ident, ItemFn, ReturnType, Stmt, Token,
7};
8
9/// Registers a function as a gpu kernel.
10///
11/// This attribute must always be placed on gpu kernel functions.
12///
13/// This attribute does a couple of things:
14/// - Tells `rustc_codegen_nvvm` to mark this as a gpu kernel and to not remove it from the ptx file.
15/// - Marks the function as `no_mangle`.
16/// - Errors if the function is not unsafe.
17/// - Makes sure function parameters are all [`Copy`].
18/// - Makes sure the function doesn't return anything.
19///
20/// Note that this does not cfg the function for nvptx(64), that is explicit so that rust analyzer is able to
21/// offer intellisense by default.
22#[proc_macro_attribute]
23pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
24    let cloned = input.clone();
25    let _ = parse_macro_input!(input as KernelHints);
26    let input = parse_macro_input!(cloned as proc_macro2::TokenStream);
27    let mut item = parse_macro_input!(item as ItemFn);
28    let no_mangle = parse_quote!(#[no_mangle]);
29    item.attrs.push(no_mangle);
30    let internal = parse_quote!(#[cfg_attr(any(target_arch="nvptx", target_arch="nvptx64"), nvvm_internal(kernel(#input)))]);
31    item.attrs.push(internal);
32
33    // used to guarantee some things about how params are passed in the codegen.
34    item.sig.abi = Some(parse_quote!(extern "C"));
35
36    let check_fn = parse_quote! {
37        fn assert_kernel_parameter_is_copy<T: Copy>() {}
38    };
39    item.block.stmts.insert(0, check_fn);
40
41    for param in &item.sig.inputs {
42        let ty = match param {
43            FnArg::Receiver(_) => quote_spanned! {
44                param.span() => ::core::compile_error!("Kernel functions may not be struct methods");
45            },
46            FnArg::Typed(ty) => ty.ty.to_token_stream(),
47        };
48        let call = parse_quote! {
49            assert_kernel_parameter_is_copy::<#ty>();
50        };
51        item.block.stmts.insert(0, call);
52    }
53
54    let ret = item.sig.output.clone();
55    if let ReturnType::Type(_, _) = ret {
56        let err = quote_spanned! {
57            ret.span() => ::core::compile_err!("Kernel functions should not return anything");
58        }
59        .into();
60        item.block.stmts.insert(0, parse_macro_input!(err as Stmt));
61    }
62
63    if item.sig.unsafety.is_none() {
64        let err = quote_spanned! {
65            item.span() => ::core::compile_error!("Kernel functions must be marked as unsafe");
66        }
67        .into();
68        item.block.stmts.insert(0, parse_macro_input!(err as Stmt));
69    }
70
71    item.to_token_stream().into()
72}
73
74#[derive(Debug, Clone, Copy, PartialEq)]
75enum Dimension {
76    Dim1,
77    Dim2,
78    Dim3,
79}
80
81impl Parse for Dimension {
82    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
83        let val = Ident::parse(input)?;
84        let val = val.to_string();
85        match val.as_str() {
86            "1d" | "1D" => Ok(Self::Dim1),
87            "2d" | "2D" => Ok(Self::Dim2),
88            "3d" | "3D" => Ok(Self::Dim3),
89            _ => Err(syn::Error::new(Span::call_site(), "Invalid dimension")),
90        }
91    }
92}
93
94enum KernelHint {
95    GridDim(Dimension),
96    BlockDim(Dimension),
97}
98
99impl Parse for KernelHint {
100    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
101        let name = Ident::parse(input)?;
102        let key = name.to_string();
103        <Token![=]>::parse(input)?;
104        match key.as_str() {
105            "grid_dim" => {
106                let dim = Dimension::parse(input)?;
107                Ok(Self::GridDim(dim))
108            }
109            "block_dim" => {
110                let dim = Dimension::parse(input)?;
111                Ok(Self::BlockDim(dim))
112            }
113            _ => Err(Error::new(Span::call_site(), "Unrecognized option")),
114        }
115    }
116}
117
118#[derive(Debug, Default, Clone, PartialEq)]
119struct KernelHints {
120    grid_dim: Option<Dimension>,
121    block_dim: Option<Dimension>,
122}
123
124impl Parse for KernelHints {
125    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
126        let iter = Punctuated::<KernelHint, Token![,]>::parse_terminated(input)?;
127        let hints = iter
128            .into_pairs()
129            .map(|x| x.into_value())
130            .collect::<Vec<_>>();
131
132        let mut out = KernelHints::default();
133
134        for hint in hints {
135            match hint {
136                KernelHint::GridDim(dim) => out.grid_dim = Some(dim),
137                KernelHint::BlockDim(dim) => out.block_dim = Some(dim),
138            }
139        }
140
141        Ok(out)
142    }
143}
144
145// derived from rust-gpu's gpu_only
146
147/// Creates a cpu version of the function which panics and cfg-gates the function for only nvptx/nvptx64.
148#[proc_macro_attribute]
149pub fn gpu_only(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
150    let syn::ItemFn {
151        attrs,
152        vis,
153        sig,
154        block,
155    } = syn::parse_macro_input!(item as syn::ItemFn);
156
157    let mut cloned_attrs = attrs.clone();
158    cloned_attrs.retain(|a| {
159        !a.path
160            .get_ident()
161            .map(|x| *x == "nvvm_internal")
162            .unwrap_or_default()
163    });
164
165    let fn_name = sig.ident.clone();
166
167    let sig_cpu = syn::Signature {
168        abi: None,
169        ..sig.clone()
170    };
171
172    let output = quote::quote! {
173        #[cfg(not(any(target_arch="nvptx", target_arch="nvptx64")))]
174        #[allow(unused_variables)]
175        #(#cloned_attrs)* #vis #sig_cpu {
176            unimplemented!(concat!("`", stringify!(#fn_name), "` can only be used on the GPU with rustc_codegen_nvvm"))
177        }
178
179        #[cfg(any(target_arch="nvptx", target_arch="nvptx64"))]
180        #(#attrs)* #vis #sig {
181            #block
182        }
183    };
184
185    output.into()
186}
187
188/// Notifies the codegen that this function is externally visible and should not be
189/// removed if it is not used by a kernel. Usually used for linking with other PTX/cubin files.
190///
191/// # Panics
192///
193/// Panics if the function is not also no_mangle.
194#[proc_macro_attribute]
195pub fn externally_visible(
196    _attr: proc_macro::TokenStream,
197    item: proc_macro::TokenStream,
198) -> TokenStream {
199    let mut func = syn::parse_macro_input!(item as syn::ItemFn);
200
201    assert!(
202        func.attrs.iter().any(|a| a.path.is_ident("no_mangle")),
203        "#[externally_visible] function should also be #[no_mangle]"
204    );
205
206    let new_attr = parse_quote!(#[cfg_attr(target_os = "cuda", nvvm_internal(used))]);
207    func.attrs.push(new_attr);
208
209    func.into_token_stream().into()
210}
211
212/// Notifies the codegen to put a `static`/`static mut` inside of a specific memory address space.
213/// This is mostly for internal use and/or advanced users, as the codegen and `cuda_std` handle address space placement
214/// implicitly. **Improper use of this macro could yield weird or undefined behavior**.
215///
216/// This macro takes a single argument which can either be `global`, `shared`, `constant`, or `local`.
217///
218/// This macro does nothing on the CPU.
219#[proc_macro_attribute]
220pub fn address_space(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
221    let mut global = syn::parse_macro_input!(item as syn::ItemStatic);
222    let input = syn::parse_macro_input!(attr as Ident);
223
224    let addrspace_num = match input.to_string().as_str() {
225        "global" => 1,
226        // what did you do to address space 2 libnvvm??
227        "shared" => 3,
228        "constant" => 4,
229        "local" => 5,
230        addr => panic!("Invalid address space `{}`", addr),
231    };
232
233    let new_attr =
234        parse_quote!(#[cfg_attr(target_os = "cuda", nvvm_internal(addrspace(#addrspace_num)))]);
235    global.attrs.push(new_attr);
236
237    global.into_token_stream().into()
238}