naga_to_tokenstream/
types.rs

1use std::collections::{HashMap, HashSet};
2
3use proc_macro2::TokenStream;
4
5use crate::ModuleToTokensConfig;
6
7/// Returns a base Rust or `glam` type that corresponds to a TypeInner, if one exists.
8fn rust_type(type_inner: &naga::TypeInner, args: &ModuleToTokensConfig) -> Option<syn::Type> {
9    match type_inner {
10        naga::TypeInner::Scalar(naga::Scalar { kind, width }) => match (kind, width) {
11            (naga::ScalarKind::Bool, 1) => Some(syn::parse_quote!(bool)),
12            (naga::ScalarKind::Float, 4) => Some(syn::parse_quote!(f32)),
13            (naga::ScalarKind::Float, 8) => Some(syn::parse_quote!(f64)),
14            (naga::ScalarKind::Sint, 4) => Some(syn::parse_quote!(i32)),
15            (naga::ScalarKind::Sint, 8) => Some(syn::parse_quote!(i64)),
16            (naga::ScalarKind::Uint, 4) => Some(syn::parse_quote!(u32)),
17            (naga::ScalarKind::Uint, 8) => Some(syn::parse_quote!(u64)),
18            _ => None,
19        },
20        naga::TypeInner::Vector {
21            size,
22            scalar: naga::Scalar { kind, width },
23        } => {
24            if args.gen_glam {
25                match (size, kind, width) {
26                    (naga::VectorSize::Bi, naga::ScalarKind::Bool, 1) => {
27                        Some(syn::parse_quote!(glam::bool::BVec2))
28                    }
29                    (naga::VectorSize::Tri, naga::ScalarKind::Bool, 1) => {
30                        Some(syn::parse_quote!(glam::bool::BVec3))
31                    }
32                    (naga::VectorSize::Quad, naga::ScalarKind::Bool, 1) => {
33                        Some(syn::parse_quote!(glam::bool::BVec4))
34                    }
35                    (naga::VectorSize::Bi, naga::ScalarKind::Float, 4) => {
36                        Some(syn::parse_quote!(glam::f32::Vec2))
37                    }
38                    (naga::VectorSize::Tri, naga::ScalarKind::Float, 4) => {
39                        Some(syn::parse_quote!(glam::f32::Vec3))
40                    }
41                    (naga::VectorSize::Quad, naga::ScalarKind::Float, 4) => {
42                        Some(syn::parse_quote!(glam::f32::Vec4))
43                    }
44                    (naga::VectorSize::Bi, naga::ScalarKind::Float, 8) => {
45                        Some(syn::parse_quote!(glam::f64::DVec2))
46                    }
47                    (naga::VectorSize::Tri, naga::ScalarKind::Float, 8) => {
48                        Some(syn::parse_quote!(glam::f64::DVec3))
49                    }
50                    (naga::VectorSize::Quad, naga::ScalarKind::Float, 8) => {
51                        Some(syn::parse_quote!(glam::f64::DVec4))
52                    }
53                    (naga::VectorSize::Bi, naga::ScalarKind::Sint, 4) => {
54                        Some(syn::parse_quote!(glam::i32::IVec2))
55                    }
56                    (naga::VectorSize::Tri, naga::ScalarKind::Sint, 4) => {
57                        Some(syn::parse_quote!(glam::i32::IVec3))
58                    }
59                    (naga::VectorSize::Quad, naga::ScalarKind::Sint, 4) => {
60                        Some(syn::parse_quote!(glam::i32::IVec4))
61                    }
62                    (naga::VectorSize::Bi, naga::ScalarKind::Sint, 8) => {
63                        Some(syn::parse_quote!(glam::i64::I64Vec2))
64                    }
65                    (naga::VectorSize::Tri, naga::ScalarKind::Sint, 8) => {
66                        Some(syn::parse_quote!(glam::i64::I64Vec3))
67                    }
68                    (naga::VectorSize::Quad, naga::ScalarKind::Sint, 8) => {
69                        Some(syn::parse_quote!(glam::i64::I64Vec4))
70                    }
71                    (naga::VectorSize::Bi, naga::ScalarKind::Uint, 4) => {
72                        Some(syn::parse_quote!(glam::u32::UVec2))
73                    }
74                    (naga::VectorSize::Tri, naga::ScalarKind::Uint, 4) => {
75                        Some(syn::parse_quote!(glam::u32::UVec3))
76                    }
77                    (naga::VectorSize::Quad, naga::ScalarKind::Uint, 4) => {
78                        Some(syn::parse_quote!(glam::u32::UVec4))
79                    }
80                    (naga::VectorSize::Bi, naga::ScalarKind::Uint, 8) => {
81                        Some(syn::parse_quote!(glam::u64::U64Vec2))
82                    }
83                    (naga::VectorSize::Tri, naga::ScalarKind::Uint, 8) => {
84                        Some(syn::parse_quote!(glam::u64::U64Vec3))
85                    }
86                    (naga::VectorSize::Quad, naga::ScalarKind::Uint, 8) => {
87                        Some(syn::parse_quote!(glam::u64::U64Vec4))
88                    }
89                    _ => None,
90                }
91            } else {
92                match (size, kind, width) {
93                    (naga::VectorSize::Bi, naga::ScalarKind::Bool, 1) => {
94                        Some(syn::parse_quote!([bool; 2]))
95                    }
96                    (naga::VectorSize::Tri, naga::ScalarKind::Bool, 1) => {
97                        Some(syn::parse_quote!([bool; 3]))
98                    }
99                    (naga::VectorSize::Quad, naga::ScalarKind::Bool, 1) => {
100                        Some(syn::parse_quote!([bool; 4]))
101                    }
102                    (naga::VectorSize::Bi, naga::ScalarKind::Float, 4) => {
103                        Some(syn::parse_quote!([f32; 2]))
104                    }
105                    (naga::VectorSize::Tri, naga::ScalarKind::Float, 4) => {
106                        Some(syn::parse_quote!([f32; 3]))
107                    }
108                    (naga::VectorSize::Quad, naga::ScalarKind::Float, 4) => {
109                        Some(syn::parse_quote!([f32; 4]))
110                    }
111                    (naga::VectorSize::Bi, naga::ScalarKind::Float, 8) => {
112                        Some(syn::parse_quote!([f64; 2]))
113                    }
114                    (naga::VectorSize::Tri, naga::ScalarKind::Float, 8) => {
115                        Some(syn::parse_quote!([f64; 3]))
116                    }
117                    (naga::VectorSize::Quad, naga::ScalarKind::Float, 8) => {
118                        Some(syn::parse_quote!([f64; 4]))
119                    }
120                    (naga::VectorSize::Bi, naga::ScalarKind::Sint, 4) => {
121                        Some(syn::parse_quote!([i32; 2]))
122                    }
123                    (naga::VectorSize::Tri, naga::ScalarKind::Sint, 4) => {
124                        Some(syn::parse_quote!([i32; 3]))
125                    }
126                    (naga::VectorSize::Quad, naga::ScalarKind::Sint, 4) => {
127                        Some(syn::parse_quote!([i32; 4]))
128                    }
129                    (naga::VectorSize::Bi, naga::ScalarKind::Sint, 8) => {
130                        Some(syn::parse_quote!([i64; 2]))
131                    }
132                    (naga::VectorSize::Tri, naga::ScalarKind::Sint, 8) => {
133                        Some(syn::parse_quote!([i64; 3]))
134                    }
135                    (naga::VectorSize::Quad, naga::ScalarKind::Sint, 8) => {
136                        Some(syn::parse_quote!([i64; 4]))
137                    }
138                    (naga::VectorSize::Bi, naga::ScalarKind::Uint, 4) => {
139                        Some(syn::parse_quote!([u32; 2]))
140                    }
141                    (naga::VectorSize::Tri, naga::ScalarKind::Uint, 4) => {
142                        Some(syn::parse_quote!([u32; 3]))
143                    }
144                    (naga::VectorSize::Quad, naga::ScalarKind::Uint, 4) => {
145                        Some(syn::parse_quote!([u32; 4]))
146                    }
147                    (naga::VectorSize::Bi, naga::ScalarKind::Uint, 8) => {
148                        Some(syn::parse_quote!([u64; 2]))
149                    }
150                    (naga::VectorSize::Tri, naga::ScalarKind::Uint, 8) => {
151                        Some(syn::parse_quote!([u64; 3]))
152                    }
153                    (naga::VectorSize::Quad, naga::ScalarKind::Uint, 8) => {
154                        Some(syn::parse_quote!([u64; 4]))
155                    }
156                    _ => None,
157                }
158            }
159        }
160        naga::TypeInner::Matrix {
161            columns,
162            rows,
163            scalar: naga::Scalar { kind, width },
164        } => {
165            if !args.gen_glam {
166                return None;
167            }
168            if columns != rows {
169                return None;
170            }
171            match (kind, width) {
172                (naga::ScalarKind::Float, 4) => match columns {
173                    naga::VectorSize::Bi => Some(syn::parse_quote!(glam::f32::Mat2)),
174                    naga::VectorSize::Tri => Some(syn::parse_quote!(glam::f32::Mat3)),
175                    naga::VectorSize::Quad => Some(syn::parse_quote!(glam::f32::Mat4)),
176                },
177                (naga::ScalarKind::Float, 8) => match columns {
178                    naga::VectorSize::Bi => Some(syn::parse_quote!(glam::f64::Mat2)),
179                    naga::VectorSize::Tri => Some(syn::parse_quote!(glam::f64::Mat3)),
180                    naga::VectorSize::Quad => Some(syn::parse_quote!(glam::f64::Mat4)),
181                },
182                _ => None,
183            }
184        }
185        naga::TypeInner::Atomic(scalar) => rust_type(&naga::TypeInner::Scalar(*scalar), args),
186        _ => None,
187    }
188}
189
190/// A builder for type definition and identifier pairs.
191pub struct TypesDefinitions {
192    definitions: Vec<syn::ItemStruct>,
193    references: HashMap<naga::Handle<naga::Type>, syn::Type>,
194    structs_filter: Option<HashSet<String>>,
195}
196
197impl TypesDefinitions {
198    /// Constructs a new type definition collator, with a given filter for type names.
199    pub fn new(
200        module: &naga::Module,
201        structs_filter: Option<HashSet<String>>,
202        args: &ModuleToTokensConfig,
203    ) -> Self {
204        let mut res = Self {
205            definitions: Vec::new(),
206            references: HashMap::new(),
207            structs_filter,
208        };
209
210        for (ty_handle, _) in module.types.iter() {
211            if let Some(new_ty_ident) = res.try_make_type(ty_handle, module, args) {
212                res.references.insert(ty_handle, new_ty_ident.clone());
213            }
214        }
215
216        res
217    }
218
219    fn try_make_type(
220        &mut self,
221        ty_handle: naga::Handle<naga::Type>,
222        module: &naga::Module,
223        args: &ModuleToTokensConfig,
224    ) -> Option<syn::Type> {
225        let ty = match module.types.get_handle(ty_handle) {
226            Err(_) => return None,
227            Ok(ty) => ty,
228        };
229        if let Some(ty_ident) = rust_type(&ty.inner, args) {
230            return Some(ty_ident);
231        };
232
233        match &ty.inner {
234            naga::TypeInner::Array { base, size, .. }
235            | naga::TypeInner::BindingArray { base, size } => {
236                let base_type = self.rust_type_ident(*base, module, args)?;
237                match size {
238                    naga::ArraySize::Constant(size) => {
239                        let size = size.get();
240                        Some(syn::parse_quote!([#base_type; #size as usize]))
241                    }
242                    naga::ArraySize::Dynamic => Some(syn::parse_quote!(Vec<#base_type>)),
243                    naga::ArraySize::Pending(_) => None,
244                }
245            }
246            naga::TypeInner::Struct { members, .. } => {
247                let struct_name = ty.name.as_ref();
248                let struct_name = struct_name?;
249
250                // Apply filter
251                if let Some(struct_name_filter) = &self.structs_filter {
252                    if !struct_name_filter.contains(struct_name) {
253                        return None;
254                    }
255                }
256
257                let members_have_names = members.iter().all(|member| member.name.is_some());
258                let members: Option<Vec<_>> = members
259                    .iter()
260                    .enumerate()
261                    .map(|(i_member, member)| {
262                        let member_name = if members_have_names {
263                            let member_name =
264                                member.name.as_ref().expect("all members had names").clone();
265                            syn::parse_str::<syn::Ident>(&member_name)
266                        } else {
267                            syn::parse_str::<syn::Ident>(&format!("v{}", i_member))
268                        };
269                        let member_ty = self.rust_type_ident(member.ty, module, args);
270
271                        let mut attributes = proc_macro2::TokenStream::new();
272                        // Runtime-sized fields must be marked as such when using encase
273                        if args.gen_encase {
274                            let ty = module.types.get_handle(member.ty);
275                            if let Ok(naga::Type {
276                                inner:
277                                    naga::TypeInner::Array {
278                                        size: naga::ArraySize::Dynamic,
279                                        ..
280                                    }
281                                    | naga::TypeInner::BindingArray {
282                                        size: naga::ArraySize::Dynamic,
283                                        ..
284                                    },
285                                ..
286                            }) = ty
287                            {
288                                attributes.extend(quote::quote!(#[size(runtime)]))
289                            }
290                        }
291
292                        member_ty.and_then(|member_ty| {
293                            member_name.ok().map(|member_name| {
294                                quote::quote! {
295                                    #attributes
296                                    pub #member_name: #member_ty
297                                }
298                            })
299                        })
300                    })
301                    .collect();
302                let struct_name = syn::parse_str::<syn::Ident>(struct_name).ok();
303                match (members, struct_name) {
304                    (Some(members), Some(struct_name)) => {
305                        #[allow(unused_mut)]
306                        let mut bonus_struct_derives = TokenStream::new();
307                        if args.gen_encase {
308                            bonus_struct_derives.extend(quote::quote!(encase::ShaderType,))
309                        }
310
311                        self.definitions.push(syn::parse_quote! {
312                            #[allow(unused, non_camel_case_types)]
313                            #[derive(Debug, PartialEq, Clone, #bonus_struct_derives)]
314                            pub struct #struct_name {
315                                #(#members ,)*
316                            }
317                        });
318                        Some(syn::parse_quote!(#struct_name))
319                    }
320                    _ => None,
321                }
322            }
323            _ => None,
324        }
325    }
326
327    /// Takes a handle to a type, and a module where the type resides, and tries to return an identifier
328    /// of that type, in Rust. Note that for structs this will be an identifier in to the set of structs generated
329    /// by calling `TypesDefinitions::definitions()`, so your output should make sure to include everything from
330    /// there in the scope where the returned identifier is used.
331    pub fn rust_type_ident(
332        &mut self,
333        ty_handle: naga::Handle<naga::Type>,
334        module: &naga::Module,
335        args: &ModuleToTokensConfig,
336    ) -> Option<syn::Type> {
337        if let Some(ident) = self.references.get(&ty_handle).cloned() {
338            return Some(ident);
339        }
340
341        if let Some(built) = self.try_make_type(ty_handle, module, args) {
342            self.references.insert(ty_handle, built.clone());
343            return Some(built);
344        }
345
346        None
347    }
348
349    /// Gives the set of definitions required by the identifiers generated by this object. These should be
350    /// emitted somewhere accessible by the places that the identifiers were used.
351    pub fn definitions(self) -> Vec<syn::Item> {
352        self.definitions
353            .into_iter()
354            .map(syn::Item::Struct)
355            .collect()
356    }
357}