Skip to main content

spirv_std_macros/
lib.rs

1// FIXME(eddyb) update/review these lints.
2//
3// BEGIN - Embark standard lints v0.4
4// do not change or add/remove here, but one can add exceptions after this section
5// for more info see: <https://github.com/EmbarkStudios/rust-ecosystem/issues/59>
6#![deny(unsafe_code)]
7#![warn(
8    clippy::all,
9    clippy::await_holding_lock,
10    clippy::char_lit_as_u8,
11    clippy::checked_conversions,
12    clippy::dbg_macro,
13    clippy::debug_assert_with_mut_call,
14    clippy::doc_markdown,
15    clippy::empty_enums,
16    clippy::enum_glob_use,
17    clippy::exit,
18    clippy::expl_impl_clone_on_copy,
19    clippy::explicit_deref_methods,
20    clippy::explicit_into_iter_loop,
21    clippy::fallible_impl_from,
22    clippy::filter_map_next,
23    clippy::float_cmp_const,
24    clippy::fn_params_excessive_bools,
25    clippy::if_let_mutex,
26    clippy::implicit_clone,
27    clippy::imprecise_flops,
28    clippy::inefficient_to_string,
29    clippy::invalid_upcast_comparisons,
30    clippy::large_types_passed_by_value,
31    clippy::let_unit_value,
32    clippy::linkedlist,
33    clippy::lossy_float_literal,
34    clippy::macro_use_imports,
35    clippy::manual_ok_or,
36    clippy::map_err_ignore,
37    clippy::map_flatten,
38    clippy::map_unwrap_or,
39    clippy::match_same_arms,
40    clippy::match_wildcard_for_single_variants,
41    clippy::mem_forget,
42    clippy::mut_mut,
43    clippy::mutex_integer,
44    clippy::needless_borrow,
45    clippy::needless_continue,
46    clippy::option_option,
47    clippy::path_buf_push_overwrite,
48    clippy::ptr_as_ptr,
49    clippy::ref_option_ref,
50    clippy::rest_pat_in_fully_bound_structs,
51    clippy::same_functions_in_if_condition,
52    clippy::semicolon_if_nothing_returned,
53    clippy::string_add_assign,
54    clippy::string_add,
55    clippy::string_lit_as_bytes,
56    clippy::todo,
57    clippy::trait_duplication_in_bounds,
58    clippy::unimplemented,
59    clippy::unnested_or_patterns,
60    clippy::unused_self,
61    clippy::useless_transmute,
62    clippy::verbose_file_reads,
63    clippy::zero_sized_map_values,
64    future_incompatible,
65    nonstandard_style,
66    rust_2018_idioms
67)]
68// END - Embark standard lints v0.4
69// crate-specific exceptions:
70// #![allow()]
71#![doc = include_str!("../README.md")]
72
73mod debug_printf;
74mod image;
75mod sample_param_permutations;
76mod scalar_or_vector_composite;
77
78use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
79use proc_macro::TokenStream;
80use proc_macro2::{Delimiter, Group, Ident, TokenTree};
81use quote::{ToTokens, TokenStreamExt, format_ident, quote};
82use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
83
84/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
85/// `spirv_std::image::Image<...>` type.
86///
87/// The grammar for the macro is as follows:
88///
89/// ```rust,ignore
90/// Image!(
91///     <dimensionality>,
92///     <type=...|format=...>,
93///     [sampled[=<true|false>],]
94///     [multisampled[=<true|false>],]
95///     [arrayed[=<true|false>],]
96///     [depth[=<true|false>],]
97/// )
98/// ```
99///
100/// `=true` can be omitted as shorthand - e.g. `sampled` is short for `sampled=true`.
101///
102/// A basic example looks like this:
103/// ```rust,ignore
104/// #[spirv(vertex)]
105/// fn main(#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled)) {}
106/// ```
107///
108/// ## Arguments
109///
110/// - `dimensionality` — Dimensionality of an image.
111///   Accepted values: `1D`, `2D`, `3D`, `rect`, `cube`, `subpass`.
112/// - `type` — The sampled type of an image, mutually exclusive with `format`,
113///   when set the image format is unknown.
114///   Accepted values: `f32`, `f64`, `u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64`.
115/// - `format` — The image format of the image, mutually exclusive with `type`.
116///   Accepted values: Snake case versions of [`ImageFormat`] variants, e.g. `rgba32f`,
117///   `rgba8_snorm`.
118/// - `sampled` — Whether it is known that the image will be used with a sampler.
119///   Accepted values: `true` or `false`. Default: `unknown`.
120/// - `multisampled` — Whether the image contains multisampled content.
121///   Accepted values: `true` or `false`. Default: `false`.
122/// - `arrayed` — Whether the image contains arrayed content.
123///   Accepted values: `true` or `false`. Default: `false`.
124/// - `depth` — Whether it is known that the image is a depth image.
125///   Accepted values: `true` or `false`. Default: `unknown`.
126///
127/// [`ImageFormat`]: spirv_std_types::image_params::ImageFormat
128///
129/// Keep in mind that `sampled` here is a different concept than the `SampledImage` type:
130/// `sampled=true` means that this image requires a sampler to be able to access, while the
131/// `SampledImage` type bundles that sampler together with the image into a single type (e.g.
132/// `sampler2D` in GLSL, vs. `texture2D`).
133#[proc_macro]
134// The `Image` is supposed to be used in the type position, which
135// uses `PascalCase`.
136#[allow(nonstandard_style)]
137pub fn Image(item: TokenStream) -> TokenStream {
138    let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
139
140    output.into()
141}
142
143/// Replaces all (nested) occurrences of the `#[spirv(..)]` attribute with
144/// `#[cfg_attr(target_arch="spirv", rust_gpu::spirv(..))]`.
145#[proc_macro_attribute]
146pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
147    let spirv = format_ident!("{}", &spirv_attr_with_version());
148
149    // prepend with #[rust_gpu::spirv(..)]
150    let attr: proc_macro2::TokenStream = attr.into();
151    let mut tokens = quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] };
152
153    let item: proc_macro2::TokenStream = item.into();
154    // If the annotated item is a function without `pub`, automatically add it.
155    // SPIR-V entry points must be publicly visible to the codegen backend.
156    // Also emit `#[allow(missing_docs)]` so the forced-public visibility doesn't
157    // trigger the `missing_docs` lint on crates that have it enabled.
158    let item = if let Ok(mut func) = syn::parse2::<syn::ItemFn>(item.clone()) {
159        if !matches!(func.vis, syn::Visibility::Public(_)) {
160            func.vis = syn::parse_quote!(pub);
161            func.attrs.push(syn::parse_quote!(#[allow(missing_docs)]));
162        }
163        func.into_token_stream()
164    } else {
165        item
166    };
167    for tt in item {
168        match tt {
169            TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
170                let mut group_tokens = proc_macro2::TokenStream::new();
171                let mut last_token_hashtag = false;
172                for tt in group.stream() {
173                    let is_token_hashtag =
174                        matches!(&tt, TokenTree::Punct(punct) if punct.as_char() == '#');
175                    match tt {
176                        TokenTree::Group(group)
177                            if group.delimiter() == Delimiter::Bracket
178                                && last_token_hashtag
179                                && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
180                        {
181                            // group matches [spirv ...]
182                            // group stream doesn't include the brackets
183                            let inner = group
184                                .stream()
185                                .into_iter()
186                                .skip(1)
187                                .collect::<proc_macro2::TokenStream>();
188                            group_tokens.extend(
189                                quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] },
190                            );
191                        }
192                        _ => group_tokens.append(tt),
193                    }
194                    last_token_hashtag = is_token_hashtag;
195                }
196                let mut out = Group::new(Delimiter::Parenthesis, group_tokens);
197                out.set_span(group.span());
198                tokens.append(out);
199            }
200            _ => tokens.append(tt),
201        }
202    }
203    tokens.into()
204}
205
206/// For testing only! Is not reexported in `spirv-std`, but reachable via
207/// `spirv_std::macros::spirv_recursive_for_testing`.
208///
209/// May be more expensive than plain `spirv`, since we're checking a lot more symbols. So I've opted to
210/// have this be a separate macro, instead of modifying the standard `spirv` one.
211#[proc_macro_attribute]
212pub fn spirv_recursive_for_testing(attr: TokenStream, item: TokenStream) -> TokenStream {
213    fn recurse(spirv: &Ident, stream: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
214        let mut last_token_hashtag = false;
215        stream.into_iter().map(|tt| {
216            let mut is_token_hashtag = false;
217            let out = match tt {
218                TokenTree::Group(group)
219                if group.delimiter() == Delimiter::Bracket
220                    && last_token_hashtag
221                    && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
222                    {
223                        // group matches [spirv ...]
224                        // group stream doesn't include the brackets
225                        let inner = group
226                            .stream()
227                            .into_iter()
228                            .skip(1)
229                            .collect::<proc_macro2::TokenStream>();
230                        quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] }
231                    },
232                TokenTree::Group(group) => {
233                    let mut out = Group::new(group.delimiter(), recurse(spirv, group.stream()));
234                    out.set_span(group.span());
235                    TokenTree::Group(out).into()
236                },
237                TokenTree::Punct(punct) => {
238                    is_token_hashtag = punct.as_char() == '#';
239                    TokenTree::Punct(punct).into()
240                }
241                tt => tt.into(),
242            };
243            last_token_hashtag = is_token_hashtag;
244            out
245        }).collect()
246    }
247
248    let attr: proc_macro2::TokenStream = attr.into();
249    let item: proc_macro2::TokenStream = item.into();
250
251    // prepend with #[rust_gpu::spirv(..)]
252    let spirv = format_ident!("{}", &spirv_attr_with_version());
253    let inner = recurse(&spirv, item);
254    quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] #inner }.into()
255}
256
257/// Marks a function as runnable only on the GPU, and will panic on
258/// CPU platforms.
259#[proc_macro_attribute]
260pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
261    let syn::ItemFn {
262        attrs,
263        vis,
264        sig,
265        block,
266    } = syn::parse_macro_input!(item as syn::ItemFn);
267
268    let fn_name = sig.ident.clone();
269
270    let sig_cpu = syn::Signature {
271        abi: None,
272        ..sig.clone()
273    };
274
275    let output = quote::quote! {
276        // Don't warn on unused arguments on the CPU side.
277        #[cfg(not(target_arch="spirv"))]
278        #[allow(unused_variables)]
279        #(#attrs)* #vis #sig_cpu {
280            unimplemented!(
281                concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms.")
282            )
283        }
284
285        #[cfg(target_arch="spirv")]
286        #(#attrs)* #vis #sig {
287            #block
288        }
289    };
290
291    output.into()
292}
293
294/// Print a formatted string using the debug printf extension.
295///
296/// Examples:
297///
298/// ```rust,ignore
299/// debug_printf!("uv: %v2f\n", uv);
300/// debug_printf!("pos.x: %f, pos.z: %f, int: %i\n", pos.x, pos.z, int);
301/// ```
302///
303/// See <https://github.com/KhronosGroup/Vulkan-ValidationLayers/blob/main/docs/debug_printf.md#debug-printf-format-string> for formatting rules.
304#[proc_macro]
305pub fn debug_printf(input: TokenStream) -> TokenStream {
306    debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
307}
308
309/// Similar to `debug_printf` but appends a newline to the format string.
310#[proc_macro]
311pub fn debug_printfln(input: TokenStream) -> TokenStream {
312    let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
313    input.format_string.push('\n');
314    debug_printf_inner(input)
315}
316
317/// Generates permutations of an `ImageWithMethods` implementation containing sampling functions
318/// that have asm instruction ending with a placeholder `$PARAMS` operand. The last parameter
319/// of each function must be named `params`, its type will be rewritten. Relevant generic
320/// arguments are added to the impl generics.
321/// See `SAMPLE_PARAM_GENERICS` for a list of names you cannot use as generic arguments.
322#[proc_macro_attribute]
323#[doc(hidden)]
324pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
325    sample_param_permutations::gen_sample_param_permutations(item)
326}
327
328#[proc_macro_derive(ScalarComposite)]
329pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream {
330    scalar_or_vector_composite::derive(item.into())
331        .unwrap_or_else(syn::Error::into_compile_error)
332        .into()
333}