enum_fields/
lib.rs

1// Copyright (C) 2023 Tristan Gerritsen <tristan@thewoosh.org>
2// All Rights Reserved.
3
4//! # enum-fields
5//! Quickly access shared enum fields in Rust.
6//!
7//! ## Example
8//! The following example showcases an enum `Entity`, which contains two
9//! variants: `Company` and `Person`.
10//!
11//! ```rs
12//! /// An entity that can be either a `Company` or a `Person`.
13//! #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, enum_fields::EnumFields)]
14//! pub enum Entity {
15//!     Company {
16//!         name: String,
17//!         ceo: String,
18//!     },
19//!
20//!     Person {
21//!         name: String,
22//!     }
23//! }
24//! ```
25//!
26//! ### Field Accessor Functions (Getters)
27//! Since `Entity` derives from [`enum_fields::EnumFields`], it now contains
28//! two field accessor functions (getters): `Entity::name()` and
29//! `Entity::ceo()`.
30//!
31//! ```rs
32//! let company = Entity::Company {
33//!     name: "Apple".into(),
34//!     ceo: "Tim Cook".into()
35//! };
36//!
37//! let person = Entity::Person {
38//!     name: "Tim Berners-Lee".into()
39//! };
40//!
41//! println!("Company with CEO: {} named: {}",
42//!     company.ceo().unwrap(),
43//!     company.name()
44//! );
45//!
46//! println!("Person named: {}", person.name());
47//! ```
48//!
49//! ### Shared Fields
50//! Note that both `Company` and `Person` have a field named `name`. This
51//! enforces `enum-fields` to let `Entity::name()` return the type directly.
52//!
53//! ```rs
54//! // Since [`Entity`] has two variants that both have the `name` field,
55//! // `Entity::name(&self)` returns the `&String`.
56//! assert_eq!(company.name(), "Apple");
57//! assert_eq!(person.name(), "Tim Berners-Lee");
58//! ```
59//!
60//! ### Shared Fields (Optional)
61//! However, only `Company` has field `ceo`, which therefore makes
62//! `Entity::ceo()` return an optional getter: `Option<&String>`.
63//!
64//! ```rs
65//! // Only `Company` has field `ceo`, so it returns an `Option<&String>`,
66//! // since a `Person` returns [`None`].
67//! assert_eq!(company.ceo(), Some(&"Tim Cook".into()));
68//! assert_eq!(person.ceo(), None);
69//! ```
70
71use std::collections::HashMap;
72
73use proc_macro::TokenStream;
74use proc_macro2::{Ident, Span};
75use quote::quote;
76use syn;
77
78#[proc_macro_derive(EnumFields)]
79pub fn enum_fields_macro_derive(input: TokenStream) -> TokenStream {
80    let ast = syn::parse(input).unwrap();
81    self::impl_for_input(&ast)
82}
83
84fn collect_available_fields<'input>(enum_data: &'input syn::DataEnum) -> HashMap<String, Vec<&'input syn::Field>> {
85    let mut fields = HashMap::new();
86
87    for variant in &enum_data.variants {
88        for field in &variant.fields {
89            if let Some(field_ident) = &field.ident {
90                let ident = field_ident.to_string();
91                fields.entry(ident)
92                    .or_insert(Vec::new())
93                    .push(field);
94            }
95        }
96    }
97
98    fields
99}
100
101fn impl_for_input(ast: &syn::DeriveInput) -> TokenStream {
102    let fail_message = "`EnumFields` is only applicable to `enum`s";
103    match &ast.data {
104        syn::Data::Enum(data_enum) => impl_for_enum(ast, &data_enum),
105        syn::Data::Union(data_union) => syn::Error::new(data_union.union_token.span, fail_message).to_compile_error().into(),
106        syn::Data::Struct(data_struct) => syn::Error::new(data_struct.struct_token.span, fail_message).to_compile_error().into(),
107    }
108}
109
110fn impl_for_enum(ast: &syn::DeriveInput, enum_data: &syn::DataEnum) -> TokenStream {
111    let name = &ast.ident;
112
113    // Collect available fields
114    let fields = collect_available_fields(enum_data);
115
116    let mut data = proc_macro2::TokenStream::new();
117
118
119    for (field_name, fields) in fields {
120        let field_present_everywhere = fields.len() == enum_data.variants.len();
121
122        let generics = &ast.generics;
123        let field_type = &fields[0].ty;
124        let field_name_ident = Ident::new(&field_name, Span::call_site());
125
126        let mut variants = proc_macro2::TokenStream::new();
127
128        for variant in &enum_data.variants {
129            let name = &variant.ident;
130
131            let variant_field_ident = variant.fields.iter()
132                .find(|variant_field| {
133                    if let Some(variant_field_ident) = &variant_field.ident {
134                        if variant_field_ident.to_string() == field_name {
135                            true
136                        } else {
137                            false
138                        }
139                    } else {
140                        false
141                    }
142                })
143                .map(|field| {
144                    field.ident.as_ref().unwrap()
145                });
146
147            match variant_field_ident {
148                Some(variant_field_ident) => {
149                    if field_present_everywhere {
150                        variants.extend(quote! {
151                            Self::#name{ #variant_field_ident, .. } => & #variant_field_ident,
152                        });
153                    } else {
154                        variants.extend(quote! {
155                            Self::#name{ #variant_field_ident, .. } => Some(& #variant_field_ident),
156                        });
157                    }
158                }
159
160                None => {
161                    // Field not present in field list.
162                    if let Some(first_field) = variant.fields.iter().next() {
163                        if first_field.ident.is_some() {
164                            variants.extend(quote! {
165                                Self::#name{ .. } => None,
166                            });
167                        } else {
168                            variants.extend(quote! {
169                                Self::#name(..) => None,
170                            });
171                        }
172                    } else {
173                        variants.extend(quote! {
174                            Self::#name => None,
175                        });
176                    }
177                }
178            }
179        }
180
181        let ty = if field_present_everywhere {
182            quote! {
183                & #field_type
184            }
185        } else {
186            quote! {
187                Option<& #field_type>
188            }
189        };
190
191        data.extend(quote! {
192            impl #generics #name #generics {
193                pub fn #field_name_ident(&self) -> #ty {
194                    //! Get the property of this enum discriminant if it's available
195                    match &self {
196                        #variants
197                    }
198                }
199            }
200        });
201    }
202
203    data.into()
204}