enum_fields/lib.rs
1// Copyright (C) 2023 Tristan Gerritsen <tristan@thewoosh.org>
2// All Rights Reserved.
3
4//! # enum-fields
5//! Quickly access shared enum fields in Rust.
6//!
7//! ## Example
8//! The following example showcases an enum `Entity`, which contains two
9//! variants: `Company` and `Person`.
10//!
11//! ```rs
12//! /// An entity that can be either a `Company` or a `Person`.
13//! #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, enum_fields::EnumFields)]
14//! pub enum Entity {
15//! Company {
16//! name: String,
17//! ceo: String,
18//! },
19//!
20//! Person {
21//! name: String,
22//! }
23//! }
24//! ```
25//!
26//! ### Field Accessor Functions (Getters)
27//! Since `Entity` derives from [`enum_fields::EnumFields`], it now contains
28//! two field accessor functions (getters): `Entity::name()` and
29//! `Entity::ceo()`.
30//!
31//! ```rs
32//! let company = Entity::Company {
33//! name: "Apple".into(),
34//! ceo: "Tim Cook".into()
35//! };
36//!
37//! let person = Entity::Person {
38//! name: "Tim Berners-Lee".into()
39//! };
40//!
41//! println!("Company with CEO: {} named: {}",
42//! company.ceo().unwrap(),
43//! company.name()
44//! );
45//!
46//! println!("Person named: {}", person.name());
47//! ```
48//!
49//! ### Shared Fields
50//! Note that both `Company` and `Person` have a field named `name`. This
51//! enforces `enum-fields` to let `Entity::name()` return the type directly.
52//!
53//! ```rs
54//! // Since [`Entity`] has two variants that both have the `name` field,
55//! // `Entity::name(&self)` returns the `&String`.
56//! assert_eq!(company.name(), "Apple");
57//! assert_eq!(person.name(), "Tim Berners-Lee");
58//! ```
59//!
60//! ### Shared Fields (Optional)
61//! However, only `Company` has field `ceo`, which therefore makes
62//! `Entity::ceo()` return an optional getter: `Option<&String>`.
63//!
64//! ```rs
65//! // Only `Company` has field `ceo`, so it returns an `Option<&String>`,
66//! // since a `Person` returns [`None`].
67//! assert_eq!(company.ceo(), Some(&"Tim Cook".into()));
68//! assert_eq!(person.ceo(), None);
69//! ```
70
71use std::collections::HashMap;
72
73use proc_macro::TokenStream;
74use proc_macro2::{Ident, Span};
75use quote::quote;
76use syn;
77
78#[proc_macro_derive(EnumFields)]
79pub fn enum_fields_macro_derive(input: TokenStream) -> TokenStream {
80 let ast = syn::parse(input).unwrap();
81 self::impl_for_input(&ast)
82}
83
84fn collect_available_fields<'input>(enum_data: &'input syn::DataEnum) -> HashMap<String, Vec<&'input syn::Field>> {
85 let mut fields = HashMap::new();
86
87 for variant in &enum_data.variants {
88 for field in &variant.fields {
89 if let Some(field_ident) = &field.ident {
90 let ident = field_ident.to_string();
91 fields.entry(ident)
92 .or_insert(Vec::new())
93 .push(field);
94 }
95 }
96 }
97
98 fields
99}
100
101fn impl_for_input(ast: &syn::DeriveInput) -> TokenStream {
102 let fail_message = "`EnumFields` is only applicable to `enum`s";
103 match &ast.data {
104 syn::Data::Enum(data_enum) => impl_for_enum(ast, &data_enum),
105 syn::Data::Union(data_union) => syn::Error::new(data_union.union_token.span, fail_message).to_compile_error().into(),
106 syn::Data::Struct(data_struct) => syn::Error::new(data_struct.struct_token.span, fail_message).to_compile_error().into(),
107 }
108}
109
110fn impl_for_enum(ast: &syn::DeriveInput, enum_data: &syn::DataEnum) -> TokenStream {
111 let name = &ast.ident;
112
113 // Collect available fields
114 let fields = collect_available_fields(enum_data);
115
116 let mut data = proc_macro2::TokenStream::new();
117
118
119 for (field_name, fields) in fields {
120 let field_present_everywhere = fields.len() == enum_data.variants.len();
121
122 let generics = &ast.generics;
123 let field_type = &fields[0].ty;
124 let field_name_ident = Ident::new(&field_name, Span::call_site());
125
126 let mut variants = proc_macro2::TokenStream::new();
127
128 for variant in &enum_data.variants {
129 let name = &variant.ident;
130
131 let variant_field_ident = variant.fields.iter()
132 .find(|variant_field| {
133 if let Some(variant_field_ident) = &variant_field.ident {
134 if variant_field_ident.to_string() == field_name {
135 true
136 } else {
137 false
138 }
139 } else {
140 false
141 }
142 })
143 .map(|field| {
144 field.ident.as_ref().unwrap()
145 });
146
147 match variant_field_ident {
148 Some(variant_field_ident) => {
149 if field_present_everywhere {
150 variants.extend(quote! {
151 Self::#name{ #variant_field_ident, .. } => & #variant_field_ident,
152 });
153 } else {
154 variants.extend(quote! {
155 Self::#name{ #variant_field_ident, .. } => Some(& #variant_field_ident),
156 });
157 }
158 }
159
160 None => {
161 // Field not present in field list.
162 if let Some(first_field) = variant.fields.iter().next() {
163 if first_field.ident.is_some() {
164 variants.extend(quote! {
165 Self::#name{ .. } => None,
166 });
167 } else {
168 variants.extend(quote! {
169 Self::#name(..) => None,
170 });
171 }
172 } else {
173 variants.extend(quote! {
174 Self::#name => None,
175 });
176 }
177 }
178 }
179 }
180
181 let ty = if field_present_everywhere {
182 quote! {
183 & #field_type
184 }
185 } else {
186 quote! {
187 Option<& #field_type>
188 }
189 };
190
191 data.extend(quote! {
192 impl #generics #name #generics {
193 pub fn #field_name_ident(&self) -> #ty {
194 //! Get the property of this enum discriminant if it's available
195 match &self {
196 #variants
197 }
198 }
199 }
200 });
201 }
202
203 data.into()
204}