Skip to main content

naga_rust_macros/
lib.rs

1//! This is a proc-macro helper library. Don't use this library directly; use [`naga_rust_embed`]
2//! instead.
3//!
4//! [`naga_rust_embed`]: https://docs.rs/naga-rust-embed
5
6#![allow(missing_docs, reason = "not intended to be used directly")]
7
8use std::error::Error;
9use std::fmt;
10use std::fs;
11use std::path::PathBuf;
12
13use quote::quote;
14use syn::Token;
15
16use naga_rust_back::Config;
17use naga_rust_back::naga;
18
19#[proc_macro]
20pub fn include_wgsl_mr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21    let ConfigAndStr {
22        config,
23        string: path_literal,
24    } = syn::parse_macro_input!(input as ConfigAndStr);
25
26    match include_wgsl_mr_impl(config, &path_literal) {
27        Ok(expansion) => expansion.into(),
28        Err(error) => error.to_compile_error().into(),
29    }
30}
31
32#[proc_macro]
33pub fn wgsl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34    let ConfigAndStr {
35        config,
36        string: source_literal,
37    } = syn::parse_macro_input!(input as ConfigAndStr);
38
39    match parse_and_translate(config, source_literal.span(), &source_literal.value()) {
40        Ok(expansion) => expansion.into(),
41        Err(error) => error.to_compile_error().into(),
42    }
43}
44
45/// Returns the input unchanged.
46#[proc_macro_attribute]
47pub fn dummy_attribute(
48    _meta: proc_macro::TokenStream,
49    input: proc_macro::TokenStream,
50) -> proc_macro::TokenStream {
51    input
52}
53
54// -------------------------------------------------------------------------------------------------
55
56/// Parsed syntax for the [`wgsl`] or [`include_wgsl_mr`] macros, which consist of configuration
57/// options `name = value_expr` followed by a string literal which is either source code or a path.
58struct ConfigAndStr {
59    config: Config,
60    string: syn::LitStr,
61}
62
63impl syn::parse::Parse for ConfigAndStr {
64    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
65        let mut config = macro_default_config();
66        loop {
67            // Try parsing the final string literal.
68            let not_a_string_error = match input.parse::<syn::LitStr>() {
69                Ok(string) => {
70                    // Accept a final optional comma after the string.
71                    if !input.is_empty() {
72                        input.parse::<Token![,]>()?;
73                    }
74                    return Ok(Self { config, string });
75                }
76                Err(e) => e,
77            };
78
79            let option_name = input.parse::<syn::Ident>().map_err(|mut e| {
80                e.combine(not_a_string_error);
81                e
82            })?;
83            input.parse::<Token![=]>()?;
84            match &*option_name.to_string() {
85                // The options parsed by this match should also be documented in
86                // `embed/src/configuration_syntax.md`.
87                "allow_unimplemented" => {
88                    config = config.allow_unimplemented(input.parse::<syn::LitBool>()?.value);
89                }
90                "explicit_types" => {
91                    config = config.explicit_types(input.parse::<syn::LitBool>()?.value);
92                }
93                "public_items" => {
94                    config = config.public_items(input.parse::<syn::LitBool>()?.value);
95                }
96                // TODO: raw_pointers doesn’t actually work, and will need to be marked unsafe
97                // when it is implemented. So, we don’t offer it yet.
98                //
99                // "raw_pointers" => {
100                //     config = config.raw_pointers(input.parse::<syn::LitBool>()?.value);
101                // }
102                "global_struct" => {
103                    config = config.global_struct(input.parse::<syn::Ident>()?.to_string());
104                }
105                "resource_struct" => {
106                    config = config.resource_struct(input.parse::<syn::Ident>()?.to_string());
107                }
108                _ => {
109                    return Err(syn::Error::new_spanned(
110                        option_name,
111                        "unrecognized configuration option name",
112                    ));
113                }
114            }
115            input.parse::<Token![,]>()?;
116        }
117    }
118}
119
120fn macro_default_config() -> Config {
121    Config::default()
122        .runtime_path("::naga_rust_embed::rt")
123        // Helps give better errors when the generated code is wrong.
124        // TODO: Consider turning this back off for efficiency? Measure impact?
125        .explicit_types(true)
126}
127
128// -------------------------------------------------------------------------------------------------
129
130fn include_wgsl_mr_impl(
131    config: Config,
132    path_literal: &syn::LitStr,
133) -> Result<proc_macro2::TokenStream, syn::Error> {
134    // We use manifest-relative paths because currently, there is no way to arrange for
135    // source-file-relative paths.
136    let mut absolute_path: PathBuf = PathBuf::from(
137        std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR must be set by Cargo"),
138    );
139    absolute_path.push(path_literal.value());
140
141    // If this fails then we can't generate the `include_str!` we must generate.
142    let absolute_path_str = absolute_path.to_str().ok_or_else(|| {
143        syn::Error::new_spanned(
144            path_literal,
145            format_args!(
146                "absolute path “{p:?}” must be UTF-8",
147                p = absolute_path.display()
148            ),
149        )
150    })?;
151
152    let wgsl_source_text: String = fs::read_to_string(&absolute_path).map_err(|error| {
153        syn::Error::new_spanned(
154            path_literal,
155            format_args!("failed to read “{absolute_path_str}”: {error}"),
156        )
157    })?;
158
159    let translated_tokens = parse_and_translate(config, path_literal.span(), &wgsl_source_text)?;
160
161    Ok(quote! {
162        // Dummy include_str! call tells the compiler that we depend on this file,
163        // which it would not notice otherwise.
164        const _: &str = include_str!(#absolute_path_str);
165
166        #translated_tokens
167    })
168}
169
170fn parse_and_translate(
171    config: Config,
172    wgsl_source_span: proc_macro2::Span,
173    wgsl_source_text: &str,
174) -> Result<proc_macro2::TokenStream, syn::Error> {
175    let module: naga::Module = naga::front::wgsl::parse_str(wgsl_source_text).map_err(|error| {
176        syn::Error::new(
177            wgsl_source_span,
178            format_args!("failed to parse WGSL text: {}", ErrorChain(&error)),
179        )
180    })?;
181
182    // TODO: allow the user of the macro to configure which validation is done.
183    let module_info: naga::valid::ModuleInfo = naga::valid::Validator::new(
184        naga::valid::ValidationFlags::all(),
185        naga_rust_back::CAPABILITIES,
186    )
187    .subgroup_stages(naga::valid::ShaderStages::all())
188    // TODO: Add support for subgroup operations, then update this.
189    .subgroup_operations(naga::valid::SubgroupOperationSet::empty())
190    .validate(&module)
191    .map_err(|error| {
192        syn::Error::new(
193            wgsl_source_span,
194            format_args!("failed to validate WGSL: {}", ErrorChain(&error)),
195        )
196    })?;
197
198    let translated_source: String = naga_rust_back::write_string(&module, &module_info, config)
199        .map_err(|error| {
200            syn::Error::new(
201                wgsl_source_span,
202                format_args!("failed to translate shader to Rust: {}", ErrorChain(&error)),
203            )
204        })?;
205
206    let translated_tokens: proc_macro2::TokenStream =
207        translated_source.parse().map_err(|error| {
208            syn::Error::new(
209                wgsl_source_span,
210                format_args!(
211                    "internal error: translator did not produce valid Rust: {}",
212                    ErrorChain(&error)
213                ),
214            )
215        })?;
216
217    Ok(translated_tokens)
218}
219
220// -------------------------------------------------------------------------------------------------
221
222/// Formatting wrapper which prints an [`Error`] together with its `source()` chain.
223///
224/// The text begins with the [`fmt::Display`] format of the error.
225#[derive(Clone, Copy, Debug)]
226struct ErrorChain<'a>(&'a (dyn Error + 'a));
227
228impl fmt::Display for ErrorChain<'_> {
229    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
230        format_error_chain(fmt, self.0)
231    }
232}
233
234fn format_error_chain(fmt: &mut fmt::Formatter<'_>, mut error: &(dyn Error + '_)) -> fmt::Result {
235    write!(fmt, "{error}")?;
236    while let Some(source) = error.source() {
237        error = source;
238        write!(fmt, "\n↳ {error}")?;
239    }
240
241    Ok(())
242}