encase_derive_impl/
lib.rs

1use proc_macro2::{Ident, Literal, Span, TokenStream};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::{
4    parenthesized,
5    parse::{Parse, ParseStream},
6    parse_quote,
7    punctuated::Punctuated,
8    spanned::Spanned,
9    token::Comma,
10    Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Token,
11    Type,
12};
13
14pub use syn;
15
16#[macro_export]
17macro_rules! implement {
18    ($path:expr) => {
19        #[proc_macro_derive(ShaderType, attributes(shader))]
20        pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21            let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput);
22            let expanded = encase_derive_impl::derive_shader_type(input, &$path);
23            proc_macro::TokenStream::from(expanded)
24        }
25    };
26}
27
28fn get_named_struct_fields(data: &syn::Data) -> syn::Result<&FieldsNamed> {
29    match data {
30        Data::Struct(DataStruct {
31            fields: Fields::Named(fields),
32            ..
33        }) if !fields.named.is_empty() => Ok(fields),
34        _ => Err(Error::new(
35            Span::call_site(),
36            "Only non empty structs with named fields are supported!",
37        )),
38    }
39}
40
41struct FieldData {
42    pub field: syn::Field,
43    pub size: Option<(u32, Span)>,
44    pub align: Option<(u32, Span)>,
45}
46
47impl FieldData {
48    fn alignment(&self, root: &Path) -> TokenStream {
49        if let Some((alignment, _)) = self.align {
50            let alignment = Literal::u64_suffixed(alignment as u64);
51            quote! {
52                #root::AlignmentValue::new(#alignment)
53            }
54        } else {
55            let ty = &self.field.ty;
56            quote! {
57                <#ty as #root::ShaderType>::METADATA.alignment()
58            }
59        }
60    }
61
62    fn size(&self, root: &Path) -> TokenStream {
63        if let Some((size, _)) = self.size {
64            let size = Literal::u64_suffixed(size as u64);
65            quote! {
66                #size
67            }
68        } else {
69            let ty = &self.field.ty;
70            quote! {
71                <#ty as #root::ShaderSize>::SHADER_SIZE.get()
72            }
73        }
74    }
75
76    fn min_size(&self, root: &Path) -> TokenStream {
77        if let Some((size, _)) = self.size {
78            let size = Literal::u64_suffixed(size as u64);
79            quote! {
80                #size
81            }
82        } else {
83            let ty = &self.field.ty;
84            quote! {
85                <#ty as #root::ShaderType>::METADATA.min_size().get()
86            }
87        }
88    }
89
90    fn extra_padding(&self, root: &Path) -> Option<TokenStream> {
91        self.size.as_ref().map(|(size, _)| {
92            let size = Literal::u64_suffixed(*size as u64);
93            let ty = &self.field.ty;
94            let original_size = quote! { <#ty as #root::ShaderSize>::SHADER_SIZE.get() };
95            quote!(#size.saturating_sub(#original_size))
96        })
97    }
98
99    fn ident(&self) -> &Ident {
100        self.field.ident.as_ref().unwrap()
101    }
102}
103
104#[derive(Debug)]
105pub struct AlignmentAttr(u32);
106
107impl Parse for AlignmentAttr {
108    fn parse(input: ParseStream) -> syn::Result<Self> {
109        match input
110            .parse::<LitInt>()
111            .and_then(|lit| lit.base10_parse::<u32>())
112        {
113            Ok(num) if num.is_power_of_two() => Ok(Self(num)),
114            _ => Err(syn::Error::new(
115                input.span(),
116                "expected a power of 2 u32 literal",
117            )),
118        }
119    }
120}
121
122#[derive(Debug)]
123pub struct StaticSizeAttr(u32);
124
125impl Parse for StaticSizeAttr {
126    fn parse(input: ParseStream) -> syn::Result<Self> {
127        let span = input.span();
128        match input
129            .parse::<LitInt>()
130            .and_then(|lit| lit.base10_parse::<u32>())
131        {
132            Ok(num) => Ok(Self(num)),
133            _ => Err(syn::Error::new(span, "expected u32 literal")),
134        }
135    }
136}
137
138#[derive(Debug)]
139pub enum SizeAttr {
140    Static(StaticSizeAttr),
141    Runtime,
142}
143
144impl Parse for SizeAttr {
145    fn parse(input: ParseStream) -> syn::Result<Self> {
146        let span = input.span();
147        match input.parse::<StaticSizeAttr>() {
148            Ok(static_size) => Ok(SizeAttr::Static(static_size)),
149            _ => match input.parse::<Path>() {
150                Ok(ident) if ident.is_ident("runtime") => Ok(SizeAttr::Runtime),
151                _ => Err(syn::Error::new(
152                    span,
153                    "expected u32 literal or `runtime` identifier",
154                )),
155            },
156        }
157    }
158}
159
160#[derive(Debug)]
161pub enum ShaderAttr {
162    Align { attr: AlignmentAttr, span: Span },
163    Size { attr: SizeAttr, span: Span },
164}
165
166impl Parse for ShaderAttr {
167    fn parse(input: ParseStream) -> syn::Result<Self> {
168        let ident_span = input.span();
169        let Ok(ident) = input.parse::<Ident>() else {
170            return Err(syn::Error::new(ident_span, "expected `align` or `size`"));
171        };
172
173        match ident.to_string().as_str() {
174            "align" => {
175                if !input.peek(syn::token::Paren) {
176                    return Err(syn::Error::new(
177                        ident_span,
178                        "expected attribute arguments in parentheses: `align(...)`",
179                    ));
180                }
181
182                let args;
183                parenthesized!(args in input);
184                let attr_span = args.span();
185                let align_attr: AlignmentAttr = args.parse()?;
186                Ok(ShaderAttr::Align {
187                    attr: align_attr,
188                    span: attr_span,
189                })
190            }
191            "size" => {
192                if !input.peek(syn::token::Paren) {
193                    return Err(syn::Error::new(
194                        ident_span,
195                        "expected attribute arguments in parentheses: `size(...)`",
196                    ));
197                }
198
199                let args;
200                parenthesized!(args in input);
201                let attr_span = args.span();
202                let size_attr: SizeAttr = args.parse()?;
203                Ok(ShaderAttr::Size {
204                    attr: size_attr,
205                    span: attr_span,
206                })
207            }
208            _ => Err(syn::Error::new(
209                ident_span,
210                "unknown shader attribute, expected `align` or `size`",
211            )),
212        }
213    }
214}
215
216#[derive(Debug)]
217pub struct ShaderAttrList(Punctuated<ShaderAttr, Token![,]>);
218impl Parse for ShaderAttrList {
219    fn parse(input: ParseStream) -> syn::Result<Self> {
220        Ok(Self(input.parse_terminated(ShaderAttr::parse, Token![,])?))
221    }
222}
223
224struct Errors {
225    inner: Option<Error>,
226}
227
228impl Errors {
229    fn new() -> Self {
230        Self { inner: None }
231    }
232
233    fn append(&mut self, err: Error) {
234        if let Some(ex_error) = &mut self.inner {
235            ex_error.combine(err);
236        } else {
237            self.inner.replace(err);
238        }
239    }
240
241    fn into_compile_error(self) -> Option<TokenStream> {
242        self.inner.map(|e| e.into_compile_error())
243    }
244}
245
246pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
247    let root = &parse_quote!(#root::private);
248
249    let fields = match get_named_struct_fields(&input.data) {
250        Ok(fields) => fields,
251        Err(e) => return e.into_compile_error(),
252    };
253
254    let last_field_index = fields.named.len() - 1;
255
256    let mut errors = Errors::new();
257
258    let mut is_runtime_sized = false;
259
260    let field_data: Vec<_> = fields
261        .named
262        .iter()
263        .enumerate()
264        .map(|(i, field)| {
265            let mut data = FieldData {
266                field: field.clone(),
267                size: None,
268                align: None,
269            };
270
271            for attr in &field.attrs {
272                if !(attr.meta.path().is_ident("shader")) {
273                    continue;
274                }
275
276                let shader_attrs = match attr.parse_args::<ShaderAttrList>() {
277                    Ok(attrs) => attrs,
278                    Err(err) => {
279                        errors.append(err);
280                        continue;
281                    }
282                };
283
284                for shader_attr in shader_attrs.0 {
285                    match shader_attr {
286                        ShaderAttr::Align { attr, span } => {
287                            if data.align.is_some() {
288                                let err = syn::Error::new(span, "duplicate `align(X)` attribute");
289                                errors.append(err);
290                                continue;
291                            }
292
293                            data.align = Some((attr.0, span));
294                        }
295                        ShaderAttr::Size { attr, span } => {
296                            if data.size.is_some() || is_runtime_sized {
297                                let err = syn::Error::new(span, "duplicate `size(X)` attribute");
298                                errors.append(err);
299                                continue;
300                            }
301
302                            match attr {
303                                SizeAttr::Runtime => {
304                                    if i == last_field_index {
305                                        is_runtime_sized = true;
306                                    } else {
307                                        let err = syn::Error::new(
308                                            span,
309                                            "only the last field can be `size(runtime)`",
310                                        );
311                                        errors.append(err);
312                                        continue;
313                                    }
314                                }
315                                SizeAttr::Static(attr) => {
316                                    data.size = Some((attr.0, span));
317                                }
318                            }
319                        }
320                    }
321                }
322            }
323            data
324        })
325        .collect();
326
327    let mut found = false;
328    let size_hint: &Path = &parse_quote!(#root::ArrayLength);
329    for field in &fields.named {
330        // TODO: rethink how to check type equality here
331        match &field.ty {
332            Type::Path(path)
333                if path.path.segments.last().unwrap().ident
334                    == size_hint.segments.last().unwrap().ident =>
335            {
336                if found {
337                    let err = syn::Error::new(
338                        field.ty.span(),
339                        "only one field can use the `ArrayLength` type!",
340                    );
341                    errors.append(err)
342                } else {
343                    if !is_runtime_sized {
344                        let err = syn::Error::new(
345                                field.ty.span(),
346                                "`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[shader(size(runtime))]`!",
347                            );
348                        errors.append(err)
349                    }
350                    found = true;
351                }
352            }
353            _ => {}
354        }
355    }
356
357    if let Some(ts) = errors.into_compile_error() {
358        return ts;
359    }
360
361    let nr_of_fields = &Literal::usize_suffixed(field_data.len());
362
363    let field_trait_constraints = generate_field_trait_constraints(
364        &input,
365        &field_data,
366        if is_runtime_sized {
367            quote!(#root::ShaderType + #root::RuntimeSizedArray)
368        } else {
369            quote!(#root::ShaderType + #root::ShaderSize)
370        },
371        quote!(#root::ShaderType + #root::ShaderSize),
372    );
373
374    let mut lifetimes = input.generics.clone();
375    lifetimes.params = lifetimes
376        .params
377        .into_iter()
378        .filter(|param| matches!(param, GenericParam::Lifetime(_)))
379        .collect::<Punctuated<GenericParam, Comma>>();
380
381    let align_check = {
382        let (impl_generics, _, _) = lifetimes.split_for_impl();
383        field_data
384            .iter()
385            .filter_map(|data| data.align.as_ref().map(|align| (&data.field.ty, align)))
386            .map(move |(ty, (align, span))| {
387                let align = Literal::u64_suffixed(*align as u64);
388                quote_spanned! {*span=>
389                    const _: () = {
390                        #[track_caller]
391                        #[allow(clippy::extra_unused_lifetimes)]
392                        const fn check #impl_generics () {
393                            let alignment = <#ty as #root::ShaderType>::METADATA.alignment().get();
394                            #root::concat_assert!(
395                                alignment <= #align,
396                                "shader(align) attribute value must be at least ", alignment, " (field's type alignment)"
397                            )
398                        }
399                        check();
400                    };
401                }
402            })
403    };
404
405    let size_check = {
406        let (impl_generics, _, _) = lifetimes.split_for_impl();
407        field_data
408            .iter()
409            .filter_map(|data| data.size.as_ref().map(|size| (&data.field.ty, size)))
410            .map(move |(ty, (size, span))| {
411                let size = Literal::u64_suffixed(*size as u64);
412                quote_spanned! {*span=>
413                    const _: () = {
414                        #[track_caller]
415                        #[allow(clippy::extra_unused_lifetimes)]
416                        const fn check #impl_generics () {
417                            let size = <#ty as #root::ShaderSize>::SHADER_SIZE.get();
418                            #root::concat_assert!(
419                                size <= #size,
420                                "size attribute value must be at least ", size, " (field's type size)"
421                            )
422                        }
423                        check();
424                    };
425                }
426            })
427    };
428
429    let uniform_check = field_data.iter().enumerate().map(|(i, data)| {
430        let ty = &data.field.ty;
431        let ty_check = quote_spanned! {ty.span()=>
432            <#ty as #root::ShaderType>::UNIFORM_COMPAT_ASSERT()
433        };
434        let ident = data.ident();
435        let name = ident.to_string();
436        let field_offset_check = quote_spanned! {ident.span()=>
437            if let ::core::option::Option::Some(min_alignment) =
438                <#ty as #root::ShaderType>::METADATA.uniform_min_alignment()
439            {
440                let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
441
442                #root::concat_assert!(
443                    min_alignment.is_aligned(offset),
444                    "offset of field '", #name, "' must be a multiple of ", min_alignment.get(),
445                    " (current offset: ", offset, ")"
446                )
447            }
448        };
449        let field_offset_diff = if i != 0 {
450            let prev_field = &field_data[i - 1];
451            let prev_field_ty = &prev_field.field.ty;
452            let prev_ident_name = prev_field.ident().to_string();
453            quote_spanned! {ident.span()=>
454                if let ::core::option::Option::Some(min_alignment) =
455                    <#prev_field_ty as #root::ShaderType>::METADATA.uniform_min_alignment()
456                {
457                    let prev_offset = <Self as #root::ShaderType>::METADATA.offset(#i - 1);
458                    let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
459                    let diff = offset - prev_offset;
460
461                    let prev_size = <#prev_field_ty as #root::ShaderSize>::SHADER_SIZE.get();
462                    let prev_size = min_alignment.round_up(prev_size);
463
464                    #root::concat_assert!(
465                        diff >= prev_size,
466                        "offset between fields '", #prev_ident_name, "' and '", #name, "' must be at least ",
467                        min_alignment.get(), " (currently: ", diff, ")"
468                    )
469                }
470            }
471        } else {
472            quote! {()}
473        };
474        quote! {
475            #ty_check,
476            #field_offset_check,
477            #field_offset_diff
478        }
479    });
480
481    let alignments = field_data.iter().map(|data| data.alignment(root));
482
483    let paddings = field_data.iter().enumerate().map(|(i, current)| {
484        let is_first = i == 0;
485        let is_last = i == field_data.len() - 1;
486
487        let mut out = TokenStream::new();
488
489        if !is_first {
490            let prev_i = i - 1;
491
492            let alignment = current.alignment(root);
493
494            let extra_padding = field_data
495                .get(prev_i)
496                .and_then(|prev| prev.extra_padding(root))
497                .map(|extra_padding| quote!(+ #extra_padding));
498
499            out.extend(quote! {
500                offsets[#i] = #alignment.round_up(offset);
501
502                let padding = #alignment.padding_needed_for(offset);
503                offset += padding;
504                paddings[#prev_i] = padding #extra_padding;
505            });
506        };
507
508        if is_last && is_runtime_sized {
509            return out;
510        }
511
512        let size = current.size(root);
513        out.extend(quote! {
514            offset += #size;
515        });
516
517        if is_last {
518            let extra_padding = current
519                .extra_padding(root)
520                .map(|extra_padding| quote!(+ #extra_padding));
521
522            out.extend(quote! {
523                paddings[#i] = struct_alignment.padding_needed_for(offset) #extra_padding;
524            });
525        }
526
527        out
528    });
529
530    fn gen_body<'a>(
531        field_data: &'a [FieldData],
532        root: &'a Path,
533        get_main: impl Fn(&Ident) -> TokenStream + 'a,
534        get_padding: impl Fn(TokenStream) -> TokenStream + 'a,
535    ) -> impl Iterator<Item = TokenStream> + 'a {
536        field_data.iter().enumerate().map(move |(i, data)| {
537            let ident = data.ident();
538
539            let padding = {
540                let i = Literal::usize_suffixed(i);
541                quote! { <Self as #root::ShaderType>::METADATA.padding(#i) }
542            };
543
544            let main = get_main(ident);
545            let padding = get_padding(padding);
546
547            quote! {
548                #main
549                #padding
550            }
551        })
552    }
553
554    let write_into_buffer_body = gen_body(
555        &field_data,
556        root,
557        |ident| {
558            quote! {
559                #root::WriteInto::write_into(&self.#ident, writer);
560            }
561        },
562        |padding| {
563            quote! {
564                #root::Writer::advance(writer, #padding as ::core::primitive::usize);
565            }
566        },
567    );
568
569    let read_from_buffer_body = gen_body(
570        &field_data,
571        root,
572        |ident| {
573            quote! {
574                #root::ReadFrom::read_from(&mut self.#ident, reader);
575            }
576        },
577        |padding| {
578            quote! {
579                #root::Reader::advance(reader, #padding as ::core::primitive::usize);
580            }
581        },
582    );
583
584    let create_from_buffer_body = gen_body(
585        &field_data,
586        root,
587        move |ident| {
588            quote! {
589                let #ident = #root::CreateFrom::create_from(reader);
590            }
591        },
592        |padding| {
593            quote! {
594                #root::Reader::advance(reader, #padding as ::core::primitive::usize);
595            }
596        },
597    );
598
599    let field_idents = field_data.iter().map(|data| data.ident());
600    let last_field = field_data.last().unwrap();
601    let last_field_min_size = last_field.min_size(root);
602    let last_field_ident = last_field.ident();
603
604    let field_types = field_data.iter().map(|data| &data.field.ty);
605    let field_types_2 = field_types.clone();
606    let field_types_3 = field_types.clone();
607    let field_types_4 = field_types.clone();
608    let all_other = field_types.clone().take(last_field_index);
609    let last_field_type = &last_field.field.ty;
610
611    let name = &input.ident;
612    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
613
614    let set_contained_rt_sized_array_length = if is_runtime_sized {
615        quote! {
616            writer.ctx.rts_array_length = ::core::option::Option::Some(
617                #root::RuntimeSizedArray::len(&self.#last_field_ident)
618                as ::core::primitive::u32
619            );
620        }
621    } else {
622        TokenStream::new()
623    };
624
625    let extra = match is_runtime_sized {
626        true => quote! {
627            impl #impl_generics #root::CalculateSizeFor for #name #ty_generics
628            where
629                Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
630                #last_field_type: #root::CalculateSizeFor,
631            {
632                fn calculate_size_for(nr_of_el: ::core::primitive::u64) -> ::core::num::NonZeroU64 {
633                    let mut offset = <Self as #root::ShaderType>::METADATA.last_offset();
634                    offset += <#last_field_type as #root::CalculateSizeFor>::calculate_size_for(nr_of_el).get();
635                    #root::SizeValue::new(<Self as #root::ShaderType>::METADATA.alignment().round_up(offset)).0
636                }
637            }
638        },
639        false => quote! {
640            impl #impl_generics #root::ShaderSize for #name #ty_generics
641            where
642                #( #field_types: #root::ShaderSize, )*
643            {}
644        },
645    };
646
647    // Note:
648    // The unused HRTBs on WriteInto, ReadFrom and CreateFrom are there
649    // to avoid #![feature(trivial_bounds)].
650    // Workaround found here: https://github.com/rust-lang/rust/issues/48214#issuecomment-1150463333
651
652    quote! {
653        #( #field_trait_constraints )*
654
655        #( #align_check )*
656
657        #( #size_check )*
658
659        impl #impl_generics #root::ShaderType for #name #ty_generics #where_clause
660        where
661            #( #all_other: #root::ShaderType + #root::ShaderSize, )*
662            #last_field_type: #root::ShaderType,
663        {
664            type ExtraMetadata = #root::StructMetadata<#nr_of_fields>;
665            const METADATA: #root::Metadata<Self::ExtraMetadata> = {
666                let struct_alignment = #root::AlignmentValue::max([ #( #alignments, )* ]);
667
668                let extra = {
669                    let mut paddings = [0; #nr_of_fields];
670                    let mut offsets = [0; #nr_of_fields];
671                    let mut offset = 0;
672                    #( #paddings )*
673                    #root::StructMetadata { offsets, paddings }
674                };
675
676                let min_size = {
677                    let mut offset = extra.offsets[#nr_of_fields - 1];
678                    offset += #last_field_min_size;
679                    #root::SizeValue::new(struct_alignment.round_up(offset))
680                };
681
682                #root::Metadata {
683                    alignment: struct_alignment,
684                    has_uniform_min_alignment: true,
685                    min_size,
686                    is_pod: false,
687                    extra,
688                }
689            };
690
691            const UNIFORM_COMPAT_ASSERT: fn() = || #root::consume_zsts([
692                #( #uniform_check, )*
693            ]);
694
695            fn size(&self) -> ::core::num::NonZeroU64 {
696                let mut offset = Self::METADATA.last_offset();
697                offset += #root::ShaderType::size(&self.#last_field_ident).get();
698                #root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0
699            }
700        }
701
702        impl #impl_generics #root::WriteInto for #name #ty_generics
703        where
704            Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
705            #( for<'__> #field_types_2: #root::WriteInto, )*
706        {
707            #[inline]
708            fn write_into<B: #root::BufferMut>(&self, writer: &mut #root::Writer<B>) {
709                #set_contained_rt_sized_array_length
710                #( #write_into_buffer_body )*
711            }
712        }
713
714        impl #impl_generics #root::ReadFrom for #name #ty_generics
715        where
716            Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
717            #( for<'__> #field_types_3: #root::ReadFrom, )*
718        {
719            #[inline]
720            fn read_from<B: #root::BufferRef>(&mut self, reader: &mut #root::Reader<B>) {
721                #( #read_from_buffer_body )*
722            }
723        }
724
725        impl #impl_generics #root::CreateFrom for #name #ty_generics
726        where
727            Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
728            #( for<'__> #field_types_4: #root::CreateFrom, )*
729        {
730            #[inline]
731            fn create_from<B: #root::BufferRef>(reader: &mut #root::Reader<B>) -> Self {
732                #( #create_from_buffer_body )*
733
734                #root::build_struct!(Self, #( #field_idents ),*)
735            }
736        }
737
738        #extra
739    }
740}
741
742fn generate_field_trait_constraints<'a>(
743    input: &'a DeriveInput,
744    field_data: &'a [FieldData],
745    trait_for_last_field: TokenStream,
746    trait_for_all_other_fields: TokenStream,
747) -> impl Iterator<Item = TokenStream> + 'a {
748    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
749    field_data.iter().enumerate().map(move |(i, data)| {
750        let ty = &data.field.ty;
751
752        let t = if i == field_data.len() - 1 {
753            &trait_for_last_field
754        } else {
755            &trait_for_all_other_fields
756        };
757
758        if ty_generics.to_token_stream().is_empty() {
759            quote_spanned! {ty.span()=>
760                const _: fn() = || {
761                    #[allow(clippy::extra_unused_lifetimes, clippy::missing_const_for_fn, clippy::extra_unused_type_parameters)]
762                    fn check #impl_generics () #where_clause {
763                        fn assert_impl<T: ?::core::marker::Sized + #t>() {}
764                        assert_impl::<#ty>();
765                    }
766                    check ();
767                };
768            }
769        } else {
770            // Case with type generics is not checked for now
771            quote_spanned! {ty.span()=>
772                const _: fn() = || {};
773            }
774        }
775    })
776}