crabslab_derive/
lib.rs

1//! Provides derive macros for `crabslab`.
2use quote::{format_ident, quote};
3use syn::{
4    spanned::Spanned, Data, DataEnum, DataStruct, DeriveInput, Fields, FieldsNamed, FieldsUnnamed,
5    Ident, Index, Type, TypeTuple, WhereClause, WherePredicate,
6};
7
8enum FieldName {
9    Index(Index),
10    Ident(Ident),
11}
12
13struct FieldParams {
14    field_tys: Vec<Type>,
15    field_names: Vec<FieldName>,
16}
17
18impl FieldParams {
19    fn new(fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>) -> Self {
20        let field_tys: Vec<_> = fields.iter().map(|field| field.ty.clone()).collect();
21        let field_names: Vec<_> = fields
22            .iter()
23            .enumerate()
24            .map(|(i, field)| {
25                field
26                    .ident
27                    .clone()
28                    .map(FieldName::Ident)
29                    .unwrap_or_else(|| {
30                        FieldName::Index(Index {
31                            index: i as u32,
32                            span: field.span(),
33                        })
34                    })
35            })
36            .collect();
37        Self {
38            field_tys,
39            field_names,
40        }
41    }
42}
43
44fn get_struct_params(ds: &DataStruct) -> FieldParams {
45    let empty_punctuated = syn::punctuated::Punctuated::new();
46    let fields = match ds {
47        DataStruct {
48            fields: Fields::Named(FieldsNamed { named: ref x, .. }),
49            ..
50        } => x,
51        DataStruct {
52            fields: Fields::Unnamed(FieldsUnnamed { unnamed: ref x, .. }),
53            ..
54        } => x,
55        DataStruct {
56            fields: Fields::Unit,
57            ..
58        } => &empty_punctuated,
59    };
60
61    FieldParams::new(fields)
62}
63
64struct EnumVariant {
65    variant: syn::Variant,
66    fields: FieldParams,
67}
68
69struct EnumParams {
70    variants: Vec<EnumVariant>,
71    slab_size: proc_macro2::TokenStream,
72}
73
74fn get_enum_params(de: &DataEnum) -> EnumParams {
75    let DataEnum {
76        enum_token: _,
77        brace_token: _,
78        variants,
79    } = de;
80    let variants = variants
81        .iter()
82        .map(|variant| {
83            let empty_fields = syn::punctuated::Punctuated::new();
84            let fields = match &variant.fields {
85                Fields::Named(FieldsNamed { named: ref x, .. }) => x,
86                Fields::Unnamed(FieldsUnnamed { unnamed: ref x, .. }) => x,
87                Fields::Unit => &empty_fields,
88            };
89            let fields = FieldParams::new(fields);
90            EnumVariant {
91                variant: variant.clone(),
92                fields,
93            }
94        })
95        .collect::<Vec<_>>();
96    let slab_size_def = quote! {
97        let mut __size = 0usize;
98    };
99    let slab_size_increments = variants
100        .iter()
101        .map(|variant| {
102            let tys = &variant.fields.field_tys;
103            if tys.is_empty() {
104                quote! {}
105            } else {
106                quote! {{
107                    let __field_size = #( <#tys as crabslab::SlabItem>::SLAB_SIZE )+*;
108                    __size += crabslab::__saturating_sub(__field_size,__size);
109                }}
110            }
111        })
112        .collect::<Vec<_>>();
113    EnumParams {
114        slab_size: quote! {
115            #slab_size_def
116            #(#slab_size_increments)*
117            // Add one for the enum variant
118            __size + 1
119        },
120        variants,
121    }
122}
123
124enum Params {
125    Struct(FieldParams),
126    Enum(EnumParams),
127}
128
129fn get_params(input: &DeriveInput) -> syn::Result<Params> {
130    match &input.data {
131        Data::Struct(ds) => Ok(Params::Struct(get_struct_params(ds))),
132        Data::Enum(de) => Ok(Params::Enum(get_enum_params(de))),
133        _ => Err(syn::Error::new(
134            input.span(),
135            "deriving SlabItem does not support unions".to_string(),
136        )),
137    }
138}
139
140/// Derives `SlabItem` for a struct.
141///
142/// ```rust
143/// use crabslab::{CpuSlab, GrowableSlab, Id, Slab, SlabItem};
144///
145/// #[derive(Debug, Default, PartialEq, SlabItem)]
146/// struct Foo {
147///     a: u32,
148///     b: u32,
149///     c: u32,
150/// }
151///
152/// let foo_one = Foo { a: 1, b: 2, c: 3 };
153/// let foo_two = Foo { a: 4, b: 5, c: 6 };
154///
155/// let mut slab = CpuSlab::new(vec![]);
156/// let foo_one_id = slab.append(&foo_one);
157/// let foo_two_id = slab.append(&foo_two);
158///
159/// // Overwrite the second item of the second `Foo`:
160/// slab.write(Id::<u32>::new(foo_two_id.inner() + 1), &42);
161/// assert_eq!(Foo { a: 4, b: 42, c: 6 }, slab.read(foo_two_id));
162/// ```
163///
164/// No such offsets are derived for enums.
165///
166/// ```rust
167/// use crabslab::{CpuSlab, GrowableSlab, Slab, SlabItem};
168///
169/// #[derive(Debug, Default, PartialEq, SlabItem)]
170/// struct Bar {
171///     a: u32,
172/// }
173///
174/// #[derive(Debug, Default, PartialEq, SlabItem)]
175/// enum Baz {
176///     #[default]
177///     One,
178///     Two {
179///         a: u32,
180///         b: u32,
181///     },
182///     Three(u32, u32),
183///     Four(Bar),
184/// }
185///
186/// assert_eq!(3, Baz::SLAB_SIZE);
187///
188/// let mut slab = CpuSlab::new(vec![]);
189///
190/// let one_id = slab.append(&Baz::One);
191/// let two_id = slab.append(&Baz::Two { a: 1, b: 2 });
192/// let three_id = slab.append(&Baz::Three(3, 4));
193/// let four_id = slab.append(&Baz::Four(Bar { a: 5 }));
194///
195/// assert_eq!(Baz::One, slab.read(one_id));
196/// assert_eq!(Baz::Two { a: 1, b: 2 }, slab.read(two_id));
197/// assert_eq!(Baz::Three(3, 4), slab.read(three_id));
198/// assert_eq!(Baz::Four(Bar { a: 5 }), slab.read(four_id));
199/// ```
200#[proc_macro_derive(SlabItem, attributes(offsets))]
201pub fn derive_from_slab(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
202    let input: DeriveInput = syn::parse_macro_input!(input);
203
204    let gen_offsets_span = input.attrs.iter().find_map(|attr| {
205        let path = attr.path();
206        if path.is_ident("offsets") {
207            Some(path.span())
208        } else {
209            None
210        }
211    });
212
213    match get_params(&input) {
214        Ok(Params::Struct(p)) => derive_from_slab_struct(input, p, gen_offsets_span.is_some()),
215        Ok(Params::Enum(p)) => {
216            if let Some(span) = gen_offsets_span {
217                syn::Error::new(span, "Deriving field offsets is not supported for enums")
218                    .into_compile_error()
219                    .into()
220            } else {
221                derive_from_slab_enum(input, p)
222            }
223        }
224        Err(e) => e.into_compile_error().into(),
225    }
226}
227
228fn derive_from_slab_enum(input: DeriveInput, params: EnumParams) -> proc_macro::TokenStream {
229    let EnumParams {
230        variants,
231        slab_size,
232    } = params;
233    let name = &input.ident;
234    let field_tys = variants
235        .iter()
236        .flat_map(|v| v.fields.field_tys.clone())
237        .collect::<Vec<_>>();
238    let mut generics = input.generics;
239    {
240        fn constrain_system_data_types(clause: &mut WhereClause, tys: &[Type]) {
241            for ty in tys.iter() {
242                let where_predicate: WherePredicate = syn::parse_quote!(#ty : crabslab::SlabItem);
243                clause.predicates.push(where_predicate);
244            }
245        }
246
247        let where_clause = generics.make_where_clause();
248        constrain_system_data_types(where_clause, &field_tys)
249    }
250    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
251
252    let variant_reads = variants.iter().map(|variant| {
253        let ident = &variant.variant.ident;
254        let field_names = variant
255            .fields
256            .field_names
257            .iter()
258            .map(|name| match name {
259                FieldName::Index(i) => Ident::new(&format!("__{}", i.index), i.span),
260                FieldName::Ident(field) => field.clone(),
261            })
262            .collect::<Vec<_>>();
263        let field_tys = &variant.fields.field_tys;
264        let num_fields = field_names.len();
265        let reads = field_names
266            .iter()
267            .zip(field_tys.iter())
268            .enumerate()
269            .map(|(i, (name, ty))| {
270                let def = quote! {
271                    let #name = <#ty as crabslab::SlabItem>::read_slab(index, slab);
272                };
273                let increment_index = if i + 1 < num_fields {
274                    quote! {
275                        index += <#ty as crabslab::SlabItem>::SLAB_SIZE;
276                    }
277                } else {
278                    quote! {}
279                };
280                quote! {
281                    #def
282                    #increment_index
283                }
284            })
285            .collect::<Vec<_>>();
286
287        match variant.variant.fields {
288            Fields::Named(_) => {
289                quote! {{
290                    #(#reads)*
291                     #name::#ident {
292                         #(#field_names),*
293                     }
294                }}
295            }
296            Fields::Unnamed(_) => {
297                quote! {{
298                    #(#reads)*
299                    #name::#ident(
300                        #(#field_names),*
301                    )
302                }}
303            }
304            Fields::Unit => quote! {
305                #name::#ident,
306            },
307        }
308    });
309    let read_variants_matches: Vec<proc_macro2::TokenStream> = variants
310        .iter()
311        .enumerate()
312        .zip(variant_reads)
313        .map(|((i, variant), read)| {
314            let hash = syn::LitInt::new(&i.to_string(), variant.variant.span());
315            quote! {
316                #hash => #read
317            }
318        })
319        .collect();
320    let variant_writes = variants.iter().map(|variant| {
321        let field_names = variant
322            .fields
323            .field_names
324            .iter()
325            .map(|name| match name {
326                FieldName::Index(i) => Ident::new(&format!("__{}", i.index), i.span),
327                FieldName::Ident(field) => field.clone(),
328            })
329            .collect::<Vec<_>>();
330        quote! {
331            #(let index = #field_names.write_slab(index, slab);)*
332        }
333    });
334    let write_variants_matches: Vec<proc_macro2::TokenStream> = variants
335        .iter()
336        .enumerate()
337        .zip(variant_writes)
338        .map(|((i, variant), write)| {
339            let hash = syn::LitInt::new(&i.to_string(), variant.variant.span());
340            let field_names = variant
341                .fields
342                .field_names
343                .iter()
344                .map(|name| match name {
345                    FieldName::Index(i) => Ident::new(&format!("__{}", i.index), i.span),
346                    FieldName::Ident(field) => field.clone(),
347                })
348                .collect::<Vec<_>>();
349            let ident = &variant.variant.ident;
350            let pat_match = match variant.variant.fields {
351                Fields::Named(_) => {
352                    quote! {
353                        #name::#ident {
354                            #(#field_names,)*
355                        }
356                    }
357                }
358                Fields::Unnamed(_) => {
359                    quote! {
360                        #name::#ident(
361                            #(#field_names,)*
362                        )
363                    }
364                }
365                Fields::Unit => quote! {
366                    #name::#ident
367                },
368            };
369            quote! {
370                #pat_match => {
371                    let __hash: u32 = #hash;
372                    let index = __hash.write_slab(index, slab);
373                    #write
374                    original_index + slab_size
375                }
376            }
377        })
378        .collect();
379
380    let output = quote! {
381        #[automatically_derived]
382        impl #impl_generics crabslab::SlabItem for #name #ty_generics #where_clause
383        {
384            const SLAB_SIZE: usize = {#slab_size};
385
386            fn read_slab(mut index: usize, slab: &[u32]) -> Self {
387                // Read the hash to tell which variant we're in.
388                let hash =  u32::read_slab(index, slab);
389                index += 1;
390                match hash {
391                    #(#read_variants_matches)*
392                    _ => Default::default(),
393                }
394            }
395
396            fn write_slab(&self, index: usize, slab: &mut [u32]) -> usize {
397                let slab_size = Self::SLAB_SIZE;
398                let original_index = index;
399                match self {
400                    #(#write_variants_matches)*
401                }
402            }
403        }
404    };
405    output.into()
406}
407
408fn derive_from_slab_struct(
409    input: DeriveInput,
410    params: FieldParams,
411    // Whether to generate field offsets
412    gen_offsets: bool,
413) -> proc_macro::TokenStream {
414    let FieldParams {
415        field_tys,
416        field_names,
417    } = params;
418
419    let name = &input.ident;
420    let is_struct_style = !matches!(field_names.first(), Some(FieldName::Index(_)));
421    let mut generics = input.generics;
422    {
423        /// Adds a `CanFetch<'lt>` bound on each of the system data types.
424        fn constrain_system_data_types(clause: &mut WhereClause, tys: &[Type]) {
425            for ty in tys.iter() {
426                let where_predicate: WherePredicate = syn::parse_quote!(#ty : crabslab::SlabItem);
427                clause.predicates.push(where_predicate);
428            }
429        }
430
431        let where_clause = generics.make_where_clause();
432        constrain_system_data_types(where_clause, &field_tys)
433    }
434    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
435    let read_field_names = field_names
436        .iter()
437        .zip(field_tys.iter())
438        .enumerate()
439        .map(|(i, (name, ty))| {
440            let var = Ident::new(&format!("__{i}"), ty.span());
441            let inner = quote! {{
442                let #var = <#ty as crabslab::SlabItem>::read_slab(index, slab);
443                index += <#ty as crabslab::SlabItem>::SLAB_SIZE;
444                #var
445            }};
446            match name {
447                FieldName::Index(_) => inner,
448                FieldName::Ident(n) => {
449                    quote! {
450                        #n: #inner
451                    }
452                }
453            }
454        })
455        .collect::<Vec<_>>();
456    let read_impl = if is_struct_style {
457        quote! {
458            Self { #(#read_field_names),* }
459        }
460    } else {
461        quote! {
462            Self( #(#read_field_names),* )
463        }
464    };
465    let write_index_field_names = field_names
466        .iter()
467        .map(|name| match name {
468            FieldName::Index(i) => quote! {
469                let index = self.#i.write_slab(index, slab);
470            },
471            FieldName::Ident(field) => quote! {
472                let index = self.#field.write_slab(index, slab);
473            },
474        })
475        .collect::<Vec<_>>();
476
477    let mut offset_tys = vec![];
478    let mut offsets = vec![];
479    for (name, ty) in field_names.iter().zip(field_tys.iter()) {
480        let (offset_of_ident, slab_size_of_ident) = match name {
481            FieldName::Index(i) => (
482                Ident::new(&format!("OFFSET_OF_{}", i.index), i.span),
483                Ident::new(&format!("SLAB_SIZE_OF_{}", i.index), i.span),
484            ),
485            FieldName::Ident(field) => (
486                Ident::new(
487                    &format!("OFFSET_OF_{}", field.to_string().to_uppercase()),
488                    field.span(),
489                ),
490                Ident::new(
491                    &format!("SLAB_SIZE_OF_{}", field.to_string().to_uppercase()),
492                    field.span(),
493                ),
494            ),
495        };
496        offsets.push(quote! {
497            pub const #offset_of_ident: crabslab::offset::Offset<#ty, Self> = {
498                crabslab::offset::Offset::new(
499                    #(<#offset_tys as crabslab::SlabItem>::SLAB_SIZE+)*
500                    0
501                )
502            };
503            pub const #slab_size_of_ident: usize = {
504                <#ty as crabslab::SlabItem>::SLAB_SIZE
505            };
506        });
507        offset_tys.push(ty.clone());
508    }
509
510    let offsets_output = if gen_offsets {
511        quote! {
512            #[automatically_derived]
513            /// Offsets into the slab buffer for each field.
514            impl #impl_generics #name #ty_generics {
515                #(#offsets)*
516            }
517        }
518    } else {
519        quote! {}
520    };
521
522    let output = quote! {
523        #[automatically_derived]
524        impl #impl_generics crabslab::SlabItem for #name #ty_generics #where_clause
525        {
526            const SLAB_SIZE: usize = {
527                #( <#field_tys as crabslab::SlabItem>::SLAB_SIZE )+*
528            };
529
530            fn read_slab(mut index: usize, slab: &[u32]) -> Self {
531                #read_impl
532            }
533
534            fn write_slab(&self, index: usize, slab: &mut [u32]) -> usize {
535                #(#write_index_field_names)*
536                index
537            }
538        }
539        #offsets_output
540    };
541    output.into()
542}
543
544#[proc_macro]
545pub fn impl_slabitem_tuples(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
546    let tuple: TypeTuple = syn::parse_macro_input!(input);
547    let tys = tuple.elems.iter().collect::<Vec<_>>();
548    let indices = tys
549        .iter()
550        .enumerate()
551        .map(|(i, _)| Index::from(i))
552        .collect::<Vec<_>>();
553    let reads = tys
554        .iter()
555        .enumerate()
556        .map(|(i, ty)| {
557            let var = Ident::new(&format!("__{i}"), ty.span());
558            quote! {{
559                    let #var = <#ty as crabslab::SlabItem>::read_slab(index, slab);
560                    index += <#ty as crabslab::SlabItem>::SLAB_SIZE;
561                    #var
562            }}
563        })
564        .collect::<Vec<_>>();
565    let output = quote! {
566        impl<#(#tys),*> crabslab::SlabItem for #tuple
567        where
568            #(#tys: crabslab::SlabItem),*,
569        {
570            const SLAB_SIZE: usize = {
571                #(#tys::SLAB_SIZE )+*
572            };
573            fn read_slab(mut index: usize, slab: &[u32]) -> Self {
574                (
575                    #( #reads ,)*
576                )
577            }
578            fn write_slab(&self, index: usize, slab: &mut [u32]) -> usize {
579                #(let index = self.#indices.write_slab(index, slab);)*
580                index
581            }
582        }
583    };
584    output.into()
585}
586
587/// Creates a proxy type to implement `IsContainer` for, where `IsContainer::Container`
588/// resolves to _the type being derived_.
589///
590/// That may be a bit confusing. In other words - invoking this derive macro on a type `A`
591/// creates an impl of `IsContainer` for a proxy type `B`, where `B::Container = A`.
592///
593/// ## Attributes:
594/// * **`proxy`** - If present, the generated type will be the argument of this attribute.
595/// * **`skip_proxy_definition`** - If present the generated type will not be defined.
596#[proc_macro_derive(IsContainer, attributes(proxy, skip_proxy_definition, array))]
597pub fn impl_derive_is_container(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
598    let input: DeriveInput = syn::parse_macro_input!(input);
599    let ident = input.ident.clone();
600    let proxy = input
601        .attrs
602        .iter()
603        .find_map(|att| {
604            if att.path().is_ident("proxy") {
605                let ident: Ident = att.parse_args().ok()?;
606                Some(ident)
607            } else {
608                None
609            }
610        })
611        .unwrap_or_else(|| format_ident!("{}Container", input.ident));
612    let is_array = input.attrs.iter().any(|att| att.path().is_ident("array"));
613    let (pointer_ty, get_pointer_impl) = if is_array {
614        (
615            quote! {
616                type Pointer<T> = Array<T>;
617            },
618            quote! {
619                fn get_pointer<T>(container: &Self::Container<T>) -> Self::Pointer<T> {
620                    container.array()
621                }
622            },
623        )
624    } else {
625        (
626            quote! {
627                type Pointer<T> = Id<T>;
628            },
629            quote! {
630                fn get_pointer<T>(container: &Self::Container<T>) -> Self::Pointer<T> {
631                    container.id()
632                }
633            },
634        )
635    };
636
637    let should_define_proxy = !input
638        .attrs
639        .iter()
640        .any(|att| att.path().is_ident("skip_proxy_definition"));
641    let proxy_def = if should_define_proxy {
642        quote! {
643            #[derive(Clone, Copy, Debug)]
644            pub struct #proxy;
645        }
646    } else {
647        quote! {}
648    };
649
650    quote! {
651        #proxy_def
652        impl IsContainer for #proxy {
653            type Container<T> = #ident<T>;
654            #pointer_ty
655
656            #get_pointer_impl
657        }
658    }
659    .into()
660}