1#![deny(unsafe_code)]
7#![warn(
8 clippy::all,
9 clippy::await_holding_lock,
10 clippy::char_lit_as_u8,
11 clippy::checked_conversions,
12 clippy::dbg_macro,
13 clippy::debug_assert_with_mut_call,
14 clippy::doc_markdown,
15 clippy::empty_enums,
16 clippy::enum_glob_use,
17 clippy::exit,
18 clippy::expl_impl_clone_on_copy,
19 clippy::explicit_deref_methods,
20 clippy::explicit_into_iter_loop,
21 clippy::fallible_impl_from,
22 clippy::filter_map_next,
23 clippy::float_cmp_const,
24 clippy::fn_params_excessive_bools,
25 clippy::if_let_mutex,
26 clippy::implicit_clone,
27 clippy::imprecise_flops,
28 clippy::inefficient_to_string,
29 clippy::invalid_upcast_comparisons,
30 clippy::large_types_passed_by_value,
31 clippy::let_unit_value,
32 clippy::linkedlist,
33 clippy::lossy_float_literal,
34 clippy::macro_use_imports,
35 clippy::manual_ok_or,
36 clippy::map_err_ignore,
37 clippy::map_flatten,
38 clippy::map_unwrap_or,
39 clippy::match_same_arms,
40 clippy::match_wildcard_for_single_variants,
41 clippy::mem_forget,
42 clippy::mut_mut,
43 clippy::mutex_integer,
44 clippy::needless_borrow,
45 clippy::needless_continue,
46 clippy::option_option,
47 clippy::path_buf_push_overwrite,
48 clippy::ptr_as_ptr,
49 clippy::ref_option_ref,
50 clippy::rest_pat_in_fully_bound_structs,
51 clippy::same_functions_in_if_condition,
52 clippy::semicolon_if_nothing_returned,
53 clippy::string_add_assign,
54 clippy::string_add,
55 clippy::string_lit_as_bytes,
56 clippy::todo,
57 clippy::trait_duplication_in_bounds,
58 clippy::unimplemented,
59 clippy::unnested_or_patterns,
60 clippy::unused_self,
61 clippy::useless_transmute,
62 clippy::verbose_file_reads,
63 clippy::zero_sized_map_values,
64 future_incompatible,
65 nonstandard_style,
66 rust_2018_idioms
67)]
68#![doc = include_str!("../README.md")]
72
73mod debug_printf;
74mod image;
75mod sample_param_permutations;
76mod scalar_or_vector_composite;
77
78use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
79use proc_macro::TokenStream;
80use proc_macro2::{Delimiter, Group, Ident, TokenTree};
81use quote::{ToTokens, TokenStreamExt, format_ident, quote};
82use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
83
84#[proc_macro]
134#[allow(nonstandard_style)]
137pub fn Image(item: TokenStream) -> TokenStream {
138 let output = syn::parse_macro_input!(item as image::ImageType).into_token_stream();
139
140 output.into()
141}
142
143#[proc_macro_attribute]
146pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
147 let spirv = format_ident!("{}", &spirv_attr_with_version());
148
149 let attr: proc_macro2::TokenStream = attr.into();
151 let mut tokens = quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] };
152
153 let item: proc_macro2::TokenStream = item.into();
154 let item = if let Ok(mut func) = syn::parse2::<syn::ItemFn>(item.clone()) {
159 if !matches!(func.vis, syn::Visibility::Public(_)) {
160 func.vis = syn::parse_quote!(pub);
161 func.attrs.push(syn::parse_quote!(#[allow(missing_docs)]));
162 }
163 func.into_token_stream()
164 } else {
165 item
166 };
167 for tt in item {
168 match tt {
169 TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
170 let mut group_tokens = proc_macro2::TokenStream::new();
171 let mut last_token_hashtag = false;
172 for tt in group.stream() {
173 let is_token_hashtag =
174 matches!(&tt, TokenTree::Punct(punct) if punct.as_char() == '#');
175 match tt {
176 TokenTree::Group(group)
177 if group.delimiter() == Delimiter::Bracket
178 && last_token_hashtag
179 && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
180 {
181 let inner = group
184 .stream()
185 .into_iter()
186 .skip(1)
187 .collect::<proc_macro2::TokenStream>();
188 group_tokens.extend(
189 quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] },
190 );
191 }
192 _ => group_tokens.append(tt),
193 }
194 last_token_hashtag = is_token_hashtag;
195 }
196 let mut out = Group::new(Delimiter::Parenthesis, group_tokens);
197 out.set_span(group.span());
198 tokens.append(out);
199 }
200 _ => tokens.append(tt),
201 }
202 }
203 tokens.into()
204}
205
206#[proc_macro_attribute]
212pub fn spirv_recursive_for_testing(attr: TokenStream, item: TokenStream) -> TokenStream {
213 fn recurse(spirv: &Ident, stream: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
214 let mut last_token_hashtag = false;
215 stream.into_iter().map(|tt| {
216 let mut is_token_hashtag = false;
217 let out = match tt {
218 TokenTree::Group(group)
219 if group.delimiter() == Delimiter::Bracket
220 && last_token_hashtag
221 && matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
222 {
223 let inner = group
226 .stream()
227 .into_iter()
228 .skip(1)
229 .collect::<proc_macro2::TokenStream>();
230 quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] }
231 },
232 TokenTree::Group(group) => {
233 let mut out = Group::new(group.delimiter(), recurse(spirv, group.stream()));
234 out.set_span(group.span());
235 TokenTree::Group(out).into()
236 },
237 TokenTree::Punct(punct) => {
238 is_token_hashtag = punct.as_char() == '#';
239 TokenTree::Punct(punct).into()
240 }
241 tt => tt.into(),
242 };
243 last_token_hashtag = is_token_hashtag;
244 out
245 }).collect()
246 }
247
248 let attr: proc_macro2::TokenStream = attr.into();
249 let item: proc_macro2::TokenStream = item.into();
250
251 let spirv = format_ident!("{}", &spirv_attr_with_version());
253 let inner = recurse(&spirv, item);
254 quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] #inner }.into()
255}
256
257#[proc_macro_attribute]
260pub fn gpu_only(_attr: TokenStream, item: TokenStream) -> TokenStream {
261 let syn::ItemFn {
262 attrs,
263 vis,
264 sig,
265 block,
266 } = syn::parse_macro_input!(item as syn::ItemFn);
267
268 let fn_name = sig.ident.clone();
269
270 let sig_cpu = syn::Signature {
271 abi: None,
272 ..sig.clone()
273 };
274
275 let output = quote::quote! {
276 #[cfg(not(target_arch="spirv"))]
278 #[allow(unused_variables)]
279 #(#attrs)* #vis #sig_cpu {
280 unimplemented!(
281 concat!("`", stringify!(#fn_name), "` is only available on SPIR-V platforms.")
282 )
283 }
284
285 #[cfg(target_arch="spirv")]
286 #(#attrs)* #vis #sig {
287 #block
288 }
289 };
290
291 output.into()
292}
293
294#[proc_macro]
305pub fn debug_printf(input: TokenStream) -> TokenStream {
306 debug_printf_inner(syn::parse_macro_input!(input as DebugPrintfInput))
307}
308
309#[proc_macro]
311pub fn debug_printfln(input: TokenStream) -> TokenStream {
312 let mut input = syn::parse_macro_input!(input as DebugPrintfInput);
313 input.format_string.push('\n');
314 debug_printf_inner(input)
315}
316
317#[proc_macro_attribute]
323#[doc(hidden)]
324pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
325 sample_param_permutations::gen_sample_param_permutations(item)
326}
327
328#[proc_macro_derive(ScalarComposite)]
329pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream {
330 scalar_or_vector_composite::derive(item.into())
331 .unwrap_or_else(syn::Error::into_compile_error)
332 .into()
333}