frunk_enum_derive/
lib.rs

1//! This crate adds the `LabelledGenericEnum` derive which integrates enums with the frunk
2//! transmogrification function.
3
4extern crate proc_macro;
5
6use quote::quote;
7use syn::parse_macro_input;
8
9use syn::spanned::Spanned as _;
10
11/// A representation of a variant's member field.  This is far simpler than the one presented by
12/// `syn` as it always has a field name, and it doesn't track things like access permissions.
13struct Field {
14    pub ident: syn::Ident,
15    pub ty: syn::Type,
16}
17
18/// Convert the fields of a variant into a simpler, more consistent form.  This also creates
19/// artificial names for "tuple structs" (`_0`, `_1`, ...) so downstream code can just read the
20/// `ident` field and get the correct name.
21fn simplify_fields(fields: &syn::Fields) -> Vec<Field> {
22    use syn::Fields::*;
23    match fields {
24        Unit => Vec::new(),
25        Named(named) => named
26            .named
27            .iter()
28            .map(|f| Field {
29                ident: f.ident.as_ref().unwrap().clone(),
30                ty: f.ty.clone(),
31            })
32            .collect(),
33        Unnamed(unnamed) => unnamed
34            .unnamed
35            .iter()
36            .enumerate()
37            .map(|(i, f)| Field {
38                ident: syn::Ident::new(&format!("_{}", i), f.span()),
39                ty: f.ty.clone(),
40            })
41            .collect(),
42    }
43}
44
45/// Recursively pack up variants into a chain of `HCons` types (with associated field names).
46fn create_hlist_repr<'a>(mut fields: impl Iterator<Item = &'a Field>) -> proc_macro2::TokenStream {
47    match fields.next() {
48        None => quote!(frunk::HNil),
49        Some(Field { ref ident, ref ty }) => {
50            let tail = create_hlist_repr(fields);
51            let ident = frunk_proc_macro_helpers::build_label_type(ident);
52            quote!(frunk::HCons<frunk::labelled::Field<#ident, #ty>, #tail>)
53        }
54    }
55}
56
57/// Recursively pack up the variants into a chain of `HEither` generic enums (with associated
58/// variant names).
59fn create_repr_for0<'a>(
60    mut variants: impl Iterator<Item = &'a syn::Variant>,
61) -> proc_macro2::TokenStream {
62    match variants.next() {
63        None => quote!(frunk_enum_core::Void),
64        Some(v) => {
65            let ident_ty = frunk_proc_macro_helpers::build_label_type(&v.ident);
66            let fields = simplify_fields(&v.fields);
67            let hlist = create_hlist_repr(fields.iter());
68            let tail = create_repr_for0(variants);
69            quote! {
70                frunk_enum_core::HEither<frunk_enum_core::Variant<#ident_ty, #hlist>, #tail>
71            }
72        }
73    }
74}
75
76/// Generates the `Repr` for a given `enum` definition.
77///
78/// ```ignore
79/// type Repr = HEither<Variant<(f,i,r,s,t), Hlist![Field<(_0), A>]>, HEither<Variant<(s,e,c,o,n,d), Hlist![Field<(_0), B>]>, Void>>;
80/// ```
81fn create_repr_for(input: &syn::DataEnum) -> proc_macro2::TokenStream {
82    let repr = create_repr_for0(input.variants.iter());
83    quote!(type Repr = #repr;)
84}
85
86/// Create the body of the cases in the `into()` implementation for a given variant.  Assumes that
87/// the captured fields are in bindings named as per the field names.
88///
89/// The `depth` argument indicates how many `Tail` wrappers to add (e.g. how far down the `HEither`
90/// chain this variant lies).
91fn create_into_case_body_for<'a>(
92    ident: &syn::Ident,
93    fields: impl Iterator<Item = &'a Field>,
94    depth: usize,
95) -> proc_macro2::TokenStream {
96    let fields = fields.map(|f| {
97        let ident = &f.ident;
98        let ident_ty = frunk_proc_macro_helpers::build_label_type(ident);
99        quote!(frunk::field!(#ident_ty, #ident, stringify!(#ident)))
100    });
101    let ident_ty = frunk_proc_macro_helpers::build_label_type(ident);
102    let mut inner = quote!(frunk_enum_core::HEither::Head(
103        frunk_enum_core::variant!(#ident_ty, frunk::hlist![#(#fields),*], stringify!(#ident))
104    ));
105    for _ in 0..depth {
106        inner = quote!(frunk_enum_core::HEither::Tail(#inner))
107    }
108    inner
109}
110
111/// Create cases for the variants for the `into()` implementation.  Captures the fields of the
112/// variant into bindings corresponding to the field names.
113fn create_into_cases_for<'a>(
114    enum_ident: &'a syn::Ident,
115    variants: impl Iterator<Item = &'a syn::Variant> + 'a,
116) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
117    use syn::Fields::*;
118    variants.enumerate().map(move |(idx, v)| {
119        let variant_ident = &v.ident;
120        let labelled_fields = simplify_fields(&v.fields);
121        let pattern_vars = labelled_fields.iter().map(|f| &f.ident);
122        let body = create_into_case_body_for(variant_ident, labelled_fields.iter(), idx);
123
124        // Tediously patterns are rendered differently for the three styles so add appropriate wrapping
125        // here.
126        let pattern_vars = match v.fields {
127            Unit => quote!(),
128            Unnamed(_) => quote!((#(#pattern_vars),*)),
129            Named(_) => quote!({#(#pattern_vars),*}),
130        };
131
132        quote!(#enum_ident::#variant_ident #pattern_vars => #body)
133    })
134}
135
136/// Generate the implementation of `into()` for the given enum.
137///
138/// ```ignore
139///  fn into(self) -> Self::Repr {
140///     match self {
141///         First(v) => HEither::Head(variant!((f, i, r, s, t), hlist!(field!((_0), v)))),
142///         Second(v) => HEither::Tail(HEither::Head(variant!((s, e, c, o, n, d), hlist!(field!((_0), v))))),
143///     }
144/// }
145/// ```
146fn create_into_for(ident: &syn::Ident, input: &syn::DataEnum) -> proc_macro2::TokenStream {
147    let cases = create_into_cases_for(ident, input.variants.iter());
148    quote! {
149        fn into(self) -> Self::Repr {
150            match self {
151                #(#cases),*
152            }
153        }
154    }
155}
156
157/// Create the pattern of the case statement for the `from()` implementation.  This captures the
158/// fields from the `Repr` into bindings to be re-constituted into the `Self` in the case body.
159///
160/// The `depth` argument indicates how many `Tail` wrappers to unpack (e.g. how far down the `HEither`
161/// chain this variant lies).
162fn create_from_case_pattern_for<'a>(
163    fields: impl Iterator<Item = &'a Field>,
164    depth: usize,
165) -> proc_macro2::TokenStream {
166    let fields = fields.map(|f| &f.ident);
167    let mut inner = quote!(frunk_enum_core::HEither::Head(frunk_enum_core::Variant {
168        value: frunk::hlist_pat!(#(#fields),*),
169        ..
170    }));
171    for _ in 0..depth {
172        inner = quote!(frunk_enum_core::HEither::Tail(#inner));
173    }
174    inner
175}
176
177/// Create the body of the case statement for the `from()` implementation.  This builds the output
178/// variant from the captured fields.  It assumes that the fields are captured in the pattern as
179/// variables named as per the field identifiers.
180fn create_from_case_body_for<'a>(
181    ident: &syn::Ident,
182    variant: &syn::Variant,
183    fields: impl Iterator<Item = &'a Field>,
184) -> proc_macro2::TokenStream {
185    use syn::Fields::*;
186    let variant_ident = &variant.ident;
187    let fields = fields.map(|f| &f.ident);
188    let fields = match variant.fields {
189        Unit => quote!(),
190        Unnamed(_) => quote!((#(#fields.value),*)),
191        Named(_) => {
192            let fields = fields.map(|f| quote!(#f: #f.value));
193            quote!({#(#fields),*})
194        }
195    };
196    quote!(#ident::#variant_ident #fields)
197}
198
199/// Generate a case for the `from()` implementation for each variant (equivalently, for each `Repr`
200/// variant).  These cases are not complete (as they don't cover the "all Tail" case).
201fn create_from_cases_for<'a>(
202    enum_ident: &'a syn::Ident,
203    variants: impl Iterator<Item = &'a syn::Variant> + 'a,
204) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
205    variants.enumerate().map(move |(idx, variant)| {
206        let labelled_fields = simplify_fields(&variant.fields);
207        let pattern = create_from_case_pattern_for(labelled_fields.iter(), idx);
208        let body = create_from_case_body_for(enum_ident, variant, labelled_fields.iter());
209
210        quote!(#pattern => #body)
211    })
212}
213
214/// Generate an unreachable case for unpacking a `Repr` (the `Tail(Tail(...(Void)...))` case).
215fn create_void_from_case(depth: usize) -> proc_macro2::TokenStream {
216    let mut pattern = quote!(void);
217    for _ in 0..depth {
218        pattern = quote!(frunk_enum_core::HEither::Tail(#pattern));
219    }
220    quote!(#pattern => match void {})
221}
222
223/// Generate the implementation of `from()` for the given enum.
224///
225/// ```ignore
226/// fn from(repr: Self::Repr) -> Self {
227///     match repr {
228///         HEither::Head(Variant { value: hlist_pat!(v), .. }) => First(v.value),
229///         HEither::Tail(HEither::Head(Variant { value: hlist_pat!(v), .. }))=> Second(e.value),
230///         HEither::Tail(HEither::Tail(void)) => match void {}, // Unreachable
231///     }
232/// }
233/// ```
234///
235/// The final case is needed for match-completeness, but fortunately it's uninhabited, so it'll
236/// never be hit.
237fn create_from_for(ident: &syn::Ident, input: &syn::DataEnum) -> proc_macro2::TokenStream {
238    let cases = create_from_cases_for(ident, input.variants.iter());
239    let void_case = create_void_from_case(input.variants.len());
240    quote! {
241        fn from(repr: Self::Repr) -> Self {
242            match repr {
243                #(#cases),*,
244                #void_case,
245            }
246        }
247    }
248}
249
250/// Generates the complete derived code for an enum.  This is the main functional entrypoint for
251/// this crate, and is the entry point used for testing (as the proc-macro entrypoint cannot).
252fn generate_for_derive_input(
253    ident: &syn::Ident,
254    generics: &syn::Generics,
255    enum_: &syn::DataEnum,
256) -> proc_macro2::TokenStream {
257    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
258    let repr = create_repr_for(enum_);
259    let into = create_into_for(ident, enum_);
260    let from = create_from_for(ident, enum_);
261
262    quote! {
263        impl #impl_generics frunk::LabelledGeneric for #ident #ty_generics #where_clause {
264            #repr
265            #into
266            #from
267        }
268    }
269}
270
271/// ```edition2018
272/// #[derive(frunk_enum_derive::LabelledGenericEnum)]
273/// enum Foo<A, B> {
274///   Bar,
275///   Baz(u32, A, String),
276///   Quux { name: String, inner: B },
277/// }
278/// ```
279#[proc_macro_derive(LabelledGenericEnum)]
280pub fn derive_labelled_generic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
281    let input = parse_macro_input!(input as syn::DeriveInput);
282    match input.data {
283        syn::Data::Enum(e) => generate_for_derive_input(&input.ident, &input.generics, &e).into(),
284        syn::Data::Struct(_) | syn::Data::Union(_) => quote!(compile_error!("#[derive(LabelledGenericEnum]] is only applicable for enum types");).into(),
285    }
286}
287
288#[test]
289fn test_generate_for_enum() {
290    let raw_enum = syn::parse_str::<syn::DeriveInput>(
291        r#"
292        enum Foo<C, E> {
293            A,
294            B(C, C, C),
295            D { foo: E, bar: E },
296        }
297    "#,
298    )
299    .unwrap();
300
301    let enum_ = match &raw_enum.data {
302        syn::Data::Enum(e) => e,
303        _ => unreachable!(),
304    };
305
306    let derived = generate_for_derive_input(&raw_enum.ident, &raw_enum.generics, enum_);
307
308    assert!(syn::parse_str::<syn::ItemImpl>(&derived.to_string()).is_ok());
309}