enum_fields/
lib.rs

1// Copyright (C) 2023 - 2025 Tristan Gerritsen <tristan@thewoosh.org>
2// All Rights Reserved.
3
4//! # enum-fields
5//! Quickly access shared enum fields in Rust.
6//!
7//! ## Example
8//! The following example showcases an enum `Entity`, which contains two
9//! variants: `Company` and `Person`.
10//!
11//! ```rs
12//! /// An entity that can be either a `Company` or a `Person`.
13//! #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, enum_fields::EnumFields)]
14//! pub enum Entity {
15//!     Company {
16//!         name: String,
17//!         ceo: String,
18//!     },
19//!
20//!     Person {
21//!         name: String,
22//!     }
23//! }
24//! ```
25//!
26//! ### Field Accessor Functions (Getters)
27//! Since `Entity` derives from [`enum_fields::EnumFields`], it now contains
28//! two field accessor functions (getters): `Entity::name()` and
29//! `Entity::ceo()`.
30//!
31//! ```rs
32//! let mut company = Entity::Company {
33//!     name: "Apple".into(),
34//!     ceo: "Tim Cook".into()
35//! };
36//!
37//! let person = Entity::Person {
38//!     name: "Tim Berners-Lee".into()
39//! };
40//!
41//! println!("Company with CEO: {} named: {}",
42//!     company.ceo().unwrap(),
43//!     company.name()
44//! );
45//!
46//! println!("Person named: {}", person.name());
47//! ```
48//!
49//! ### Shared Fields
50//! Note that both `Company` and `Person` have a field named `name`. This
51//! enforces `enum-fields` to let `Entity::name()` return the type directly.
52//!
53//! ```rs
54//! // Since [`Entity`] has two variants that both have the `name` field,
55//! // `Entity::name(&self)` returns the `&String`.
56//! assert_eq!(company.name(), "Apple");
57//! assert_eq!(person.name(), "Tim Berners-Lee");
58//! ```
59//!
60//! ### Shared Fields (Optional)
61//! However, only `Company` has field `ceo`, which therefore makes
62//! `Entity::ceo()` return an optional getter: `Option<&String>`.
63//!
64//! ```rs
65//! // Only `Company` has field `ceo`, so it returns an `Option<&String>`,
66//! // since a `Person` returns [`None`].
67//! assert_eq!(company.ceo(), Some(&"Tim Cook".into()));
68//! assert_eq!(person.ceo(), None);
69//!
70//! if let Some(ceo) = company.ceo_mut() {
71//!     ceo.push_str(" ?!");
72//! }
73//! assert_eq!(company.ceo(), Some(&"Tim Cook ?!".into()));
74//!
75//! *company.name_mut() = "Microsoft".into();
76//! assert_eq!(company.name(), "Microsoft");
77//! ```
78
79use std::collections::HashMap;
80
81use proc_macro::TokenStream;
82use proc_macro2::{Ident, Span};
83use quote::quote;
84use syn;
85
86#[proc_macro_derive(EnumFields)]
87pub fn enum_fields_macro_derive(input: TokenStream) -> TokenStream {
88    let ast = syn::parse(input).unwrap();
89    impl_for_input(&ast)
90}
91
92fn collect_available_fields(enum_data: &syn::DataEnum) -> HashMap<String, Vec<&syn::Field>> {
93    let mut fields = HashMap::new();
94
95    for variant in &enum_data.variants {
96        for field in &variant.fields {
97            if let Some(field_ident) = &field.ident {
98                let ident = field_ident.to_string();
99                fields.entry(ident)
100                    .or_insert(Vec::new())
101                    .push(field);
102            }
103        }
104    }
105
106    fields
107}
108
109fn impl_for_input(ast: &syn::DeriveInput) -> TokenStream {
110    let fail_message = "`EnumFields` is only applicable to `enum`s";
111    match &ast.data {
112        syn::Data::Enum(data_enum) => impl_for_enum(ast, &data_enum),
113        syn::Data::Union(data_union) => syn::Error::new(data_union.union_token.span, fail_message).to_compile_error().into(),
114        syn::Data::Struct(data_struct) => syn::Error::new(data_struct.struct_token.span, fail_message).to_compile_error().into(),
115    }
116}
117
118fn impl_for_enum(ast: &syn::DeriveInput, enum_data: &syn::DataEnum) -> TokenStream {
119    let name = &ast.ident;
120
121    // Collect available fields
122    let fields = collect_available_fields(enum_data);
123
124    let mut data = proc_macro2::TokenStream::new();
125
126
127    for (field_name, fields) in fields {
128        let field_present_everywhere = fields.len() == enum_data.variants.len()
129            && fields.iter().all(|x| x.ty == fields[0].ty);
130
131        let generics = &ast.generics;
132        let field_type = &fields[0].ty;
133        let field_name_ident = Ident::new(&field_name, Span::call_site());
134        let field_name_ident_mut = Ident::new(&format!("{field_name}_mut"), Span::call_site());
135
136        let mut variants = proc_macro2::TokenStream::new();
137
138        for variant in &enum_data.variants {
139            let name = &variant.ident;
140
141            let variant_field = variant.fields.iter()
142                .find(|variant_field| {
143                    if let Some(variant_field_ident) = &variant_field.ident {
144                        if variant_field_ident.to_string() == field_name {
145                            true
146                        } else {
147                            false
148                        }
149                    } else {
150                        false
151                    }
152                });
153
154            let variant_field_ident = variant_field.as_ref().and_then(|field| field.ident.as_ref());
155
156            match variant_field_ident {
157                Some(variant_field_ident) => {
158                    variants.extend(quote! {
159                        Self::#name{ #variant_field_ident, .. } => (#variant_field_ident).into(),
160                    });
161                }
162
163                None => {
164                    // Field not present in field list.
165                    if let Some(first_field) = variant.fields.iter().next() {
166                        if first_field.ident.is_some() {
167                            variants.extend(quote! {
168                                Self::#name{ .. } => None,
169                            });
170                        } else {
171                            variants.extend(quote! {
172                                Self::#name(..) => None,
173                            });
174                        }
175                    } else {
176                        variants.extend(quote! {
177                            Self::#name => None,
178                        });
179                    }
180                }
181            }
182        }
183
184        let ty = if field_present_everywhere {
185            quote! {
186                & #field_type
187            }
188        } else {
189            quote! {
190                Option<& #field_type>
191            }
192        };
193
194        let ty_mut = if field_present_everywhere {
195            quote! {
196                &mut #field_type
197            }
198        } else {
199            quote! {
200                Option<&mut #field_type>
201            }
202        };
203
204        data.extend(quote! {
205            impl #generics #name #generics {
206                pub fn #field_name_ident(&self) -> #ty {
207                    //! Get the property of this enum discriminant if it's available
208                    match self {
209                        #variants
210                    }
211                }
212
213                 pub fn #field_name_ident_mut(&mut self) -> #ty_mut {
214                    //! Get the mutable property of this enum discriminant if it's available
215                    match self {
216                        #variants
217                    }
218                }
219            }
220        });
221    }
222
223    data.into()
224}