1use proc_macro2::{Ident, Literal, Span, TokenStream};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::{
4 parenthesized,
5 parse::{Parse, ParseStream},
6 parse_quote,
7 punctuated::Punctuated,
8 spanned::Spanned,
9 token::Comma,
10 Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Token,
11 Type,
12};
13
14pub use syn;
15
16#[macro_export]
17macro_rules! implement {
18 ($path:expr) => {
19 #[proc_macro_derive(ShaderType, attributes(shader))]
20 pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21 let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput);
22 let expanded = encase_derive_impl::derive_shader_type(input, &$path);
23 proc_macro::TokenStream::from(expanded)
24 }
25 };
26}
27
28fn get_named_struct_fields(data: &syn::Data) -> syn::Result<&FieldsNamed> {
29 match data {
30 Data::Struct(DataStruct {
31 fields: Fields::Named(fields),
32 ..
33 }) if !fields.named.is_empty() => Ok(fields),
34 _ => Err(Error::new(
35 Span::call_site(),
36 "Only non empty structs with named fields are supported!",
37 )),
38 }
39}
40
41struct FieldData {
42 pub field: syn::Field,
43 pub size: Option<(u32, Span)>,
44 pub align: Option<(u32, Span)>,
45}
46
47impl FieldData {
48 fn alignment(&self, root: &Path) -> TokenStream {
49 if let Some((alignment, _)) = self.align {
50 let alignment = Literal::u64_suffixed(alignment as u64);
51 quote! {
52 #root::AlignmentValue::new(#alignment)
53 }
54 } else {
55 let ty = &self.field.ty;
56 quote! {
57 <#ty as #root::ShaderType>::METADATA.alignment()
58 }
59 }
60 }
61
62 fn size(&self, root: &Path) -> TokenStream {
63 if let Some((size, _)) = self.size {
64 let size = Literal::u64_suffixed(size as u64);
65 quote! {
66 #size
67 }
68 } else {
69 let ty = &self.field.ty;
70 quote! {
71 <#ty as #root::ShaderSize>::SHADER_SIZE.get()
72 }
73 }
74 }
75
76 fn min_size(&self, root: &Path) -> TokenStream {
77 if let Some((size, _)) = self.size {
78 let size = Literal::u64_suffixed(size as u64);
79 quote! {
80 #size
81 }
82 } else {
83 let ty = &self.field.ty;
84 quote! {
85 <#ty as #root::ShaderType>::METADATA.min_size().get()
86 }
87 }
88 }
89
90 fn extra_padding(&self, root: &Path) -> Option<TokenStream> {
91 self.size.as_ref().map(|(size, _)| {
92 let size = Literal::u64_suffixed(*size as u64);
93 let ty = &self.field.ty;
94 let original_size = quote! { <#ty as #root::ShaderSize>::SHADER_SIZE.get() };
95 quote!(#size.saturating_sub(#original_size))
96 })
97 }
98
99 fn ident(&self) -> &Ident {
100 self.field.ident.as_ref().unwrap()
101 }
102}
103
104#[derive(Debug)]
105pub struct AlignmentAttr(u32);
106
107impl Parse for AlignmentAttr {
108 fn parse(input: ParseStream) -> syn::Result<Self> {
109 match input
110 .parse::<LitInt>()
111 .and_then(|lit| lit.base10_parse::<u32>())
112 {
113 Ok(num) if num.is_power_of_two() => Ok(Self(num)),
114 _ => Err(syn::Error::new(
115 input.span(),
116 "expected a power of 2 u32 literal",
117 )),
118 }
119 }
120}
121
122#[derive(Debug)]
123pub struct StaticSizeAttr(u32);
124
125impl Parse for StaticSizeAttr {
126 fn parse(input: ParseStream) -> syn::Result<Self> {
127 let span = input.span();
128 match input
129 .parse::<LitInt>()
130 .and_then(|lit| lit.base10_parse::<u32>())
131 {
132 Ok(num) => Ok(Self(num)),
133 _ => Err(syn::Error::new(span, "expected u32 literal")),
134 }
135 }
136}
137
138#[derive(Debug)]
139pub enum SizeAttr {
140 Static(StaticSizeAttr),
141 Runtime,
142}
143
144impl Parse for SizeAttr {
145 fn parse(input: ParseStream) -> syn::Result<Self> {
146 let span = input.span();
147 match input.parse::<StaticSizeAttr>() {
148 Ok(static_size) => Ok(SizeAttr::Static(static_size)),
149 _ => match input.parse::<Path>() {
150 Ok(ident) if ident.is_ident("runtime") => Ok(SizeAttr::Runtime),
151 _ => Err(syn::Error::new(
152 span,
153 "expected u32 literal or `runtime` identifier",
154 )),
155 },
156 }
157 }
158}
159
160#[derive(Debug)]
161pub enum ShaderAttr {
162 Align { attr: AlignmentAttr, span: Span },
163 Size { attr: SizeAttr, span: Span },
164}
165
166impl Parse for ShaderAttr {
167 fn parse(input: ParseStream) -> syn::Result<Self> {
168 let ident_span = input.span();
169 let Ok(ident) = input.parse::<Ident>() else {
170 return Err(syn::Error::new(ident_span, "expected `align` or `size`"));
171 };
172
173 match ident.to_string().as_str() {
174 "align" => {
175 if !input.peek(syn::token::Paren) {
176 return Err(syn::Error::new(
177 ident_span,
178 "expected attribute arguments in parentheses: `align(...)`",
179 ));
180 }
181
182 let args;
183 parenthesized!(args in input);
184 let attr_span = args.span();
185 let align_attr: AlignmentAttr = args.parse()?;
186 Ok(ShaderAttr::Align {
187 attr: align_attr,
188 span: attr_span,
189 })
190 }
191 "size" => {
192 if !input.peek(syn::token::Paren) {
193 return Err(syn::Error::new(
194 ident_span,
195 "expected attribute arguments in parentheses: `size(...)`",
196 ));
197 }
198
199 let args;
200 parenthesized!(args in input);
201 let attr_span = args.span();
202 let size_attr: SizeAttr = args.parse()?;
203 Ok(ShaderAttr::Size {
204 attr: size_attr,
205 span: attr_span,
206 })
207 }
208 _ => Err(syn::Error::new(
209 ident_span,
210 "unknown shader attribute, expected `align` or `size`",
211 )),
212 }
213 }
214}
215
216#[derive(Debug)]
217pub struct ShaderAttrList(Punctuated<ShaderAttr, Token![,]>);
218impl Parse for ShaderAttrList {
219 fn parse(input: ParseStream) -> syn::Result<Self> {
220 Ok(Self(input.parse_terminated(ShaderAttr::parse, Token![,])?))
221 }
222}
223
224struct Errors {
225 inner: Option<Error>,
226}
227
228impl Errors {
229 fn new() -> Self {
230 Self { inner: None }
231 }
232
233 fn append(&mut self, err: Error) {
234 if let Some(ex_error) = &mut self.inner {
235 ex_error.combine(err);
236 } else {
237 self.inner.replace(err);
238 }
239 }
240
241 fn into_compile_error(self) -> Option<TokenStream> {
242 self.inner.map(|e| e.into_compile_error())
243 }
244}
245
246pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
247 let root = &parse_quote!(#root::private);
248
249 let fields = match get_named_struct_fields(&input.data) {
250 Ok(fields) => fields,
251 Err(e) => return e.into_compile_error(),
252 };
253
254 let last_field_index = fields.named.len() - 1;
255
256 let mut errors = Errors::new();
257
258 let mut is_runtime_sized = false;
259
260 let field_data: Vec<_> = fields
261 .named
262 .iter()
263 .enumerate()
264 .map(|(i, field)| {
265 let mut data = FieldData {
266 field: field.clone(),
267 size: None,
268 align: None,
269 };
270
271 for attr in &field.attrs {
272 if !(attr.meta.path().is_ident("shader")) {
273 continue;
274 }
275
276 let shader_attrs = match attr.parse_args::<ShaderAttrList>() {
277 Ok(attrs) => attrs,
278 Err(err) => {
279 errors.append(err);
280 continue;
281 }
282 };
283
284 for shader_attr in shader_attrs.0 {
285 match shader_attr {
286 ShaderAttr::Align { attr, span } => {
287 if data.align.is_some() {
288 let err = syn::Error::new(span, "duplicate `align(X)` attribute");
289 errors.append(err);
290 continue;
291 }
292
293 data.align = Some((attr.0, span));
294 }
295 ShaderAttr::Size { attr, span } => {
296 if data.size.is_some() || is_runtime_sized {
297 let err = syn::Error::new(span, "duplicate `size(X)` attribute");
298 errors.append(err);
299 continue;
300 }
301
302 match attr {
303 SizeAttr::Runtime => {
304 if i == last_field_index {
305 is_runtime_sized = true;
306 } else {
307 let err = syn::Error::new(
308 span,
309 "only the last field can be `size(runtime)`",
310 );
311 errors.append(err);
312 continue;
313 }
314 }
315 SizeAttr::Static(attr) => {
316 data.size = Some((attr.0, span));
317 }
318 }
319 }
320 }
321 }
322 }
323 data
324 })
325 .collect();
326
327 let mut found = false;
328 let size_hint: &Path = &parse_quote!(#root::ArrayLength);
329 for field in &fields.named {
330 match &field.ty {
332 Type::Path(path)
333 if path.path.segments.last().unwrap().ident
334 == size_hint.segments.last().unwrap().ident =>
335 {
336 if found {
337 let err = syn::Error::new(
338 field.ty.span(),
339 "only one field can use the `ArrayLength` type!",
340 );
341 errors.append(err)
342 } else {
343 if !is_runtime_sized {
344 let err = syn::Error::new(
345 field.ty.span(),
346 "`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[shader(size(runtime))]`!",
347 );
348 errors.append(err)
349 }
350 found = true;
351 }
352 }
353 _ => {}
354 }
355 }
356
357 if let Some(ts) = errors.into_compile_error() {
358 return ts;
359 }
360
361 let nr_of_fields = &Literal::usize_suffixed(field_data.len());
362
363 let field_trait_constraints = generate_field_trait_constraints(
364 &input,
365 &field_data,
366 if is_runtime_sized {
367 quote!(#root::ShaderType + #root::RuntimeSizedArray)
368 } else {
369 quote!(#root::ShaderType + #root::ShaderSize)
370 },
371 quote!(#root::ShaderType + #root::ShaderSize),
372 );
373
374 let mut lifetimes = input.generics.clone();
375 lifetimes.params = lifetimes
376 .params
377 .into_iter()
378 .filter(|param| matches!(param, GenericParam::Lifetime(_)))
379 .collect::<Punctuated<GenericParam, Comma>>();
380
381 let align_check = {
382 let (impl_generics, _, _) = lifetimes.split_for_impl();
383 field_data
384 .iter()
385 .filter_map(|data| data.align.as_ref().map(|align| (&data.field.ty, align)))
386 .map(move |(ty, (align, span))| {
387 let align = Literal::u64_suffixed(*align as u64);
388 quote_spanned! {*span=>
389 const _: () = {
390 #[track_caller]
391 #[allow(clippy::extra_unused_lifetimes)]
392 const fn check #impl_generics () {
393 let alignment = <#ty as #root::ShaderType>::METADATA.alignment().get();
394 #root::concat_assert!(
395 alignment <= #align,
396 "shader(align) attribute value must be at least ", alignment, " (field's type alignment)"
397 )
398 }
399 check();
400 };
401 }
402 })
403 };
404
405 let size_check = {
406 let (impl_generics, _, _) = lifetimes.split_for_impl();
407 field_data
408 .iter()
409 .filter_map(|data| data.size.as_ref().map(|size| (&data.field.ty, size)))
410 .map(move |(ty, (size, span))| {
411 let size = Literal::u64_suffixed(*size as u64);
412 quote_spanned! {*span=>
413 const _: () = {
414 #[track_caller]
415 #[allow(clippy::extra_unused_lifetimes)]
416 const fn check #impl_generics () {
417 let size = <#ty as #root::ShaderSize>::SHADER_SIZE.get();
418 #root::concat_assert!(
419 size <= #size,
420 "size attribute value must be at least ", size, " (field's type size)"
421 )
422 }
423 check();
424 };
425 }
426 })
427 };
428
429 let uniform_check = field_data.iter().enumerate().map(|(i, data)| {
430 let ty = &data.field.ty;
431 let ty_check = quote_spanned! {ty.span()=>
432 <#ty as #root::ShaderType>::UNIFORM_COMPAT_ASSERT()
433 };
434 let ident = data.ident();
435 let name = ident.to_string();
436 let field_offset_check = quote_spanned! {ident.span()=>
437 if let ::core::option::Option::Some(min_alignment) =
438 <#ty as #root::ShaderType>::METADATA.uniform_min_alignment()
439 {
440 let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
441
442 #root::concat_assert!(
443 min_alignment.is_aligned(offset),
444 "offset of field '", #name, "' must be a multiple of ", min_alignment.get(),
445 " (current offset: ", offset, ")"
446 )
447 }
448 };
449 let field_offset_diff = if i != 0 {
450 let prev_field = &field_data[i - 1];
451 let prev_field_ty = &prev_field.field.ty;
452 let prev_ident_name = prev_field.ident().to_string();
453 quote_spanned! {ident.span()=>
454 if let ::core::option::Option::Some(min_alignment) =
455 <#prev_field_ty as #root::ShaderType>::METADATA.uniform_min_alignment()
456 {
457 let prev_offset = <Self as #root::ShaderType>::METADATA.offset(#i - 1);
458 let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
459 let diff = offset - prev_offset;
460
461 let prev_size = <#prev_field_ty as #root::ShaderSize>::SHADER_SIZE.get();
462 let prev_size = min_alignment.round_up(prev_size);
463
464 #root::concat_assert!(
465 diff >= prev_size,
466 "offset between fields '", #prev_ident_name, "' and '", #name, "' must be at least ",
467 min_alignment.get(), " (currently: ", diff, ")"
468 )
469 }
470 }
471 } else {
472 quote! {()}
473 };
474 quote! {
475 #ty_check,
476 #field_offset_check,
477 #field_offset_diff
478 }
479 });
480
481 let alignments = field_data.iter().map(|data| data.alignment(root));
482
483 let paddings = field_data.iter().enumerate().map(|(i, current)| {
484 let is_first = i == 0;
485 let is_last = i == field_data.len() - 1;
486
487 let mut out = TokenStream::new();
488
489 if !is_first {
490 let prev_i = i - 1;
491
492 let alignment = current.alignment(root);
493
494 let extra_padding = field_data
495 .get(prev_i)
496 .and_then(|prev| prev.extra_padding(root))
497 .map(|extra_padding| quote!(+ #extra_padding));
498
499 out.extend(quote! {
500 offsets[#i] = #alignment.round_up(offset);
501
502 let padding = #alignment.padding_needed_for(offset);
503 offset += padding;
504 paddings[#prev_i] = padding #extra_padding;
505 });
506 };
507
508 if is_last && is_runtime_sized {
509 return out;
510 }
511
512 let size = current.size(root);
513 out.extend(quote! {
514 offset += #size;
515 });
516
517 if is_last {
518 let extra_padding = current
519 .extra_padding(root)
520 .map(|extra_padding| quote!(+ #extra_padding));
521
522 out.extend(quote! {
523 paddings[#i] = struct_alignment.padding_needed_for(offset) #extra_padding;
524 });
525 }
526
527 out
528 });
529
530 fn gen_body<'a>(
531 field_data: &'a [FieldData],
532 root: &'a Path,
533 get_main: impl Fn(&Ident) -> TokenStream + 'a,
534 get_padding: impl Fn(TokenStream) -> TokenStream + 'a,
535 ) -> impl Iterator<Item = TokenStream> + 'a {
536 field_data.iter().enumerate().map(move |(i, data)| {
537 let ident = data.ident();
538
539 let padding = {
540 let i = Literal::usize_suffixed(i);
541 quote! { <Self as #root::ShaderType>::METADATA.padding(#i) }
542 };
543
544 let main = get_main(ident);
545 let padding = get_padding(padding);
546
547 quote! {
548 #main
549 #padding
550 }
551 })
552 }
553
554 let write_into_buffer_body = gen_body(
555 &field_data,
556 root,
557 |ident| {
558 quote! {
559 #root::WriteInto::write_into(&self.#ident, writer);
560 }
561 },
562 |padding| {
563 quote! {
564 #root::Writer::advance(writer, #padding as ::core::primitive::usize);
565 }
566 },
567 );
568
569 let read_from_buffer_body = gen_body(
570 &field_data,
571 root,
572 |ident| {
573 quote! {
574 #root::ReadFrom::read_from(&mut self.#ident, reader);
575 }
576 },
577 |padding| {
578 quote! {
579 #root::Reader::advance(reader, #padding as ::core::primitive::usize);
580 }
581 },
582 );
583
584 let create_from_buffer_body = gen_body(
585 &field_data,
586 root,
587 move |ident| {
588 quote! {
589 let #ident = #root::CreateFrom::create_from(reader);
590 }
591 },
592 |padding| {
593 quote! {
594 #root::Reader::advance(reader, #padding as ::core::primitive::usize);
595 }
596 },
597 );
598
599 let field_idents = field_data.iter().map(|data| data.ident());
600 let last_field = field_data.last().unwrap();
601 let last_field_min_size = last_field.min_size(root);
602 let last_field_ident = last_field.ident();
603
604 let field_types = field_data.iter().map(|data| &data.field.ty);
605 let field_types_2 = field_types.clone();
606 let field_types_3 = field_types.clone();
607 let field_types_4 = field_types.clone();
608 let all_other = field_types.clone().take(last_field_index);
609 let last_field_type = &last_field.field.ty;
610
611 let name = &input.ident;
612 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
613
614 let set_contained_rt_sized_array_length = if is_runtime_sized {
615 quote! {
616 writer.ctx.rts_array_length = ::core::option::Option::Some(
617 #root::RuntimeSizedArray::len(&self.#last_field_ident)
618 as ::core::primitive::u32
619 );
620 }
621 } else {
622 TokenStream::new()
623 };
624
625 let extra = match is_runtime_sized {
626 true => quote! {
627 impl #impl_generics #root::CalculateSizeFor for #name #ty_generics
628 where
629 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
630 #last_field_type: #root::CalculateSizeFor,
631 {
632 fn calculate_size_for(nr_of_el: ::core::primitive::u64) -> ::core::num::NonZeroU64 {
633 let mut offset = <Self as #root::ShaderType>::METADATA.last_offset();
634 offset += <#last_field_type as #root::CalculateSizeFor>::calculate_size_for(nr_of_el).get();
635 #root::SizeValue::new(<Self as #root::ShaderType>::METADATA.alignment().round_up(offset)).0
636 }
637 }
638 },
639 false => quote! {
640 impl #impl_generics #root::ShaderSize for #name #ty_generics
641 where
642 #( #field_types: #root::ShaderSize, )*
643 {}
644 },
645 };
646
647 quote! {
653 #( #field_trait_constraints )*
654
655 #( #align_check )*
656
657 #( #size_check )*
658
659 impl #impl_generics #root::ShaderType for #name #ty_generics #where_clause
660 where
661 #( #all_other: #root::ShaderType + #root::ShaderSize, )*
662 #last_field_type: #root::ShaderType,
663 {
664 type ExtraMetadata = #root::StructMetadata<#nr_of_fields>;
665 const METADATA: #root::Metadata<Self::ExtraMetadata> = {
666 let struct_alignment = #root::AlignmentValue::max([ #( #alignments, )* ]);
667
668 let extra = {
669 let mut paddings = [0; #nr_of_fields];
670 let mut offsets = [0; #nr_of_fields];
671 let mut offset = 0;
672 #( #paddings )*
673 #root::StructMetadata { offsets, paddings }
674 };
675
676 let min_size = {
677 let mut offset = extra.offsets[#nr_of_fields - 1];
678 offset += #last_field_min_size;
679 #root::SizeValue::new(struct_alignment.round_up(offset))
680 };
681
682 #root::Metadata {
683 alignment: struct_alignment,
684 has_uniform_min_alignment: true,
685 min_size,
686 is_pod: false,
687 extra,
688 }
689 };
690
691 const UNIFORM_COMPAT_ASSERT: fn() = || #root::consume_zsts([
692 #( #uniform_check, )*
693 ]);
694
695 fn size(&self) -> ::core::num::NonZeroU64 {
696 let mut offset = Self::METADATA.last_offset();
697 offset += #root::ShaderType::size(&self.#last_field_ident).get();
698 #root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0
699 }
700 }
701
702 impl #impl_generics #root::WriteInto for #name #ty_generics
703 where
704 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
705 #( for<'__> #field_types_2: #root::WriteInto, )*
706 {
707 #[inline]
708 fn write_into<B: #root::BufferMut>(&self, writer: &mut #root::Writer<B>) {
709 #set_contained_rt_sized_array_length
710 #( #write_into_buffer_body )*
711 }
712 }
713
714 impl #impl_generics #root::ReadFrom for #name #ty_generics
715 where
716 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
717 #( for<'__> #field_types_3: #root::ReadFrom, )*
718 {
719 #[inline]
720 fn read_from<B: #root::BufferRef>(&mut self, reader: &mut #root::Reader<B>) {
721 #( #read_from_buffer_body )*
722 }
723 }
724
725 impl #impl_generics #root::CreateFrom for #name #ty_generics
726 where
727 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
728 #( for<'__> #field_types_4: #root::CreateFrom, )*
729 {
730 #[inline]
731 fn create_from<B: #root::BufferRef>(reader: &mut #root::Reader<B>) -> Self {
732 #( #create_from_buffer_body )*
733
734 #root::build_struct!(Self, #( #field_idents ),*)
735 }
736 }
737
738 #extra
739 }
740}
741
742fn generate_field_trait_constraints<'a>(
743 input: &'a DeriveInput,
744 field_data: &'a [FieldData],
745 trait_for_last_field: TokenStream,
746 trait_for_all_other_fields: TokenStream,
747) -> impl Iterator<Item = TokenStream> + 'a {
748 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
749 field_data.iter().enumerate().map(move |(i, data)| {
750 let ty = &data.field.ty;
751
752 let t = if i == field_data.len() - 1 {
753 &trait_for_last_field
754 } else {
755 &trait_for_all_other_fields
756 };
757
758 if ty_generics.to_token_stream().is_empty() {
759 quote_spanned! {ty.span()=>
760 const _: fn() = || {
761 #[allow(clippy::extra_unused_lifetimes, clippy::missing_const_for_fn, clippy::extra_unused_type_parameters)]
762 fn check #impl_generics () #where_clause {
763 fn assert_impl<T: ?::core::marker::Sized + #t>() {}
764 assert_impl::<#ty>();
765 }
766 check ();
767 };
768 }
769 } else {
770 quote_spanned! {ty.span()=>
772 const _: fn() = || {};
773 }
774 }
775 })
776}