enum_field_getter/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro_error::{abort_call_site, emit_warning, proc_macro_error};
5use quote::{format_ident, quote};
6use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
7
8use std::collections::{HashMap, HashSet};
9
10/// See top-level crate documentation.
11#[proc_macro_error]
12#[proc_macro_derive(EnumFieldGetter)]
13pub fn enum_field_getter(stream: TokenStream) -> TokenStream {
14    let info = parse_macro_input!(stream as DeriveInput);
15    if let Data::Enum(enum_data) = info.data {
16        let variants = enum_data.variants.iter();
17        let name = info.ident;
18        let mut field_info: HashMap<String, (Type, Vec<String>)> = HashMap::new();
19        let mut tuple_field_info: HashMap<usize, (Type, Vec<String>)> = HashMap::new();
20        let mut incompatible = HashSet::<String>::new();
21        let mut tuple_incompatible = HashSet::<usize>::new();
22        for variant in variants {
23            if let Fields::Named(_) = variant.fields {
24                for field in &variant.fields {
25                    let ident = field.ident.clone().unwrap().to_string();
26                    let field_ty = field.ty.clone();
27                    let df = (field_ty.clone(), vec![variant.ident.to_string()]);
28                    field_info.entry(ident.clone()).and_modify(|info| {
29                        let (ty, used_variants) = info;
30                        if quote!{#field_ty}.to_string() != quote!{#ty}.to_string() {
31                            emit_warning!(field, "fields must be the same type across all variants - no getter will be emitted for this field");
32                            incompatible.insert(ident.clone());
33                        } else {
34                            used_variants.push(variant.ident.to_string());
35                        }
36                    }).or_insert(df);
37                }
38            } else if let Fields::Unnamed(_) = variant.fields {
39                for (i, field) in variant.fields.iter().enumerate() {
40                    let field_ty = field.ty.clone();
41                    let df = (field_ty.clone(), vec![variant.ident.to_string()]);
42                    tuple_field_info.entry(i).and_modify(|info| {
43                        let (ty, used_variants) = info;
44                        if quote!{#field_ty}.to_string() != quote!{#ty}.to_string() {
45                            emit_warning!(field, "fields must be the same type across all variants - no getter will be emitted for this field");
46                            tuple_incompatible.insert(i);
47                        } else {
48                            used_variants.push(variant.ident.to_string());
49                        }
50                    }).or_insert(df);
51                }
52            }
53        }
54        for removeable in incompatible {
55            field_info.remove(&removeable);
56        }
57        for tuple_removeable in tuple_incompatible {
58            tuple_field_info.remove(&tuple_removeable);
59        }
60        let getters = field_info.keys().map(|k| format_ident!("{}", k));
61        let getters_mut = field_info.keys().map(|k| format_ident!("{}_mut", k));
62        let types = field_info.values().map(|v| v.0.clone());
63        let types_mut = types.clone();
64        let field_info_vec = field_info.iter().collect::<Vec<_>>();
65        let matches = field_info_vec.iter().map(|(k, v)| {
66            let variants =
67                v.1.clone()
68                    .iter()
69                    .map(|v| format_ident!("{}", v))
70                    .collect::<Vec<_>>();
71            let field = vec![format_ident!("{}", k); variants.len()];
72            quote! {
73                match self {
74                    #(
75                        Self::#variants { #field, .. } => Some(#field),
76                    )*
77                    _ => None,
78                }
79            }
80        });
81        let matches_mut = matches.clone();
82        let tuple_getters = tuple_field_info.keys().map(|k| format_ident!("get_{}", k));
83        let tuple_getters_mut = tuple_field_info
84            .keys()
85            .map(|k| format_ident!("get_{}_mut", k));
86        let tuple_types = tuple_field_info.values().map(|v| v.0.clone());
87        let tuple_types_mut = tuple_types.clone();
88        let tuple_field_info_vec = tuple_field_info.iter().collect::<Vec<_>>();
89        let tuple_matches = tuple_field_info_vec.iter().map(|(k, v)| {
90            let variants =
91                v.1.clone()
92                    .iter()
93                    .map(|v| format_ident!("{}", v))
94                    .collect::<Vec<_>>();
95            let preceding = vec![format_ident!("_"); **k];
96            let preceding_quote = vec![quote! { #(#preceding,)* }; variants.len()];
97            let field = vec![format_ident!("val_{}", k); variants.len()];
98            quote! {
99                match self {
100                    #(
101                        Self::#variants(#preceding_quote #field, .. ) => Some(#field),
102                    )*
103                    _ => None,
104                }
105            }
106        });
107        let tuple_matches_mut = tuple_matches.clone();
108        quote! {
109            impl #name {
110                #(
111                pub fn #getters (&self) -> Option<&#types> {
112                    #matches
113                }
114                )*
115                #(
116                pub fn #tuple_getters (&self) -> Option<&#tuple_types> {
117                    #tuple_matches
118                }
119                )*
120                #(
121                pub fn #getters_mut (&mut self) -> Option<&mut #types_mut> {
122                    #matches_mut
123                }
124                )*
125                #(
126                pub fn #tuple_getters_mut (&mut self) -> Option<&mut #tuple_types_mut> {
127                    #tuple_matches_mut
128                }
129                )*
130            }
131        }
132        .into()
133    } else {
134        abort_call_site!("macro can only be used on enums");
135    }
136}