encase_derive_impl/
lib.rs

1use proc_macro2::{Ident, Literal, Span, TokenStream};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_quote,
6    punctuated::Punctuated,
7    spanned::Spanned,
8    token::Comma,
9    Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Type,
10};
11
12pub use syn;
13
14#[macro_export]
15macro_rules! implement {
16    ($path:expr) => {
17        #[proc_macro_derive(ShaderType, attributes(align, size))]
18        pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19            let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput);
20            let expanded = encase_derive_impl::derive_shader_type(input, &$path);
21            proc_macro::TokenStream::from(expanded)
22        }
23    };
24}
25
26fn get_named_struct_fields(data: &syn::Data) -> syn::Result<&FieldsNamed> {
27    match data {
28        Data::Struct(DataStruct {
29            fields: Fields::Named(fields),
30            ..
31        }) if !fields.named.is_empty() => Ok(fields),
32        _ => Err(Error::new(
33            Span::call_site(),
34            "Only non empty structs with named fields are supported!",
35        )),
36    }
37}
38
39struct FieldData {
40    pub field: syn::Field,
41    pub size: Option<(u32, Span)>,
42    pub align: Option<(u32, Span)>,
43}
44
45impl FieldData {
46    fn alignment(&self, root: &Path) -> TokenStream {
47        if let Some((alignment, _)) = self.align {
48            let alignment = Literal::u64_suffixed(alignment as u64);
49            quote! {
50                #root::AlignmentValue::new(#alignment)
51            }
52        } else {
53            let ty = &self.field.ty;
54            quote! {
55                <#ty as #root::ShaderType>::METADATA.alignment()
56            }
57        }
58    }
59
60    fn size(&self, root: &Path) -> TokenStream {
61        if let Some((size, _)) = self.size {
62            let size = Literal::u64_suffixed(size as u64);
63            quote! {
64                #size
65            }
66        } else {
67            let ty = &self.field.ty;
68            quote! {
69                <#ty as #root::ShaderSize>::SHADER_SIZE.get()
70            }
71        }
72    }
73
74    fn min_size(&self, root: &Path) -> TokenStream {
75        if let Some((size, _)) = self.size {
76            let size = Literal::u64_suffixed(size as u64);
77            quote! {
78                #size
79            }
80        } else {
81            let ty = &self.field.ty;
82            quote! {
83                <#ty as #root::ShaderType>::METADATA.min_size().get()
84            }
85        }
86    }
87
88    fn extra_padding(&self, root: &Path) -> Option<TokenStream> {
89        self.size.as_ref().map(|(size, _)| {
90            let size = Literal::u64_suffixed(*size as u64);
91            let ty = &self.field.ty;
92            let original_size = quote! { <#ty as #root::ShaderSize>::SHADER_SIZE.get() };
93            quote!(#size.saturating_sub(#original_size))
94        })
95    }
96
97    fn ident(&self) -> &Ident {
98        self.field.ident.as_ref().unwrap()
99    }
100}
101
102struct AlignmentAttr(u32);
103
104impl Parse for AlignmentAttr {
105    fn parse(input: ParseStream) -> syn::Result<Self> {
106        match input
107            .parse::<LitInt>()
108            .and_then(|lit| lit.base10_parse::<u32>())
109        {
110            Ok(num) if num.is_power_of_two() => Ok(Self(num)),
111            _ => Err(syn::Error::new(
112                input.span(),
113                "expected a power of 2 u32 literal",
114            )),
115        }
116    }
117}
118
119struct StaticSizeAttr(u32);
120
121impl Parse for StaticSizeAttr {
122    fn parse(input: ParseStream) -> syn::Result<Self> {
123        match input
124            .parse::<LitInt>()
125            .and_then(|lit| lit.base10_parse::<u32>())
126        {
127            Ok(num) => Ok(Self(num)),
128            _ => Err(syn::Error::new(input.span(), "expected u32 literal")),
129        }
130    }
131}
132
133enum SizeAttr {
134    Static(StaticSizeAttr),
135    Runtime,
136}
137
138impl Parse for SizeAttr {
139    fn parse(input: ParseStream) -> syn::Result<Self> {
140        match input.parse::<StaticSizeAttr>() {
141            Ok(static_size) => Ok(SizeAttr::Static(static_size)),
142            _ => match input.parse::<Path>() {
143                Ok(ident) if ident.is_ident("runtime") => Ok(SizeAttr::Runtime),
144                _ => Err(syn::Error::new(
145                    input.span(),
146                    "expected u32 literal or `runtime` identifier",
147                )),
148            },
149        }
150    }
151}
152
153struct Errors {
154    inner: Option<Error>,
155}
156
157impl Errors {
158    fn new() -> Self {
159        Self { inner: None }
160    }
161
162    fn append(&mut self, err: Error) {
163        if let Some(ex_error) = &mut self.inner {
164            ex_error.combine(err);
165        } else {
166            self.inner.replace(err);
167        }
168    }
169
170    fn into_compile_error(self) -> Option<TokenStream> {
171        self.inner.map(|e| e.into_compile_error())
172    }
173}
174
175pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
176    let root = &parse_quote!(#root::private);
177
178    let fields = match get_named_struct_fields(&input.data) {
179        Ok(fields) => fields,
180        Err(e) => return e.into_compile_error(),
181    };
182
183    let last_field_index = fields.named.len() - 1;
184
185    let mut errors = Errors::new();
186
187    let mut is_runtime_sized = false;
188
189    let field_data: Vec<_> = fields
190        .named
191        .iter()
192        .enumerate()
193        .map(|(i, field)| {
194            let mut data = FieldData {
195                field: field.clone(),
196                size: None,
197                align: None,
198            };
199            for attr in &field.attrs {
200                if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("align")) {
201                    continue;
202                }
203                match attr.meta.require_list() {
204                    Ok(meta_list) => {
205                        let span = meta_list.tokens.span();
206                        if meta_list.path.is_ident("align") {
207                            let res = attr.parse_args::<AlignmentAttr>();
208                            match res {
209                                Ok(val) => data.align = Some((val.0, span)),
210                                Err(err) => errors.append(err),
211                            }
212                        } else if meta_list.path.is_ident("size") {
213                            let res = if i == last_field_index {
214                                attr.parse_args::<SizeAttr>().map(|val| match val {
215                                    SizeAttr::Runtime => {
216                                        is_runtime_sized = true;
217                                        None
218                                    }
219                                    SizeAttr::Static(size) => Some((size.0, span)),
220                                })
221                            } else {
222                                attr.parse_args::<StaticSizeAttr>()
223                                    .map(|val| Some((val.0, span)))
224                            };
225                            match res {
226                                Ok(val) => data.size = val,
227                                Err(err) => errors.append(err),
228                            }
229                        }
230                    }
231                    Err(err) => errors.append(err),
232                };
233            }
234            data
235        })
236        .collect();
237
238    let mut found = false;
239    let size_hint: &Path = &parse_quote!(#root::ArrayLength);
240    for field in &fields.named {
241        // TODO: rethink how to check type equality here
242        match &field.ty {
243            Type::Path(path)
244                if path.path.segments.last().unwrap().ident
245                    == size_hint.segments.last().unwrap().ident =>
246            {
247                if found {
248                    let err = syn::Error::new(
249                        field.ty.span(),
250                        "only one field can use the `ArrayLength` type!",
251                    );
252                    errors.append(err)
253                } else {
254                    if !is_runtime_sized {
255                        let err = syn::Error::new(
256                                field.ty.span(),
257                                "`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[size(runtime)]`!",
258                            );
259                        errors.append(err)
260                    }
261                    found = true;
262                }
263            }
264            _ => {}
265        }
266    }
267
268    if let Some(ts) = errors.into_compile_error() {
269        return ts;
270    }
271
272    let nr_of_fields = &Literal::usize_suffixed(field_data.len());
273
274    let field_trait_constraints = generate_field_trait_constraints(
275        &input,
276        &field_data,
277        if is_runtime_sized {
278            quote!(#root::ShaderType + #root::RuntimeSizedArray)
279        } else {
280            quote!(#root::ShaderType + #root::ShaderSize)
281        },
282        quote!(#root::ShaderType + #root::ShaderSize),
283    );
284
285    let mut lifetimes = input.generics.clone();
286    lifetimes.params = lifetimes
287        .params
288        .into_iter()
289        .filter(|param| matches!(param, GenericParam::Lifetime(_)))
290        .collect::<Punctuated<GenericParam, Comma>>();
291
292    let align_check = {
293        let (impl_generics, _, _) = lifetimes.split_for_impl();
294        field_data
295            .iter()
296            .filter_map(|data| data.align.as_ref().map(|align| (&data.field.ty, align)))
297            .map(move |(ty, (align, span))| {
298                let align = Literal::u64_suffixed(*align as u64);
299                quote_spanned! {*span=>
300                    const _: () = {
301                        #[track_caller]
302                        #[allow(clippy::extra_unused_lifetimes)]
303                        const fn check #impl_generics () {
304                            let alignment = <#ty as #root::ShaderType>::METADATA.alignment().get();
305                            #root::concat_assert!(
306                                alignment <= #align,
307                                "align attribute value must be at least ", alignment, " (field's type alignment)"
308                            )
309                        }
310                        check();
311                    };
312                }
313            })
314    };
315
316    let size_check = {
317        let (impl_generics, _, _) = lifetimes.split_for_impl();
318        field_data
319            .iter()
320            .filter_map(|data| data.size.as_ref().map(|size| (&data.field.ty, size)))
321            .map(move |(ty, (size, span))| {
322                let size = Literal::u64_suffixed(*size as u64);
323                quote_spanned! {*span=>
324                    const _: () = {
325                        #[track_caller]
326                        #[allow(clippy::extra_unused_lifetimes)]
327                        const fn check #impl_generics () {
328                            let size = <#ty as #root::ShaderSize>::SHADER_SIZE.get();
329                            #root::concat_assert!(
330                                size <= #size,
331                                "size attribute value must be at least ", size, " (field's type size)"
332                            )
333                        }
334                        check();
335                    };
336                }
337            })
338    };
339
340    let uniform_check = field_data.iter().enumerate().map(|(i, data)| {
341        let ty = &data.field.ty;
342        let ty_check = quote_spanned! {ty.span()=>
343            <#ty as #root::ShaderType>::UNIFORM_COMPAT_ASSERT()
344        };
345        let ident = data.ident();
346        let name = ident.to_string();
347        let field_offset_check = quote_spanned! {ident.span()=>
348            if let ::core::option::Option::Some(min_alignment) =
349                <#ty as #root::ShaderType>::METADATA.uniform_min_alignment()
350            {
351                let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
352
353                #root::concat_assert!(
354                    min_alignment.is_aligned(offset),
355                    "offset of field '", #name, "' must be a multiple of ", min_alignment.get(),
356                    " (current offset: ", offset, ")"
357                )
358            }
359        };
360        let field_offset_diff = if i != 0 {
361            let prev_field = &field_data[i - 1];
362            let prev_field_ty = &prev_field.field.ty;
363            let prev_ident_name = prev_field.ident().to_string();
364            quote_spanned! {ident.span()=>
365                if let ::core::option::Option::Some(min_alignment) =
366                    <#prev_field_ty as #root::ShaderType>::METADATA.uniform_min_alignment()
367                {
368                    let prev_offset = <Self as #root::ShaderType>::METADATA.offset(#i - 1);
369                    let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
370                    let diff = offset - prev_offset;
371
372                    let prev_size = <#prev_field_ty as #root::ShaderSize>::SHADER_SIZE.get();
373                    let prev_size = min_alignment.round_up(prev_size);
374
375                    #root::concat_assert!(
376                        diff >= prev_size,
377                        "offset between fields '", #prev_ident_name, "' and '", #name, "' must be at least ",
378                        min_alignment.get(), " (currently: ", diff, ")"
379                    )
380                }
381            }
382        } else {
383            quote! {()}
384        };
385        quote! {
386            #ty_check,
387            #field_offset_check,
388            #field_offset_diff
389        }
390    });
391
392    let alignments = field_data.iter().map(|data| data.alignment(root));
393
394    let paddings = field_data.iter().enumerate().map(|(i, current)| {
395        let is_first = i == 0;
396        let is_last = i == field_data.len() - 1;
397
398        let mut out = TokenStream::new();
399
400        if !is_first {
401            let prev_i = i - 1;
402
403            let alignment = current.alignment(root);
404
405            let extra_padding = field_data
406                .get(prev_i)
407                .and_then(|prev| prev.extra_padding(root))
408                .map(|extra_padding| quote!(+ #extra_padding));
409
410            out.extend(quote! {
411                offsets[#i] = #alignment.round_up(offset);
412
413                let padding = #alignment.padding_needed_for(offset);
414                offset += padding;
415                paddings[#prev_i] = padding #extra_padding;
416            });
417        };
418
419        if is_last && is_runtime_sized {
420            return out;
421        }
422
423        let size = current.size(root);
424        out.extend(quote! {
425            offset += #size;
426        });
427
428        if is_last {
429            let extra_padding = current
430                .extra_padding(root)
431                .map(|extra_padding| quote!(+ #extra_padding));
432
433            out.extend(quote! {
434                paddings[#i] = struct_alignment.padding_needed_for(offset) #extra_padding;
435            });
436        }
437
438        out
439    });
440
441    fn gen_body<'a>(
442        field_data: &'a [FieldData],
443        root: &'a Path,
444        get_main: impl Fn(&Ident) -> TokenStream + 'a,
445        get_padding: impl Fn(TokenStream) -> TokenStream + 'a,
446    ) -> impl Iterator<Item = TokenStream> + 'a {
447        field_data.iter().enumerate().map(move |(i, data)| {
448            let ident = data.ident();
449
450            let padding = {
451                let i = Literal::usize_suffixed(i);
452                quote! { <Self as #root::ShaderType>::METADATA.padding(#i) }
453            };
454
455            let main = get_main(ident);
456            let padding = get_padding(padding);
457
458            quote! {
459                #main
460                #padding
461            }
462        })
463    }
464
465    let write_into_buffer_body = gen_body(
466        &field_data,
467        root,
468        |ident| {
469            quote! {
470                #root::WriteInto::write_into(&self.#ident, writer);
471            }
472        },
473        |padding| {
474            quote! {
475                #root::Writer::advance(writer, #padding as ::core::primitive::usize);
476            }
477        },
478    );
479
480    let read_from_buffer_body = gen_body(
481        &field_data,
482        root,
483        |ident| {
484            quote! {
485                #root::ReadFrom::read_from(&mut self.#ident, reader);
486            }
487        },
488        |padding| {
489            quote! {
490                #root::Reader::advance(reader, #padding as ::core::primitive::usize);
491            }
492        },
493    );
494
495    let create_from_buffer_body = gen_body(
496        &field_data,
497        root,
498        move |ident| {
499            quote! {
500                let #ident = #root::CreateFrom::create_from(reader);
501            }
502        },
503        |padding| {
504            quote! {
505                #root::Reader::advance(reader, #padding as ::core::primitive::usize);
506            }
507        },
508    );
509
510    let field_idents = field_data.iter().map(|data| data.ident());
511    let last_field = field_data.last().unwrap();
512    let last_field_min_size = last_field.min_size(root);
513    let last_field_ident = last_field.ident();
514
515    let field_types = field_data.iter().map(|data| &data.field.ty);
516    let field_types_2 = field_types.clone();
517    let field_types_3 = field_types.clone();
518    let field_types_4 = field_types.clone();
519    let all_other = field_types.clone().take(last_field_index);
520    let last_field_type = &last_field.field.ty;
521
522    let name = &input.ident;
523    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
524
525    let set_contained_rt_sized_array_length = if is_runtime_sized {
526        quote! {
527            writer.ctx.rts_array_length = ::core::option::Option::Some(
528                #root::RuntimeSizedArray::len(&self.#last_field_ident)
529                as ::core::primitive::u32
530            );
531        }
532    } else {
533        TokenStream::new()
534    };
535
536    let extra = match is_runtime_sized {
537        true => quote! {
538            impl #impl_generics #root::CalculateSizeFor for #name #ty_generics
539            where
540                Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
541                #last_field_type: #root::CalculateSizeFor,
542            {
543                fn calculate_size_for(nr_of_el: ::core::primitive::u64) -> ::core::num::NonZeroU64 {
544                    let mut offset = <Self as #root::ShaderType>::METADATA.last_offset();
545                    offset += <#last_field_type as #root::CalculateSizeFor>::calculate_size_for(nr_of_el).get();
546                    #root::SizeValue::new(<Self as #root::ShaderType>::METADATA.alignment().round_up(offset)).0
547                }
548            }
549        },
550        false => quote! {
551            impl #impl_generics #root::ShaderSize for #name #ty_generics
552            where
553                #( #field_types: #root::ShaderSize, )*
554            {}
555        },
556    };
557
558    // Note:
559    // The unused HRTBs on WriteInto, ReadFrom and CreateFrom are there
560    // to avoid #![feature(trivial_bounds)].
561    // Workaround found here: https://github.com/rust-lang/rust/issues/48214#issuecomment-1150463333
562
563    quote! {
564        #( #field_trait_constraints )*
565
566        #( #align_check )*
567
568        #( #size_check )*
569
570        impl #impl_generics #root::ShaderType for #name #ty_generics #where_clause
571        where
572            #( #all_other: #root::ShaderType + #root::ShaderSize, )*
573            #last_field_type: #root::ShaderType,
574        {
575            type ExtraMetadata = #root::StructMetadata<#nr_of_fields>;
576            const METADATA: #root::Metadata<Self::ExtraMetadata> = {
577                let struct_alignment = #root::AlignmentValue::max([ #( #alignments, )* ]);
578
579                let extra = {
580                    let mut paddings = [0; #nr_of_fields];
581                    let mut offsets = [0; #nr_of_fields];
582                    let mut offset = 0;
583                    #( #paddings )*
584                    #root::StructMetadata { offsets, paddings }
585                };
586
587                let min_size = {
588                    let mut offset = extra.offsets[#nr_of_fields - 1];
589                    offset += #last_field_min_size;
590                    #root::SizeValue::new(struct_alignment.round_up(offset))
591                };
592
593                #root::Metadata {
594                    alignment: struct_alignment,
595                    has_uniform_min_alignment: true,
596                    min_size,
597                    is_pod: false,
598                    extra,
599                }
600            };
601
602            const UNIFORM_COMPAT_ASSERT: fn() = || #root::consume_zsts([
603                #( #uniform_check, )*
604            ]);
605
606            fn size(&self) -> ::core::num::NonZeroU64 {
607                let mut offset = Self::METADATA.last_offset();
608                offset += #root::ShaderType::size(&self.#last_field_ident).get();
609                #root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0
610            }
611        }
612
613        impl #impl_generics #root::WriteInto for #name #ty_generics
614        where
615            Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
616            #( for<'__> #field_types_2: #root::WriteInto, )*
617        {
618            #[inline]
619            fn write_into<B: #root::BufferMut>(&self, writer: &mut #root::Writer<B>) {
620                #set_contained_rt_sized_array_length
621                #( #write_into_buffer_body )*
622            }
623        }
624
625        impl #impl_generics #root::ReadFrom for #name #ty_generics
626        where
627            Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
628            #( for<'__> #field_types_3: #root::ReadFrom, )*
629        {
630            #[inline]
631            fn read_from<B: #root::BufferRef>(&mut self, reader: &mut #root::Reader<B>) {
632                #( #read_from_buffer_body )*
633            }
634        }
635
636        impl #impl_generics #root::CreateFrom for #name #ty_generics
637        where
638            Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
639            #( for<'__> #field_types_4: #root::CreateFrom, )*
640        {
641            #[inline]
642            fn create_from<B: #root::BufferRef>(reader: &mut #root::Reader<B>) -> Self {
643                #( #create_from_buffer_body )*
644
645                #root::build_struct!(Self, #( #field_idents ),*)
646            }
647        }
648
649        #extra
650    }
651}
652
653fn generate_field_trait_constraints<'a>(
654    input: &'a DeriveInput,
655    field_data: &'a [FieldData],
656    trait_for_last_field: TokenStream,
657    trait_for_all_other_fields: TokenStream,
658) -> impl Iterator<Item = TokenStream> + 'a {
659    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
660    field_data.iter().enumerate().map(move |(i, data)| {
661        let ty = &data.field.ty;
662
663        let t = if i == field_data.len() - 1 {
664            &trait_for_last_field
665        } else {
666            &trait_for_all_other_fields
667        };
668
669        if ty_generics.to_token_stream().is_empty() {
670            quote_spanned! {ty.span()=>
671                const _: fn() = || {
672                    #[allow(clippy::extra_unused_lifetimes, clippy::missing_const_for_fn, clippy::extra_unused_type_parameters)]
673                    fn check #impl_generics () #where_clause {
674                        fn assert_impl<T: ?::core::marker::Sized + #t>() {}
675                        assert_impl::<#ty>();
676                    }
677                    check ();
678                };
679            }
680        } else {
681            // Case with type generics is not checked for now
682            quote_spanned! {ty.span()=>
683                const _: fn() = || {};
684            }
685        }
686    })
687}