1use quote::{format_ident, quote};
3use syn::{
4 spanned::Spanned, Data, DataEnum, DataStruct, DeriveInput, Fields, FieldsNamed, FieldsUnnamed,
5 Ident, Index, Type, TypeTuple, WhereClause, WherePredicate,
6};
7
8enum FieldName {
9 Index(Index),
10 Ident(Ident),
11}
12
13struct FieldParams {
14 field_tys: Vec<Type>,
15 field_names: Vec<FieldName>,
16}
17
18impl FieldParams {
19 fn new(fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>) -> Self {
20 let field_tys: Vec<_> = fields.iter().map(|field| field.ty.clone()).collect();
21 let field_names: Vec<_> = fields
22 .iter()
23 .enumerate()
24 .map(|(i, field)| {
25 field
26 .ident
27 .clone()
28 .map(FieldName::Ident)
29 .unwrap_or_else(|| {
30 FieldName::Index(Index {
31 index: i as u32,
32 span: field.span(),
33 })
34 })
35 })
36 .collect();
37 Self {
38 field_tys,
39 field_names,
40 }
41 }
42}
43
44fn get_struct_params(ds: &DataStruct) -> FieldParams {
45 let empty_punctuated = syn::punctuated::Punctuated::new();
46 let fields = match ds {
47 DataStruct {
48 fields: Fields::Named(FieldsNamed { named: ref x, .. }),
49 ..
50 } => x,
51 DataStruct {
52 fields: Fields::Unnamed(FieldsUnnamed { unnamed: ref x, .. }),
53 ..
54 } => x,
55 DataStruct {
56 fields: Fields::Unit,
57 ..
58 } => &empty_punctuated,
59 };
60
61 FieldParams::new(fields)
62}
63
64struct EnumVariant {
65 variant: syn::Variant,
66 fields: FieldParams,
67}
68
69struct EnumParams {
70 variants: Vec<EnumVariant>,
71 slab_size: proc_macro2::TokenStream,
72}
73
74fn get_enum_params(de: &DataEnum) -> EnumParams {
75 let DataEnum {
76 enum_token: _,
77 brace_token: _,
78 variants,
79 } = de;
80 let variants = variants
81 .iter()
82 .map(|variant| {
83 let empty_fields = syn::punctuated::Punctuated::new();
84 let fields = match &variant.fields {
85 Fields::Named(FieldsNamed { named: ref x, .. }) => x,
86 Fields::Unnamed(FieldsUnnamed { unnamed: ref x, .. }) => x,
87 Fields::Unit => &empty_fields,
88 };
89 let fields = FieldParams::new(fields);
90 EnumVariant {
91 variant: variant.clone(),
92 fields,
93 }
94 })
95 .collect::<Vec<_>>();
96 let slab_size_def = quote! {
97 let mut __size = 0usize;
98 };
99 let slab_size_increments = variants
100 .iter()
101 .map(|variant| {
102 let tys = &variant.fields.field_tys;
103 if tys.is_empty() {
104 quote! {}
105 } else {
106 quote! {{
107 let __field_size = #( <#tys as crabslab::SlabItem>::SLAB_SIZE )+*;
108 __size += crabslab::__saturating_sub(__field_size,__size);
109 }}
110 }
111 })
112 .collect::<Vec<_>>();
113 EnumParams {
114 slab_size: quote! {
115 #slab_size_def
116 #(#slab_size_increments)*
117 __size + 1
119 },
120 variants,
121 }
122}
123
124enum Params {
125 Struct(FieldParams),
126 Enum(EnumParams),
127}
128
129fn get_params(input: &DeriveInput) -> syn::Result<Params> {
130 match &input.data {
131 Data::Struct(ds) => Ok(Params::Struct(get_struct_params(ds))),
132 Data::Enum(de) => Ok(Params::Enum(get_enum_params(de))),
133 _ => Err(syn::Error::new(
134 input.span(),
135 "deriving SlabItem does not support unions".to_string(),
136 )),
137 }
138}
139
140#[proc_macro_derive(SlabItem, attributes(offsets))]
201pub fn derive_from_slab(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
202 let input: DeriveInput = syn::parse_macro_input!(input);
203
204 let gen_offsets_span = input.attrs.iter().find_map(|attr| {
205 let path = attr.path();
206 if path.is_ident("offsets") {
207 Some(path.span())
208 } else {
209 None
210 }
211 });
212
213 match get_params(&input) {
214 Ok(Params::Struct(p)) => derive_from_slab_struct(input, p, gen_offsets_span.is_some()),
215 Ok(Params::Enum(p)) => {
216 if let Some(span) = gen_offsets_span {
217 syn::Error::new(span, "Deriving field offsets is not supported for enums")
218 .into_compile_error()
219 .into()
220 } else {
221 derive_from_slab_enum(input, p)
222 }
223 }
224 Err(e) => e.into_compile_error().into(),
225 }
226}
227
228fn derive_from_slab_enum(input: DeriveInput, params: EnumParams) -> proc_macro::TokenStream {
229 let EnumParams {
230 variants,
231 slab_size,
232 } = params;
233 let name = &input.ident;
234 let field_tys = variants
235 .iter()
236 .flat_map(|v| v.fields.field_tys.clone())
237 .collect::<Vec<_>>();
238 let mut generics = input.generics;
239 {
240 fn constrain_system_data_types(clause: &mut WhereClause, tys: &[Type]) {
241 for ty in tys.iter() {
242 let where_predicate: WherePredicate = syn::parse_quote!(#ty : crabslab::SlabItem);
243 clause.predicates.push(where_predicate);
244 }
245 }
246
247 let where_clause = generics.make_where_clause();
248 constrain_system_data_types(where_clause, &field_tys)
249 }
250 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
251
252 let variant_reads = variants.iter().map(|variant| {
253 let ident = &variant.variant.ident;
254 let field_names = variant
255 .fields
256 .field_names
257 .iter()
258 .map(|name| match name {
259 FieldName::Index(i) => Ident::new(&format!("__{}", i.index), i.span),
260 FieldName::Ident(field) => field.clone(),
261 })
262 .collect::<Vec<_>>();
263 let field_tys = &variant.fields.field_tys;
264 let num_fields = field_names.len();
265 let reads = field_names
266 .iter()
267 .zip(field_tys.iter())
268 .enumerate()
269 .map(|(i, (name, ty))| {
270 let def = quote! {
271 let #name = <#ty as crabslab::SlabItem>::read_slab(index, slab);
272 };
273 let increment_index = if i + 1 < num_fields {
274 quote! {
275 index += <#ty as crabslab::SlabItem>::SLAB_SIZE;
276 }
277 } else {
278 quote! {}
279 };
280 quote! {
281 #def
282 #increment_index
283 }
284 })
285 .collect::<Vec<_>>();
286
287 match variant.variant.fields {
288 Fields::Named(_) => {
289 quote! {{
290 #(#reads)*
291 #name::#ident {
292 #(#field_names),*
293 }
294 }}
295 }
296 Fields::Unnamed(_) => {
297 quote! {{
298 #(#reads)*
299 #name::#ident(
300 #(#field_names),*
301 )
302 }}
303 }
304 Fields::Unit => quote! {
305 #name::#ident,
306 },
307 }
308 });
309 let read_variants_matches: Vec<proc_macro2::TokenStream> = variants
310 .iter()
311 .enumerate()
312 .zip(variant_reads)
313 .map(|((i, variant), read)| {
314 let hash = syn::LitInt::new(&i.to_string(), variant.variant.span());
315 quote! {
316 #hash => #read
317 }
318 })
319 .collect();
320 let variant_writes = variants.iter().map(|variant| {
321 let field_names = variant
322 .fields
323 .field_names
324 .iter()
325 .map(|name| match name {
326 FieldName::Index(i) => Ident::new(&format!("__{}", i.index), i.span),
327 FieldName::Ident(field) => field.clone(),
328 })
329 .collect::<Vec<_>>();
330 quote! {
331 #(let index = #field_names.write_slab(index, slab);)*
332 }
333 });
334 let write_variants_matches: Vec<proc_macro2::TokenStream> = variants
335 .iter()
336 .enumerate()
337 .zip(variant_writes)
338 .map(|((i, variant), write)| {
339 let hash = syn::LitInt::new(&i.to_string(), variant.variant.span());
340 let field_names = variant
341 .fields
342 .field_names
343 .iter()
344 .map(|name| match name {
345 FieldName::Index(i) => Ident::new(&format!("__{}", i.index), i.span),
346 FieldName::Ident(field) => field.clone(),
347 })
348 .collect::<Vec<_>>();
349 let ident = &variant.variant.ident;
350 let pat_match = match variant.variant.fields {
351 Fields::Named(_) => {
352 quote! {
353 #name::#ident {
354 #(#field_names,)*
355 }
356 }
357 }
358 Fields::Unnamed(_) => {
359 quote! {
360 #name::#ident(
361 #(#field_names,)*
362 )
363 }
364 }
365 Fields::Unit => quote! {
366 #name::#ident
367 },
368 };
369 quote! {
370 #pat_match => {
371 let __hash: u32 = #hash;
372 let index = __hash.write_slab(index, slab);
373 #write
374 original_index + slab_size
375 }
376 }
377 })
378 .collect();
379
380 let output = quote! {
381 #[automatically_derived]
382 impl #impl_generics crabslab::SlabItem for #name #ty_generics #where_clause
383 {
384 const SLAB_SIZE: usize = {#slab_size};
385
386 fn read_slab(mut index: usize, slab: &[u32]) -> Self {
387 let hash = u32::read_slab(index, slab);
389 index += 1;
390 match hash {
391 #(#read_variants_matches)*
392 _ => Default::default(),
393 }
394 }
395
396 fn write_slab(&self, index: usize, slab: &mut [u32]) -> usize {
397 let slab_size = Self::SLAB_SIZE;
398 let original_index = index;
399 match self {
400 #(#write_variants_matches)*
401 }
402 }
403 }
404 };
405 output.into()
406}
407
408fn derive_from_slab_struct(
409 input: DeriveInput,
410 params: FieldParams,
411 gen_offsets: bool,
413) -> proc_macro::TokenStream {
414 let FieldParams {
415 field_tys,
416 field_names,
417 } = params;
418
419 let name = &input.ident;
420 let is_struct_style = !matches!(field_names.first(), Some(FieldName::Index(_)));
421 let mut generics = input.generics;
422 {
423 fn constrain_system_data_types(clause: &mut WhereClause, tys: &[Type]) {
425 for ty in tys.iter() {
426 let where_predicate: WherePredicate = syn::parse_quote!(#ty : crabslab::SlabItem);
427 clause.predicates.push(where_predicate);
428 }
429 }
430
431 let where_clause = generics.make_where_clause();
432 constrain_system_data_types(where_clause, &field_tys)
433 }
434 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
435 let read_field_names = field_names
436 .iter()
437 .zip(field_tys.iter())
438 .enumerate()
439 .map(|(i, (name, ty))| {
440 let var = Ident::new(&format!("__{i}"), ty.span());
441 let inner = quote! {{
442 let #var = <#ty as crabslab::SlabItem>::read_slab(index, slab);
443 index += <#ty as crabslab::SlabItem>::SLAB_SIZE;
444 #var
445 }};
446 match name {
447 FieldName::Index(_) => inner,
448 FieldName::Ident(n) => {
449 quote! {
450 #n: #inner
451 }
452 }
453 }
454 })
455 .collect::<Vec<_>>();
456 let read_impl = if is_struct_style {
457 quote! {
458 Self { #(#read_field_names),* }
459 }
460 } else {
461 quote! {
462 Self( #(#read_field_names),* )
463 }
464 };
465 let write_index_field_names = field_names
466 .iter()
467 .map(|name| match name {
468 FieldName::Index(i) => quote! {
469 let index = self.#i.write_slab(index, slab);
470 },
471 FieldName::Ident(field) => quote! {
472 let index = self.#field.write_slab(index, slab);
473 },
474 })
475 .collect::<Vec<_>>();
476
477 let mut offset_tys = vec![];
478 let mut offsets = vec![];
479 for (name, ty) in field_names.iter().zip(field_tys.iter()) {
480 let (offset_of_ident, slab_size_of_ident) = match name {
481 FieldName::Index(i) => (
482 Ident::new(&format!("OFFSET_OF_{}", i.index), i.span),
483 Ident::new(&format!("SLAB_SIZE_OF_{}", i.index), i.span),
484 ),
485 FieldName::Ident(field) => (
486 Ident::new(
487 &format!("OFFSET_OF_{}", field.to_string().to_uppercase()),
488 field.span(),
489 ),
490 Ident::new(
491 &format!("SLAB_SIZE_OF_{}", field.to_string().to_uppercase()),
492 field.span(),
493 ),
494 ),
495 };
496 offsets.push(quote! {
497 pub const #offset_of_ident: crabslab::offset::Offset<#ty, Self> = {
498 crabslab::offset::Offset::new(
499 #(<#offset_tys as crabslab::SlabItem>::SLAB_SIZE+)*
500 0
501 )
502 };
503 pub const #slab_size_of_ident: usize = {
504 <#ty as crabslab::SlabItem>::SLAB_SIZE
505 };
506 });
507 offset_tys.push(ty.clone());
508 }
509
510 let offsets_output = if gen_offsets {
511 quote! {
512 #[automatically_derived]
513 impl #impl_generics #name #ty_generics {
515 #(#offsets)*
516 }
517 }
518 } else {
519 quote! {}
520 };
521
522 let output = quote! {
523 #[automatically_derived]
524 impl #impl_generics crabslab::SlabItem for #name #ty_generics #where_clause
525 {
526 const SLAB_SIZE: usize = {
527 #( <#field_tys as crabslab::SlabItem>::SLAB_SIZE )+*
528 };
529
530 fn read_slab(mut index: usize, slab: &[u32]) -> Self {
531 #read_impl
532 }
533
534 fn write_slab(&self, index: usize, slab: &mut [u32]) -> usize {
535 #(#write_index_field_names)*
536 index
537 }
538 }
539 #offsets_output
540 };
541 output.into()
542}
543
544#[proc_macro]
545pub fn impl_slabitem_tuples(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
546 let tuple: TypeTuple = syn::parse_macro_input!(input);
547 let tys = tuple.elems.iter().collect::<Vec<_>>();
548 let indices = tys
549 .iter()
550 .enumerate()
551 .map(|(i, _)| Index::from(i))
552 .collect::<Vec<_>>();
553 let reads = tys
554 .iter()
555 .enumerate()
556 .map(|(i, ty)| {
557 let var = Ident::new(&format!("__{i}"), ty.span());
558 quote! {{
559 let #var = <#ty as crabslab::SlabItem>::read_slab(index, slab);
560 index += <#ty as crabslab::SlabItem>::SLAB_SIZE;
561 #var
562 }}
563 })
564 .collect::<Vec<_>>();
565 let output = quote! {
566 impl<#(#tys),*> crabslab::SlabItem for #tuple
567 where
568 #(#tys: crabslab::SlabItem),*,
569 {
570 const SLAB_SIZE: usize = {
571 #(#tys::SLAB_SIZE )+*
572 };
573 fn read_slab(mut index: usize, slab: &[u32]) -> Self {
574 (
575 #( #reads ,)*
576 )
577 }
578 fn write_slab(&self, index: usize, slab: &mut [u32]) -> usize {
579 #(let index = self.#indices.write_slab(index, slab);)*
580 index
581 }
582 }
583 };
584 output.into()
585}
586
587#[proc_macro_derive(IsContainer, attributes(proxy, skip_proxy_definition, array))]
597pub fn impl_derive_is_container(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
598 let input: DeriveInput = syn::parse_macro_input!(input);
599 let ident = input.ident.clone();
600 let proxy = input
601 .attrs
602 .iter()
603 .find_map(|att| {
604 if att.path().is_ident("proxy") {
605 let ident: Ident = att.parse_args().ok()?;
606 Some(ident)
607 } else {
608 None
609 }
610 })
611 .unwrap_or_else(|| format_ident!("{}Container", input.ident));
612 let is_array = input.attrs.iter().any(|att| att.path().is_ident("array"));
613 let (pointer_ty, get_pointer_impl) = if is_array {
614 (
615 quote! {
616 type Pointer<T> = Array<T>;
617 },
618 quote! {
619 fn get_pointer<T>(container: &Self::Container<T>) -> Self::Pointer<T> {
620 container.array()
621 }
622 },
623 )
624 } else {
625 (
626 quote! {
627 type Pointer<T> = Id<T>;
628 },
629 quote! {
630 fn get_pointer<T>(container: &Self::Container<T>) -> Self::Pointer<T> {
631 container.id()
632 }
633 },
634 )
635 };
636
637 let should_define_proxy = !input
638 .attrs
639 .iter()
640 .any(|att| att.path().is_ident("skip_proxy_definition"));
641 let proxy_def = if should_define_proxy {
642 quote! {
643 #[derive(Clone, Copy, Debug)]
644 pub struct #proxy;
645 }
646 } else {
647 quote! {}
648 };
649
650 quote! {
651 #proxy_def
652 impl IsContainer for #proxy {
653 type Container<T> = #ident<T>;
654 #pointer_ty
655
656 #get_pointer_impl
657 }
658 }
659 .into()
660}