enum_struct/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::{format_ident, quote, quote_spanned};
6use syn::{Fields, ItemEnum, parse::Parse, punctuated::Punctuated, spanned::Spanned, token::Comma};
7
8struct PunctedNamedFields(Punctuated<syn::Field, Comma>);
9struct PunctedUnnamedFields(Punctuated<syn::Field, Comma>);
10
11impl std::ops::Deref for PunctedNamedFields {
12    type Target = Punctuated<syn::Field, Comma>;
13
14    fn deref(&self) -> &Self::Target {
15        &self.0
16    }
17}
18
19impl Parse for PunctedNamedFields {
20    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
21        input.parse_terminated(syn::Field::parse_named, Comma)
22            .map(Self)
23    }
24}
25
26impl Parse for PunctedUnnamedFields {
27    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
28        input.parse_terminated(syn::Field::parse_unnamed, Comma)
29            .map(Self)
30    }
31}
32
33/// Insert fields for each variant, and generate getter for each fields
34///
35/// # Example
36///
37/// ```
38/// #[enum_struct::fields {
39///     id: u64,
40/// }]
41/// #[derive(Debug, PartialEq)]
42/// enum Foo {
43///     Named(String),
44///     Complex { name: String, age: u32, level: u16 },
45///     Empty,
46/// }
47///
48/// let named = Foo::Named(2, "jack".into());
49/// let complex = Foo::Complex { id: 3, name: "john".into(), age: 22, level: 4 };
50/// let empty = Foo::Empty { id: 4 };
51///
52/// assert_eq!(named.id(), &2);
53/// assert_eq!(complex.id(), &3);
54/// assert_eq!(empty.id(), &4);
55///
56/// let mut named = named;
57///
58/// *named.id_mut() = 8;
59/// assert_eq!(named.id(), &8);
60/// assert_eq!(named, Foo::Named(8, "jack".into()));
61/// ```
62#[proc_macro_attribute]
63pub fn fields(attr: TokenStream, adt: TokenStream) -> TokenStream {
64    let mut item_enum = match syn::parse::<ItemEnum>(adt) {
65        Ok(x) => x,
66        Err(e) => return e.into_compile_error().into(),
67    };
68    let fields = match syn::parse::<PunctedNamedFields>(attr.clone()) {
69        Ok(it) => it,
70        Err(err) => return err.into_compile_error().into(),
71    };
72    item_enum.variants.iter_mut().for_each(|variant| {
73        add_fields(&mut variant.fields, &fields);
74    });
75
76    let ItemEnum {
77        attrs,
78        vis,
79        enum_token,
80        ident,
81        generics,
82        brace_token: _,
83        variants,
84    } = item_enum;
85
86    let (impl_generics,
87         type_generics,
88         where_clause) = generics.split_for_impl();
89
90    let methods = generate_methods(&vis, &fields, &variants);
91
92    quote! {
93        #(#attrs)*
94        #vis #enum_token #ident #generics {
95            #variants
96        }
97        impl #impl_generics #ident #type_generics #where_clause {
98            #(#methods)*
99        }
100    }.into()
101}
102
103fn generate_methods(
104    vis: &syn::Visibility,
105    fields: &PunctedNamedFields,
106    variants: &Punctuated<syn::Variant, Comma>,
107) -> Vec<pm2::TokenStream> {
108    fields.pairs()
109        .map(|pair| pair.into_value())
110        .enumerate()
111        .map(|(i, field)| {
112            let i_field = pm2::Literal::usize_unsuffixed(i);
113            let name = field.ident.as_ref().expect("empty field");
114            let colon = field.colon_token.as_ref().expect("empty colon token");
115            let ty = &field.ty;
116
117            let attrs = field.attrs.iter()
118                .filter(allowed_field_attr)
119                .collect::<Vec<_>>();
120
121            let field_name = lose_span(name);
122            let method_span = colon.span.span();
123
124            let immutable_getter = format_ident!("{field_name}", span = method_span);
125            let mutable_getter = format_ident!("{field_name}_mut", span = method_span);
126            let owned_getter = format_ident!("into_{field_name}", span = method_span);
127
128            let variants_pat = variants.iter()
129                .map(|it| {
130                    let body = match it.fields {
131                        Fields::Named(_) => quote! {
132                            { #field_name, .. }
133                        },
134                        Fields::Unnamed(_) => quote! {
135                            { #i_field: #field_name, .. }
136                        },
137                        Fields::Unit => quote! {},
138                    };
139                    let variant_name = lose_span(&it.ident);
140                    quote! {
141                        Self::#variant_name #body
142                    }
143                })
144                .collect::<Vec<_>>();
145            let match_arms = if variants_pat.is_empty() {
146                quote! {
147                    _ => loop {}
148                }
149            } else {
150                quote! {
151                    #(| #variants_pat)*
152                    => #field_name,
153                }
154            };
155
156            quote! {
157                #(#attrs)*
158                #[allow(unused)]
159                #vis fn #immutable_getter(&self) -> &#ty {
160                    match self {
161                        #match_arms
162                    }
163                }
164                #(#attrs)*
165                #[allow(unused)]
166                #vis fn #mutable_getter(&mut self) -> &mut #ty {
167                    match self {
168                        #match_arms
169                    }
170                }
171                #(#attrs)*
172                #[allow(unused)]
173                #vis fn #owned_getter(self) -> #ty {
174                    match self {
175                        #match_arms
176                    }
177                }
178            }
179        })
180        .collect()
181}
182
183fn allowed_field_attr(attr: &&syn::Attribute) -> bool {
184    attr.path().is_ident("doc") && attr.meta.require_name_value().is_ok()
185        || attr.path().is_ident("cfg") && attr.meta.require_list().is_ok()
186}
187
188fn lose_span(ident: &pm2::Ident) -> pm2::Ident {
189    pm2::Ident::new(&ident.to_string(), pm2::Span::call_site())
190}
191
192fn add_fields(variant_fields: &mut Fields, fields: &PunctedNamedFields) {
193    let needs_comma = !fields.trailing_punct() && !fields.is_empty();
194    match variant_fields {
195        Fields::Unit => {
196            let mut tokens = pm2::Group::new(pm2::Delimiter::Brace, pm2::TokenStream::new());
197            tokens.set_span(variant_fields.span());
198            *variant_fields = Fields::Named(syn::parse2(pm2::TokenTree::from(tokens).into()).unwrap());
199            add_fields(variant_fields, fields)
200        },
201        Fields::Named(syn::FieldsNamed { named, .. }) => {
202            let fields_iter = fields.pairs();
203            let tokens = if needs_comma {
204                quote_spanned! { fields.span() => #(#fields_iter)* , #named }
205            } else {
206                quote_spanned! { fields.span() => #(#fields_iter)*   #named }
207            };
208            *named = syn::parse2::<PunctedNamedFields>(tokens).unwrap().0;
209        },
210        Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
211            let fields_iter = fields.0.clone().into_pairs().map(|mut pair| {
212                pair.value_mut().ident.take();
213                pair.value_mut().colon_token.take();
214                pair
215            });
216            let tokens = if needs_comma {
217                quote_spanned! { fields.span() => #(#fields_iter)* , #unnamed }
218            } else {
219                quote_spanned! { fields.span() => #(#fields_iter)*   #unnamed }
220            };
221            *unnamed = syn::parse2::<PunctedUnnamedFields>(tokens).unwrap().0;
222        },
223    }
224}