1#![allow(missing_docs, reason = "not intended to be used directly")]
7
8use std::error::Error;
9use std::fmt;
10use std::fs;
11use std::path::PathBuf;
12
13use quote::quote;
14use syn::Token;
15
16use naga_rust_back::Config;
17use naga_rust_back::naga;
18
19#[proc_macro]
20pub fn include_wgsl_mr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21 let ConfigAndStr {
22 config,
23 string: path_literal,
24 } = syn::parse_macro_input!(input as ConfigAndStr);
25
26 match include_wgsl_mr_impl(config, &path_literal) {
27 Ok(expansion) => expansion.into(),
28 Err(error) => error.to_compile_error().into(),
29 }
30}
31
32#[proc_macro]
33pub fn wgsl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34 let ConfigAndStr {
35 config,
36 string: source_literal,
37 } = syn::parse_macro_input!(input as ConfigAndStr);
38
39 match parse_and_translate(config, source_literal.span(), &source_literal.value()) {
40 Ok(expansion) => expansion.into(),
41 Err(error) => error.to_compile_error().into(),
42 }
43}
44
45#[proc_macro_attribute]
47pub fn dummy_attribute(
48 _meta: proc_macro::TokenStream,
49 input: proc_macro::TokenStream,
50) -> proc_macro::TokenStream {
51 input
52}
53
54struct ConfigAndStr {
59 config: Config,
60 string: syn::LitStr,
61}
62
63impl syn::parse::Parse for ConfigAndStr {
64 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
65 let mut config = macro_default_config();
66 loop {
67 let not_a_string_error = match input.parse::<syn::LitStr>() {
69 Ok(string) => {
70 if !input.is_empty() {
72 input.parse::<Token![,]>()?;
73 }
74 return Ok(Self { config, string });
75 }
76 Err(e) => e,
77 };
78
79 let option_name = input.parse::<syn::Ident>().map_err(|mut e| {
80 e.combine(not_a_string_error);
81 e
82 })?;
83 input.parse::<Token![=]>()?;
84 match &*option_name.to_string() {
85 "allow_unimplemented" => {
88 config = config.allow_unimplemented(input.parse::<syn::LitBool>()?.value);
89 }
90 "explicit_types" => {
91 config = config.explicit_types(input.parse::<syn::LitBool>()?.value);
92 }
93 "public_items" => {
94 config = config.public_items(input.parse::<syn::LitBool>()?.value);
95 }
96 "global_struct" => {
103 config = config.global_struct(input.parse::<syn::Ident>()?.to_string());
104 }
105 "resource_struct" => {
106 config = config.resource_struct(input.parse::<syn::Ident>()?.to_string());
107 }
108 _ => {
109 return Err(syn::Error::new_spanned(
110 option_name,
111 "unrecognized configuration option name",
112 ));
113 }
114 }
115 input.parse::<Token![,]>()?;
116 }
117 }
118}
119
120fn macro_default_config() -> Config {
121 Config::default()
122 .runtime_path("::naga_rust_embed::rt")
123 .explicit_types(true)
126}
127
128fn include_wgsl_mr_impl(
131 config: Config,
132 path_literal: &syn::LitStr,
133) -> Result<proc_macro2::TokenStream, syn::Error> {
134 let mut absolute_path: PathBuf = PathBuf::from(
137 std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR must be set by Cargo"),
138 );
139 absolute_path.push(path_literal.value());
140
141 let absolute_path_str = absolute_path.to_str().ok_or_else(|| {
143 syn::Error::new_spanned(
144 path_literal,
145 format_args!(
146 "absolute path “{p:?}” must be UTF-8",
147 p = absolute_path.display()
148 ),
149 )
150 })?;
151
152 let wgsl_source_text: String = fs::read_to_string(&absolute_path).map_err(|error| {
153 syn::Error::new_spanned(
154 path_literal,
155 format_args!("failed to read “{absolute_path_str}”: {error}"),
156 )
157 })?;
158
159 let translated_tokens = parse_and_translate(config, path_literal.span(), &wgsl_source_text)?;
160
161 Ok(quote! {
162 const _: &str = include_str!(#absolute_path_str);
165
166 #translated_tokens
167 })
168}
169
170fn parse_and_translate(
171 config: Config,
172 wgsl_source_span: proc_macro2::Span,
173 wgsl_source_text: &str,
174) -> Result<proc_macro2::TokenStream, syn::Error> {
175 let module: naga::Module = naga::front::wgsl::parse_str(wgsl_source_text).map_err(|error| {
176 syn::Error::new(
177 wgsl_source_span,
178 format_args!("failed to parse WGSL text: {}", ErrorChain(&error)),
179 )
180 })?;
181
182 let module_info: naga::valid::ModuleInfo = naga::valid::Validator::new(
184 naga::valid::ValidationFlags::all(),
185 naga_rust_back::CAPABILITIES,
186 )
187 .subgroup_stages(naga::valid::ShaderStages::all())
188 .subgroup_operations(naga::valid::SubgroupOperationSet::empty())
190 .validate(&module)
191 .map_err(|error| {
192 syn::Error::new(
193 wgsl_source_span,
194 format_args!("failed to validate WGSL: {}", ErrorChain(&error)),
195 )
196 })?;
197
198 let translated_source: String = naga_rust_back::write_string(&module, &module_info, config)
199 .map_err(|error| {
200 syn::Error::new(
201 wgsl_source_span,
202 format_args!("failed to translate shader to Rust: {}", ErrorChain(&error)),
203 )
204 })?;
205
206 let translated_tokens: proc_macro2::TokenStream =
207 translated_source.parse().map_err(|error| {
208 syn::Error::new(
209 wgsl_source_span,
210 format_args!(
211 "internal error: translator did not produce valid Rust: {}",
212 ErrorChain(&error)
213 ),
214 )
215 })?;
216
217 Ok(translated_tokens)
218}
219
220#[derive(Clone, Copy, Debug)]
226struct ErrorChain<'a>(&'a (dyn Error + 'a));
227
228impl fmt::Display for ErrorChain<'_> {
229 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
230 format_error_chain(fmt, self.0)
231 }
232}
233
234fn format_error_chain(fmt: &mut fmt::Formatter<'_>, mut error: &(dyn Error + '_)) -> fmt::Result {
235 write!(fmt, "{error}")?;
236 while let Some(source) = error.source() {
237 error = source;
238 write!(fmt, "\n↳ {error}")?;
239 }
240
241 Ok(())
242}