linearize_derive/
lib.rs

1use {
2    proc_macro2::{Ident, Literal, Span, TokenStream, TokenTree},
3    quote::{quote, quote_spanned},
4    syn::{
5        parse::{Parse, ParseStream},
6        parse_macro_input, parse_quote,
7        spanned::Spanned,
8        Attribute, Error, Generics, Item, ItemEnum, ItemStruct, LitInt, Path, Token, Type,
9    },
10};
11
12/// A proc macro to derive the `Linearize` trait.
13///
14/// This macro can be used to derive the `Linearize` trait for structs and enums.
15///
16/// The structure of these types can be arbitrary except that all contained fields must
17/// also implement the `Linearize` trait.
18///
19/// # Using different crate names
20///
21/// If you use the linearize crate under a name other than `linearize`, you can use the
22/// `crate` attribute to have the proc macro reference the correct crate. For example,
23/// if you import the linearize crate like this:
24///
25/// ```toml
26/// linearize-0_1 = { package = "linearize", version = "0.1" }
27/// ```
28///
29/// Then you can use this attribute as follows:
30///
31/// ```rust,ignore
32/// #[derive(Linearize)]
33/// #[linearize(crate = linearize_0_1)]
34/// struct S;
35/// ```
36///
37/// <div class="warning">
38///
39/// If you import the linearize crate under a name other than `linearize` or use the crate
40/// attribute, you must ensure that these two names are in sync. Otherwise the macro
41/// might not uphold the invariants of the `Linearize` trait.
42///
43/// </div>
44///
45/// # Implementing const functions
46///
47/// If you want to use the forms of the `static_map` and `static_copy_map` macros that
48/// work in constants and statics, you must enable the `const` attribute:
49///
50/// ```rust,ignore
51/// #[derive(Linearize)]
52/// #[linearize(const)]
53/// struct S;
54/// ```
55///
56/// In this case, your type must only contain fields that also enabled this attribute. In
57/// particular, you cannot use any of the standard types `u8`, `bool`, etc.
58///
59/// # Performance
60///
61/// If the type is a C-style enum with default discriminants, the derived functions will
62/// be compiled to a jump table in debug mode and will be completely optimized away in
63/// release mode.
64///
65/// If the type contains fields, the generated code will still be reasonably efficient.
66///
67/// # Limitations
68///
69/// While this macro fully supports types with generics, the generated output will not
70/// compile. This is due to limitations of the rust type system. If a future version of
71/// the rust compiler lifts these limitations, this macro will automatically start working
72/// for generic types.
73#[proc_macro_derive(Linearize, attributes(linearize))]
74pub fn derive_linearize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
75    let mut input: Input = parse_macro_input!(input as Input);
76    let crate_name = &input.attributes.crate_name;
77    let FullyLinearized {
78        linearize,
79        delinearize,
80        const_linearize,
81        const_delinearize,
82        const_names,
83        consts,
84        max_len,
85    } = input.build_linearize();
86    let where_clause = input.generics.make_where_clause();
87    for ty in &input.critical_types {
88        where_clause
89            .predicates
90            .push(parse_quote!(#ty: #crate_name::Linearize));
91    }
92    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
93    let ident = input.ident;
94    let mut const_impl = quote! {};
95    if input.attributes.enable_const {
96        const_impl = quote! {
97            #[doc(hidden)]
98            impl #impl_generics #ident #type_generics #where_clause {
99                #[inline]
100                pub const fn __linearize_d66aa8fa_6974_4651_b2b7_75291a9e7105(&self) -> usize {
101                    #const_linearize
102                }
103
104                #[inline]
105                pub const unsafe fn __from_linear_unchecked_fb2f0b31_5b5a_48b4_9264_39d0bdf94f1d(linear: usize) -> Self {
106                    #const_delinearize
107                }
108            }
109        };
110    }
111    let res = quote_spanned! { input.span =>
112        #[allow(clippy::modulo_one, clippy::manual_range_contains)]
113        const _: () = {
114            trait __C {
115                #(const #const_names: usize;)*
116            }
117
118            impl #impl_generics __C for #ident #type_generics #where_clause {
119                #(#consts)*
120            }
121
122            // SAFETY:
123            //
124            // Storage and CopyStorage obviously are the required type.
125            //
126            // The bodies if `linearize` and `from_linear_unchecked` are generated as follows:
127            //
128            // First, consider a struct s = { a1: T1, ..., an: Tn }. The calculated LENGTH
129            // is the product of the lengths of the Ti. We write |T| for the LENGTH of T.
130            // Write Bi = |T{i+1}| * ... * |Tn|, the product of the LENGTHs of the later types.
131            // Write linear(v) for the linearization of v. Then we define
132            // linear(s) = \sum_{i} linear(ai) * Bi.
133            // It is easy to see that linear(s) / Bi % Ti = linear(ai).
134            // Therefore we have created a bijection between the struct and [0, B0).
135            //
136            // Now consider an enum e = { V1, ..., Vn } where each variant can have fields.
137            // Each Vi can be treated like a struct and we can define a bijection between
138            // the enum and [0, |V1| + ... + |Vn|) by mapping V1 to [0, |V1|), V2 to
139            // [|V1|, |V1| + |V2|), and so on.
140            #[automatically_derived]
141            unsafe impl #impl_generics
142            #crate_name::Linearize for #ident #type_generics
143            #where_clause
144            {
145                type Storage<__T> = [__T; <Self as #crate_name::Linearize>::LENGTH];
146
147                type CopyStorage<__T> = [__T; <Self as #crate_name::Linearize>::LENGTH] where __T: Copy;
148
149                const LENGTH: usize = <Self as __C>::#max_len;
150
151                #[inline]
152                fn linearize(&self) -> usize {
153                    #linearize
154                }
155
156                #[inline]
157                unsafe fn from_linear_unchecked(linear: usize) -> Self {
158                    #delinearize
159                }
160            }
161
162            #const_impl
163        };
164    };
165    res.into()
166}
167
168struct Input {
169    span: Span,
170    ident: Ident,
171    generics: Generics,
172    critical_types: Vec<Type>,
173    kind: Kind,
174    attributes: InputAttributes,
175}
176
177struct InputAttributes {
178    crate_name: Path,
179    enable_const: bool,
180}
181
182#[derive(Default)]
183struct InputAttributesOpt {
184    crate_name: Option<Path>,
185    enable_const: bool,
186}
187
188enum Kind {
189    Struct(StructInput),
190    Enum(EnumInput),
191}
192
193struct StructInput {
194    fields: Vec<StructField>,
195}
196
197struct EnumInput {
198    variants: Vec<EnumVariant>,
199}
200
201struct EnumVariant {
202    ident: Ident,
203    fields: Vec<StructField>,
204}
205
206struct PartialLinearized {
207    linearize: TokenStream,
208    delinearize: TokenStream,
209    const_linearize: TokenStream,
210    const_delinearize: TokenStream,
211    max_len: Option<TokenStream>,
212}
213
214struct FullyLinearized {
215    linearize: TokenStream,
216    delinearize: TokenStream,
217    const_linearize: TokenStream,
218    const_delinearize: TokenStream,
219    const_names: Vec<Ident>,
220    consts: Vec<TokenStream>,
221    max_len: Ident,
222}
223
224struct StructField {
225    original_name: Option<Ident>,
226    generated_name: Option<Ident>,
227    ty: Type,
228}
229
230fn build_linearize_struct(
231    input: &Input,
232    fields: &[StructField],
233    base: &Ident,
234) -> PartialLinearized {
235    let crate_name = &input.attributes.crate_name;
236    let mut linearize_parts = vec![];
237    let mut delinearize_parts = vec![];
238    let mut const_linearize_parts = vec![];
239    let mut const_delinearize_parts = vec![];
240    let mut max_len = quote!(1usize);
241    for (idx, field) in fields.iter().enumerate().rev() {
242        let idx = LitInt::new(&idx.to_string(), Span::call_site());
243        let ref_name = match &field.generated_name {
244            Some(i) => quote! {#i},
245            None => match &field.original_name {
246                Some(i) => quote! { &self.#i },
247                None => quote! { &self.#idx },
248            },
249        };
250        let mut_name = match &field.original_name {
251            Some(i) => quote! { #i },
252            None => quote! { #idx },
253        };
254        let ty = &field.ty;
255        linearize_parts.push(quote! {
256            res = res.wrapping_add(<#ty as #crate_name::Linearize>::linearize(#ref_name).wrapping_mul(const { #max_len }));
257        });
258        delinearize_parts.push(quote! {
259            #mut_name: {
260                let idx = (linear / const { #max_len }) % <#ty as #crate_name::Linearize>::LENGTH;
261                <#ty as #crate_name::Linearize>::from_linear_unchecked(idx)
262            },
263        });
264        if input.attributes.enable_const {
265            const_linearize_parts.push(quote! {
266                res = res.wrapping_add(<#ty>::__linearize_d66aa8fa_6974_4651_b2b7_75291a9e7105(#ref_name).wrapping_mul(const { #max_len }));
267            });
268            const_delinearize_parts.push(quote! {
269                #mut_name: {
270                    let idx = (linear / const { #max_len }) % <#ty as #crate_name::Linearize>::LENGTH;
271                    <#ty>::__from_linear_unchecked_fb2f0b31_5b5a_48b4_9264_39d0bdf94f1d(idx)
272                },
273            });
274        }
275        max_len = quote! {
276            #max_len * <#ty as #crate_name::Linearize>::LENGTH
277        };
278    }
279    delinearize_parts.reverse();
280    const_delinearize_parts.reverse();
281    let make_linearize = |parts: &[TokenStream]| {
282        if fields.is_empty() {
283            quote! { <Self as __C>::#base }
284        } else {
285            quote! {
286                let mut res = <Self as __C>::#base;
287                #(#parts)*
288                res
289            }
290        }
291    };
292    let make_delinearize = |parts: &[TokenStream]| {
293        quote! {
294            { #(#parts)* }
295        }
296    };
297    let mut max_len = Some(max_len);
298    if fields.is_empty() {
299        max_len = None;
300    }
301    PartialLinearized {
302        linearize: make_linearize(&linearize_parts),
303        delinearize: make_delinearize(&delinearize_parts),
304        const_linearize: make_linearize(&const_linearize_parts),
305        const_delinearize: make_delinearize(&const_delinearize_parts),
306        max_len,
307    }
308}
309
310impl StructInput {
311    fn build_linearize(&self, input: &Input) -> FullyLinearized {
312        let b0 = Ident::new("B0", Span::mixed_site());
313        let b1 = Ident::new("B1", Span::mixed_site());
314        let PartialLinearized {
315            linearize,
316            delinearize,
317            const_linearize,
318            const_delinearize,
319            max_len,
320        } = build_linearize_struct(input, &self.fields, &b0);
321        let max_len = max_len.unwrap_or_else(|| quote!(1usize));
322        let mut consts = vec![];
323        consts.push(quote! { const B0: usize = 0; });
324        consts.push(quote! { const B1: usize = #max_len; });
325        FullyLinearized {
326            linearize,
327            delinearize: quote! { Self #delinearize },
328            const_linearize,
329            const_delinearize: quote! { Self #const_delinearize },
330            max_len: b1.clone(),
331            consts,
332            const_names: vec![b0, b1],
333        }
334    }
335}
336
337impl EnumInput {
338    fn build_linearize(&self, input: &Input) -> FullyLinearized {
339        let mut linearize_cases = vec![];
340        let mut delinearize_cases = vec![];
341        let mut const_linearize_cases = vec![];
342        let mut const_delinearize_cases = vec![];
343        let mut consts = vec![];
344        consts.push(quote! { const B0: usize = 0; });
345        let mut prev_const_name = Ident::new("B0", Span::mixed_site());
346        let mut const_base = prev_const_name.clone();
347        let mut const_base_offset = 0;
348        let mut const_names = vec![prev_const_name.clone()];
349        for (variant_idx, variant) in self.variants.iter().enumerate() {
350            let mut exposition = vec![];
351            for (idx, field) in variant.fields.iter().enumerate() {
352                let idx = LitInt::new(&idx.to_string(), Span::call_site());
353                let generated_name = field.generated_name.as_ref().unwrap();
354                match &field.original_name {
355                    None => exposition.push(quote! { #idx: #generated_name }),
356                    Some(i) => exposition.push(quote! { #i: #generated_name }),
357                }
358            }
359            let exposition = quote! {
360                { #(#exposition),* }
361            };
362            let PartialLinearized {
363                linearize,
364                delinearize,
365                const_linearize,
366                const_delinearize,
367                max_len,
368            } = build_linearize_struct(input, &variant.fields, &prev_const_name);
369            let next_base = match &max_len {
370                Some(len) => quote! { <Self as __C>::#prev_const_name + #len },
371                None => {
372                    const_base_offset += 1;
373                    let offset = Literal::usize_unsuffixed(const_base_offset);
374                    quote! { <Self as __C>::#const_base + #offset }
375                }
376            };
377            let ident = &variant.ident;
378            linearize_cases.push(quote! {
379                Self::#ident #exposition => {
380                    #linearize
381                }
382            });
383            if input.attributes.enable_const {
384                const_linearize_cases.push(quote! {
385                    Self::#ident #exposition => {
386                        #const_linearize
387                    }
388                });
389            }
390            let const_name = Ident::new(&format!("B{}", variant_idx + 1), Span::mixed_site());
391            consts.push(quote! { const #const_name: usize = #next_base; });
392            if variant.fields.is_empty() {
393                let guard = if input.generics.params.is_empty() {
394                    quote! {
395                        <Self as __C>::#prev_const_name
396                    }
397                } else {
398                    quote! {
399                        n if n == <Self as __C>::#prev_const_name
400                    }
401                };
402                delinearize_cases.push(quote! {
403                    #guard => Self::#ident { },
404                });
405                if input.attributes.enable_const {
406                    const_delinearize_cases.push(quote! {
407                        #guard => Self::#ident { },
408                    });
409                }
410            } else {
411                let make_case = |delinearize: &TokenStream| {
412                    quote! {
413                        #[allow(clippy::impossible_comparisons)]
414                        n if n >= <Self as __C>::#prev_const_name && n < <Self as __C>::#const_name => {
415                            let linear = linear.wrapping_sub(<Self as __C>::#prev_const_name);
416                            Self::#ident #delinearize
417                        },
418                    }
419                };
420                delinearize_cases.push(make_case(&delinearize));
421                if input.attributes.enable_const {
422                    const_delinearize_cases.push(make_case(&const_delinearize));
423                }
424            }
425            prev_const_name = const_name;
426            const_names.push(prev_const_name.clone());
427            if max_len.is_some() {
428                const_base = prev_const_name.clone();
429                const_base_offset = 0;
430            }
431        }
432        let make_linearize = |cases: &[TokenStream]| {
433            if self.variants.is_empty() {
434                quote! {
435                    #[cold]
436                    const fn unreachable() -> ! {
437                        unsafe { core::hint::unreachable_unchecked() }
438                    }
439                    unreachable()
440                }
441            } else {
442                quote! {
443                    match self {
444                        #(#cases)*
445                    }
446                }
447            }
448        };
449        let make_delinearize = |cases: &[TokenStream]| {
450            quote! {
451                match linear {
452                    #(#cases)*
453                    _ => {
454                        #[cold]
455                        const fn unreachable() -> ! {
456                            unsafe { core::hint::unreachable_unchecked() }
457                        }
458                        unreachable()
459                    },
460                }
461            }
462        };
463        FullyLinearized {
464            linearize: make_linearize(&linearize_cases),
465            const_linearize: make_linearize(&const_linearize_cases),
466            delinearize: make_delinearize(&delinearize_cases),
467            const_delinearize: make_delinearize(&const_delinearize_cases),
468            max_len: prev_const_name,
469            const_names,
470            consts,
471        }
472    }
473}
474
475impl Input {
476    fn parse_enum(input: ItemEnum) -> syn::Result<Self> {
477        let span = input.span();
478        let mut critical_types = Vec::new();
479        let mut variants = vec![];
480        let mut i = 0;
481        for variant in input.variants {
482            let mut fields = vec![];
483            for field in variant.fields {
484                critical_types.push(field.ty.clone());
485                let name = Ident::new(&format!("f{i}"), Span::mixed_site());
486                i += 1;
487                fields.push(StructField {
488                    original_name: field.ident,
489                    generated_name: Some(name),
490                    ty: field.ty,
491                })
492            }
493            variants.push(EnumVariant {
494                ident: variant.ident,
495                fields,
496            });
497        }
498        Ok(Self {
499            span,
500            ident: input.ident,
501            generics: input.generics,
502            critical_types,
503            kind: Kind::Enum(EnumInput { variants }),
504            attributes: parse_attributes(&input.attrs)?,
505        })
506    }
507
508    fn parse_struct(input: ItemStruct) -> syn::Result<Self> {
509        let span = input.span();
510        let mut critical_types = Vec::new();
511        let mut fields = vec![];
512        for field in input.fields {
513            critical_types.push(field.ty.clone());
514            fields.push(StructField {
515                original_name: field.ident,
516                generated_name: None,
517                ty: field.ty,
518            });
519        }
520        Ok(Self {
521            span,
522            ident: input.ident,
523            generics: input.generics,
524            critical_types,
525            kind: Kind::Struct(StructInput { fields }),
526            attributes: parse_attributes(&input.attrs)?,
527        })
528    }
529
530    fn build_linearize(&self) -> FullyLinearized {
531        match &self.kind {
532            Kind::Struct(s) => s.build_linearize(self),
533            Kind::Enum(e) => e.build_linearize(self),
534        }
535    }
536}
537
538fn parse_attributes(attrs: &[Attribute]) -> syn::Result<InputAttributes> {
539    let mut res = InputAttributesOpt::default();
540    for attr in attrs {
541        if !attr.meta.path().is_ident("linearize") {
542            continue;
543        }
544        let new: InputAttributesOpt = attr.meta.require_list()?.parse_args()?;
545        res.enable_const |= new.enable_const;
546        macro_rules! opt {
547            ($name:ident) => {
548                if new.$name.is_some() {
549                    res.$name = new.$name;
550                }
551            };
552        }
553        opt!(crate_name);
554    }
555    Ok(InputAttributes {
556        crate_name: res.crate_name.unwrap_or_else(|| parse_quote!(::linearize)),
557        enable_const: res.enable_const,
558    })
559}
560
561impl Parse for InputAttributesOpt {
562    fn parse(input: ParseStream) -> syn::Result<Self> {
563        let mut res = Self::default();
564        while !input.is_empty() {
565            let key: TokenTree = input.parse()?;
566            match key.to_string().as_str() {
567                "crate" => {
568                    let _: Token![=] = input.parse()?;
569                    let path: Path = input.parse()?;
570                    res.crate_name = Some(path);
571                }
572                "const" => {
573                    res.enable_const = true;
574                }
575                _ => {
576                    return Err(Error::new(
577                        key.span(),
578                        format!("Unknown attribute: {}", key),
579                    ))
580                }
581            }
582            if !input.is_empty() {
583                let _: Token![,] = input.parse()?;
584            }
585        }
586        Ok(res)
587    }
588}
589
590impl Parse for Input {
591    fn parse(input: ParseStream) -> syn::Result<Self> {
592        let item: Item = input.parse()?;
593        match item {
594            Item::Enum(e) => Self::parse_enum(e),
595            Item::Struct(s) => Self::parse_struct(s),
596            _ => Err(Error::new(item.span(), "expected enum or struct")),
597        }
598    }
599}