1use proc_macro2::{Ident, Literal, Span, TokenStream};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_quote,
6 punctuated::Punctuated,
7 spanned::Spanned,
8 token::Comma,
9 Data, DataStruct, DeriveInput, Error, Fields, FieldsNamed, GenericParam, LitInt, Path, Type,
10};
11
12pub use syn;
13
14#[macro_export]
15macro_rules! implement {
16 ($path:expr) => {
17 #[proc_macro_derive(ShaderType, attributes(align, size))]
18 pub fn derive_shader_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19 let input = $crate::syn::parse_macro_input!(input as $crate::syn::DeriveInput);
20 let expanded = encase_derive_impl::derive_shader_type(input, &$path);
21 proc_macro::TokenStream::from(expanded)
22 }
23 };
24}
25
26fn get_named_struct_fields(data: &syn::Data) -> syn::Result<&FieldsNamed> {
27 match data {
28 Data::Struct(DataStruct {
29 fields: Fields::Named(fields),
30 ..
31 }) if !fields.named.is_empty() => Ok(fields),
32 _ => Err(Error::new(
33 Span::call_site(),
34 "Only non empty structs with named fields are supported!",
35 )),
36 }
37}
38
39struct FieldData {
40 pub field: syn::Field,
41 pub size: Option<(u32, Span)>,
42 pub align: Option<(u32, Span)>,
43}
44
45impl FieldData {
46 fn alignment(&self, root: &Path) -> TokenStream {
47 if let Some((alignment, _)) = self.align {
48 let alignment = Literal::u64_suffixed(alignment as u64);
49 quote! {
50 #root::AlignmentValue::new(#alignment)
51 }
52 } else {
53 let ty = &self.field.ty;
54 quote! {
55 <#ty as #root::ShaderType>::METADATA.alignment()
56 }
57 }
58 }
59
60 fn size(&self, root: &Path) -> TokenStream {
61 if let Some((size, _)) = self.size {
62 let size = Literal::u64_suffixed(size as u64);
63 quote! {
64 #size
65 }
66 } else {
67 let ty = &self.field.ty;
68 quote! {
69 <#ty as #root::ShaderSize>::SHADER_SIZE.get()
70 }
71 }
72 }
73
74 fn min_size(&self, root: &Path) -> TokenStream {
75 if let Some((size, _)) = self.size {
76 let size = Literal::u64_suffixed(size as u64);
77 quote! {
78 #size
79 }
80 } else {
81 let ty = &self.field.ty;
82 quote! {
83 <#ty as #root::ShaderType>::METADATA.min_size().get()
84 }
85 }
86 }
87
88 fn extra_padding(&self, root: &Path) -> Option<TokenStream> {
89 self.size.as_ref().map(|(size, _)| {
90 let size = Literal::u64_suffixed(*size as u64);
91 let ty = &self.field.ty;
92 let original_size = quote! { <#ty as #root::ShaderSize>::SHADER_SIZE.get() };
93 quote!(#size.saturating_sub(#original_size))
94 })
95 }
96
97 fn ident(&self) -> &Ident {
98 self.field.ident.as_ref().unwrap()
99 }
100}
101
102struct AlignmentAttr(u32);
103
104impl Parse for AlignmentAttr {
105 fn parse(input: ParseStream) -> syn::Result<Self> {
106 match input
107 .parse::<LitInt>()
108 .and_then(|lit| lit.base10_parse::<u32>())
109 {
110 Ok(num) if num.is_power_of_two() => Ok(Self(num)),
111 _ => Err(syn::Error::new(
112 input.span(),
113 "expected a power of 2 u32 literal",
114 )),
115 }
116 }
117}
118
119struct StaticSizeAttr(u32);
120
121impl Parse for StaticSizeAttr {
122 fn parse(input: ParseStream) -> syn::Result<Self> {
123 match input
124 .parse::<LitInt>()
125 .and_then(|lit| lit.base10_parse::<u32>())
126 {
127 Ok(num) => Ok(Self(num)),
128 _ => Err(syn::Error::new(input.span(), "expected u32 literal")),
129 }
130 }
131}
132
133enum SizeAttr {
134 Static(StaticSizeAttr),
135 Runtime,
136}
137
138impl Parse for SizeAttr {
139 fn parse(input: ParseStream) -> syn::Result<Self> {
140 match input.parse::<StaticSizeAttr>() {
141 Ok(static_size) => Ok(SizeAttr::Static(static_size)),
142 _ => match input.parse::<Path>() {
143 Ok(ident) if ident.is_ident("runtime") => Ok(SizeAttr::Runtime),
144 _ => Err(syn::Error::new(
145 input.span(),
146 "expected u32 literal or `runtime` identifier",
147 )),
148 },
149 }
150 }
151}
152
153struct Errors {
154 inner: Option<Error>,
155}
156
157impl Errors {
158 fn new() -> Self {
159 Self { inner: None }
160 }
161
162 fn append(&mut self, err: Error) {
163 if let Some(ex_error) = &mut self.inner {
164 ex_error.combine(err);
165 } else {
166 self.inner.replace(err);
167 }
168 }
169
170 fn into_compile_error(self) -> Option<TokenStream> {
171 self.inner.map(|e| e.into_compile_error())
172 }
173}
174
175pub fn derive_shader_type(input: DeriveInput, root: &Path) -> TokenStream {
176 let root = &parse_quote!(#root::private);
177
178 let fields = match get_named_struct_fields(&input.data) {
179 Ok(fields) => fields,
180 Err(e) => return e.into_compile_error(),
181 };
182
183 let last_field_index = fields.named.len() - 1;
184
185 let mut errors = Errors::new();
186
187 let mut is_runtime_sized = false;
188
189 let field_data: Vec<_> = fields
190 .named
191 .iter()
192 .enumerate()
193 .map(|(i, field)| {
194 let mut data = FieldData {
195 field: field.clone(),
196 size: None,
197 align: None,
198 };
199 for attr in &field.attrs {
200 if !(attr.meta.path().is_ident("size") || attr.meta.path().is_ident("align")) {
201 continue;
202 }
203 match attr.meta.require_list() {
204 Ok(meta_list) => {
205 let span = meta_list.tokens.span();
206 if meta_list.path.is_ident("align") {
207 let res = attr.parse_args::<AlignmentAttr>();
208 match res {
209 Ok(val) => data.align = Some((val.0, span)),
210 Err(err) => errors.append(err),
211 }
212 } else if meta_list.path.is_ident("size") {
213 let res = if i == last_field_index {
214 attr.parse_args::<SizeAttr>().map(|val| match val {
215 SizeAttr::Runtime => {
216 is_runtime_sized = true;
217 None
218 }
219 SizeAttr::Static(size) => Some((size.0, span)),
220 })
221 } else {
222 attr.parse_args::<StaticSizeAttr>()
223 .map(|val| Some((val.0, span)))
224 };
225 match res {
226 Ok(val) => data.size = val,
227 Err(err) => errors.append(err),
228 }
229 }
230 }
231 Err(err) => errors.append(err),
232 };
233 }
234 data
235 })
236 .collect();
237
238 let mut found = false;
239 let size_hint: &Path = &parse_quote!(#root::ArrayLength);
240 for field in &fields.named {
241 match &field.ty {
243 Type::Path(path)
244 if path.path.segments.last().unwrap().ident
245 == size_hint.segments.last().unwrap().ident =>
246 {
247 if found {
248 let err = syn::Error::new(
249 field.ty.span(),
250 "only one field can use the `ArrayLength` type!",
251 );
252 errors.append(err)
253 } else {
254 if !is_runtime_sized {
255 let err = syn::Error::new(
256 field.ty.span(),
257 "`ArrayLength` type can only be used within a struct containing a runtime-sized array marked as `#[size(runtime)]`!",
258 );
259 errors.append(err)
260 }
261 found = true;
262 }
263 }
264 _ => {}
265 }
266 }
267
268 if let Some(ts) = errors.into_compile_error() {
269 return ts;
270 }
271
272 let nr_of_fields = &Literal::usize_suffixed(field_data.len());
273
274 let field_trait_constraints = generate_field_trait_constraints(
275 &input,
276 &field_data,
277 if is_runtime_sized {
278 quote!(#root::ShaderType + #root::RuntimeSizedArray)
279 } else {
280 quote!(#root::ShaderType + #root::ShaderSize)
281 },
282 quote!(#root::ShaderType + #root::ShaderSize),
283 );
284
285 let mut lifetimes = input.generics.clone();
286 lifetimes.params = lifetimes
287 .params
288 .into_iter()
289 .filter(|param| matches!(param, GenericParam::Lifetime(_)))
290 .collect::<Punctuated<GenericParam, Comma>>();
291
292 let align_check = {
293 let (impl_generics, _, _) = lifetimes.split_for_impl();
294 field_data
295 .iter()
296 .filter_map(|data| data.align.as_ref().map(|align| (&data.field.ty, align)))
297 .map(move |(ty, (align, span))| {
298 let align = Literal::u64_suffixed(*align as u64);
299 quote_spanned! {*span=>
300 const _: () = {
301 #[track_caller]
302 #[allow(clippy::extra_unused_lifetimes)]
303 const fn check #impl_generics () {
304 let alignment = <#ty as #root::ShaderType>::METADATA.alignment().get();
305 #root::concat_assert!(
306 alignment <= #align,
307 "align attribute value must be at least ", alignment, " (field's type alignment)"
308 )
309 }
310 check();
311 };
312 }
313 })
314 };
315
316 let size_check = {
317 let (impl_generics, _, _) = lifetimes.split_for_impl();
318 field_data
319 .iter()
320 .filter_map(|data| data.size.as_ref().map(|size| (&data.field.ty, size)))
321 .map(move |(ty, (size, span))| {
322 let size = Literal::u64_suffixed(*size as u64);
323 quote_spanned! {*span=>
324 const _: () = {
325 #[track_caller]
326 #[allow(clippy::extra_unused_lifetimes)]
327 const fn check #impl_generics () {
328 let size = <#ty as #root::ShaderSize>::SHADER_SIZE.get();
329 #root::concat_assert!(
330 size <= #size,
331 "size attribute value must be at least ", size, " (field's type size)"
332 )
333 }
334 check();
335 };
336 }
337 })
338 };
339
340 let uniform_check = field_data.iter().enumerate().map(|(i, data)| {
341 let ty = &data.field.ty;
342 let ty_check = quote_spanned! {ty.span()=>
343 <#ty as #root::ShaderType>::UNIFORM_COMPAT_ASSERT()
344 };
345 let ident = data.ident();
346 let name = ident.to_string();
347 let field_offset_check = quote_spanned! {ident.span()=>
348 if let ::core::option::Option::Some(min_alignment) =
349 <#ty as #root::ShaderType>::METADATA.uniform_min_alignment()
350 {
351 let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
352
353 #root::concat_assert!(
354 min_alignment.is_aligned(offset),
355 "offset of field '", #name, "' must be a multiple of ", min_alignment.get(),
356 " (current offset: ", offset, ")"
357 )
358 }
359 };
360 let field_offset_diff = if i != 0 {
361 let prev_field = &field_data[i - 1];
362 let prev_field_ty = &prev_field.field.ty;
363 let prev_ident_name = prev_field.ident().to_string();
364 quote_spanned! {ident.span()=>
365 if let ::core::option::Option::Some(min_alignment) =
366 <#prev_field_ty as #root::ShaderType>::METADATA.uniform_min_alignment()
367 {
368 let prev_offset = <Self as #root::ShaderType>::METADATA.offset(#i - 1);
369 let offset = <Self as #root::ShaderType>::METADATA.offset(#i);
370 let diff = offset - prev_offset;
371
372 let prev_size = <#prev_field_ty as #root::ShaderSize>::SHADER_SIZE.get();
373 let prev_size = min_alignment.round_up(prev_size);
374
375 #root::concat_assert!(
376 diff >= prev_size,
377 "offset between fields '", #prev_ident_name, "' and '", #name, "' must be at least ",
378 min_alignment.get(), " (currently: ", diff, ")"
379 )
380 }
381 }
382 } else {
383 quote! {()}
384 };
385 quote! {
386 #ty_check,
387 #field_offset_check,
388 #field_offset_diff
389 }
390 });
391
392 let alignments = field_data.iter().map(|data| data.alignment(root));
393
394 let paddings = field_data.iter().enumerate().map(|(i, current)| {
395 let is_first = i == 0;
396 let is_last = i == field_data.len() - 1;
397
398 let mut out = TokenStream::new();
399
400 if !is_first {
401 let prev_i = i - 1;
402
403 let alignment = current.alignment(root);
404
405 let extra_padding = field_data
406 .get(prev_i)
407 .and_then(|prev| prev.extra_padding(root))
408 .map(|extra_padding| quote!(+ #extra_padding));
409
410 out.extend(quote! {
411 offsets[#i] = #alignment.round_up(offset);
412
413 let padding = #alignment.padding_needed_for(offset);
414 offset += padding;
415 paddings[#prev_i] = padding #extra_padding;
416 });
417 };
418
419 if is_last && is_runtime_sized {
420 return out;
421 }
422
423 let size = current.size(root);
424 out.extend(quote! {
425 offset += #size;
426 });
427
428 if is_last {
429 let extra_padding = current
430 .extra_padding(root)
431 .map(|extra_padding| quote!(+ #extra_padding));
432
433 out.extend(quote! {
434 paddings[#i] = struct_alignment.padding_needed_for(offset) #extra_padding;
435 });
436 }
437
438 out
439 });
440
441 fn gen_body<'a>(
442 field_data: &'a [FieldData],
443 root: &'a Path,
444 get_main: impl Fn(&Ident) -> TokenStream + 'a,
445 get_padding: impl Fn(TokenStream) -> TokenStream + 'a,
446 ) -> impl Iterator<Item = TokenStream> + 'a {
447 field_data.iter().enumerate().map(move |(i, data)| {
448 let ident = data.ident();
449
450 let padding = {
451 let i = Literal::usize_suffixed(i);
452 quote! { <Self as #root::ShaderType>::METADATA.padding(#i) }
453 };
454
455 let main = get_main(ident);
456 let padding = get_padding(padding);
457
458 quote! {
459 #main
460 #padding
461 }
462 })
463 }
464
465 let write_into_buffer_body = gen_body(
466 &field_data,
467 root,
468 |ident| {
469 quote! {
470 #root::WriteInto::write_into(&self.#ident, writer);
471 }
472 },
473 |padding| {
474 quote! {
475 #root::Writer::advance(writer, #padding as ::core::primitive::usize);
476 }
477 },
478 );
479
480 let read_from_buffer_body = gen_body(
481 &field_data,
482 root,
483 |ident| {
484 quote! {
485 #root::ReadFrom::read_from(&mut self.#ident, reader);
486 }
487 },
488 |padding| {
489 quote! {
490 #root::Reader::advance(reader, #padding as ::core::primitive::usize);
491 }
492 },
493 );
494
495 let create_from_buffer_body = gen_body(
496 &field_data,
497 root,
498 move |ident| {
499 quote! {
500 let #ident = #root::CreateFrom::create_from(reader);
501 }
502 },
503 |padding| {
504 quote! {
505 #root::Reader::advance(reader, #padding as ::core::primitive::usize);
506 }
507 },
508 );
509
510 let field_idents = field_data.iter().map(|data| data.ident());
511 let last_field = field_data.last().unwrap();
512 let last_field_min_size = last_field.min_size(root);
513 let last_field_ident = last_field.ident();
514
515 let field_types = field_data.iter().map(|data| &data.field.ty);
516 let field_types_2 = field_types.clone();
517 let field_types_3 = field_types.clone();
518 let field_types_4 = field_types.clone();
519 let all_other = field_types.clone().take(last_field_index);
520 let last_field_type = &last_field.field.ty;
521
522 let name = &input.ident;
523 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
524
525 let set_contained_rt_sized_array_length = if is_runtime_sized {
526 quote! {
527 writer.ctx.rts_array_length = ::core::option::Option::Some(
528 #root::RuntimeSizedArray::len(&self.#last_field_ident)
529 as ::core::primitive::u32
530 );
531 }
532 } else {
533 TokenStream::new()
534 };
535
536 let extra = match is_runtime_sized {
537 true => quote! {
538 impl #impl_generics #root::CalculateSizeFor for #name #ty_generics
539 where
540 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
541 #last_field_type: #root::CalculateSizeFor,
542 {
543 fn calculate_size_for(nr_of_el: ::core::primitive::u64) -> ::core::num::NonZeroU64 {
544 let mut offset = <Self as #root::ShaderType>::METADATA.last_offset();
545 offset += <#last_field_type as #root::CalculateSizeFor>::calculate_size_for(nr_of_el).get();
546 #root::SizeValue::new(<Self as #root::ShaderType>::METADATA.alignment().round_up(offset)).0
547 }
548 }
549 },
550 false => quote! {
551 impl #impl_generics #root::ShaderSize for #name #ty_generics
552 where
553 #( #field_types: #root::ShaderSize, )*
554 {}
555 },
556 };
557
558 quote! {
564 #( #field_trait_constraints )*
565
566 #( #align_check )*
567
568 #( #size_check )*
569
570 impl #impl_generics #root::ShaderType for #name #ty_generics #where_clause
571 where
572 #( #all_other: #root::ShaderType + #root::ShaderSize, )*
573 #last_field_type: #root::ShaderType,
574 {
575 type ExtraMetadata = #root::StructMetadata<#nr_of_fields>;
576 const METADATA: #root::Metadata<Self::ExtraMetadata> = {
577 let struct_alignment = #root::AlignmentValue::max([ #( #alignments, )* ]);
578
579 let extra = {
580 let mut paddings = [0; #nr_of_fields];
581 let mut offsets = [0; #nr_of_fields];
582 let mut offset = 0;
583 #( #paddings )*
584 #root::StructMetadata { offsets, paddings }
585 };
586
587 let min_size = {
588 let mut offset = extra.offsets[#nr_of_fields - 1];
589 offset += #last_field_min_size;
590 #root::SizeValue::new(struct_alignment.round_up(offset))
591 };
592
593 #root::Metadata {
594 alignment: struct_alignment,
595 has_uniform_min_alignment: true,
596 min_size,
597 is_pod: false,
598 extra,
599 }
600 };
601
602 const UNIFORM_COMPAT_ASSERT: fn() = || #root::consume_zsts([
603 #( #uniform_check, )*
604 ]);
605
606 fn size(&self) -> ::core::num::NonZeroU64 {
607 let mut offset = Self::METADATA.last_offset();
608 offset += #root::ShaderType::size(&self.#last_field_ident).get();
609 #root::SizeValue::new(Self::METADATA.alignment().round_up(offset)).0
610 }
611 }
612
613 impl #impl_generics #root::WriteInto for #name #ty_generics
614 where
615 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
616 #( for<'__> #field_types_2: #root::WriteInto, )*
617 {
618 #[inline]
619 fn write_into<B: #root::BufferMut>(&self, writer: &mut #root::Writer<B>) {
620 #set_contained_rt_sized_array_length
621 #( #write_into_buffer_body )*
622 }
623 }
624
625 impl #impl_generics #root::ReadFrom for #name #ty_generics
626 where
627 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
628 #( for<'__> #field_types_3: #root::ReadFrom, )*
629 {
630 #[inline]
631 fn read_from<B: #root::BufferRef>(&mut self, reader: &mut #root::Reader<B>) {
632 #( #read_from_buffer_body )*
633 }
634 }
635
636 impl #impl_generics #root::CreateFrom for #name #ty_generics
637 where
638 Self: #root::ShaderType<ExtraMetadata = #root::StructMetadata<#nr_of_fields>>,
639 #( for<'__> #field_types_4: #root::CreateFrom, )*
640 {
641 #[inline]
642 fn create_from<B: #root::BufferRef>(reader: &mut #root::Reader<B>) -> Self {
643 #( #create_from_buffer_body )*
644
645 #root::build_struct!(Self, #( #field_idents ),*)
646 }
647 }
648
649 #extra
650 }
651}
652
653fn generate_field_trait_constraints<'a>(
654 input: &'a DeriveInput,
655 field_data: &'a [FieldData],
656 trait_for_last_field: TokenStream,
657 trait_for_all_other_fields: TokenStream,
658) -> impl Iterator<Item = TokenStream> + 'a {
659 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
660 field_data.iter().enumerate().map(move |(i, data)| {
661 let ty = &data.field.ty;
662
663 let t = if i == field_data.len() - 1 {
664 &trait_for_last_field
665 } else {
666 &trait_for_all_other_fields
667 };
668
669 if ty_generics.to_token_stream().is_empty() {
670 quote_spanned! {ty.span()=>
671 const _: fn() = || {
672 #[allow(clippy::extra_unused_lifetimes, clippy::missing_const_for_fn, clippy::extra_unused_type_parameters)]
673 fn check #impl_generics () #where_clause {
674 fn assert_impl<T: ?::core::marker::Sized + #t>() {}
675 assert_impl::<#ty>();
676 }
677 check ();
678 };
679 }
680 } else {
681 quote_spanned! {ty.span()=>
683 const _: fn() = || {};
684 }
685 }
686 })
687}