1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::{format_ident, quote, quote_spanned};
6use syn::{Fields, ItemEnum, parse::Parse, punctuated::Punctuated, spanned::Spanned, token::Comma};
7
8struct PunctedNamedFields(Punctuated<syn::Field, Comma>);
9struct PunctedUnnamedFields(Punctuated<syn::Field, Comma>);
10
11impl std::ops::Deref for PunctedNamedFields {
12 type Target = Punctuated<syn::Field, Comma>;
13
14 fn deref(&self) -> &Self::Target {
15 &self.0
16 }
17}
18
19impl Parse for PunctedNamedFields {
20 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
21 input.parse_terminated(syn::Field::parse_named, Comma)
22 .map(Self)
23 }
24}
25
26impl Parse for PunctedUnnamedFields {
27 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
28 input.parse_terminated(syn::Field::parse_unnamed, Comma)
29 .map(Self)
30 }
31}
32
33#[proc_macro_attribute]
63pub fn fields(attr: TokenStream, adt: TokenStream) -> TokenStream {
64 let mut item_enum = match syn::parse::<ItemEnum>(adt) {
65 Ok(x) => x,
66 Err(e) => return e.into_compile_error().into(),
67 };
68 let fields = match syn::parse::<PunctedNamedFields>(attr.clone()) {
69 Ok(it) => it,
70 Err(err) => return err.into_compile_error().into(),
71 };
72 item_enum.variants.iter_mut().for_each(|variant| {
73 add_fields(&mut variant.fields, &fields);
74 });
75
76 let ItemEnum {
77 attrs,
78 vis,
79 enum_token,
80 ident,
81 generics,
82 brace_token: _,
83 variants,
84 } = item_enum;
85
86 let (impl_generics,
87 type_generics,
88 where_clause) = generics.split_for_impl();
89
90 let methods = generate_methods(&vis, &fields, &variants);
91
92 quote! {
93 #(#attrs)*
94 #vis #enum_token #ident #generics {
95 #variants
96 }
97 impl #impl_generics #ident #type_generics #where_clause {
98 #(#methods)*
99 }
100 }.into()
101}
102
103fn generate_methods(
104 vis: &syn::Visibility,
105 fields: &PunctedNamedFields,
106 variants: &Punctuated<syn::Variant, Comma>,
107) -> Vec<pm2::TokenStream> {
108 fields.pairs()
109 .map(|pair| pair.into_value())
110 .enumerate()
111 .map(|(i, field)| {
112 let i_field = pm2::Literal::usize_unsuffixed(i);
113 let name = field.ident.as_ref().expect("empty field");
114 let colon = field.colon_token.as_ref().expect("empty colon token");
115 let ty = &field.ty;
116
117 let attrs = field.attrs.iter()
118 .filter(allowed_field_attr)
119 .collect::<Vec<_>>();
120
121 let field_name = lose_span(name);
122 let method_span = colon.span.span();
123
124 let immutable_getter = format_ident!("{field_name}", span = method_span);
125 let mutable_getter = format_ident!("{field_name}_mut", span = method_span);
126 let owned_getter = format_ident!("into_{field_name}", span = method_span);
127
128 let variants_pat = variants.iter()
129 .map(|it| {
130 let body = match it.fields {
131 Fields::Named(_) => quote! {
132 { #field_name, .. }
133 },
134 Fields::Unnamed(_) => quote! {
135 { #i_field: #field_name, .. }
136 },
137 Fields::Unit => quote! {},
138 };
139 let variant_name = lose_span(&it.ident);
140 quote! {
141 Self::#variant_name #body
142 }
143 })
144 .collect::<Vec<_>>();
145 let match_arms = if variants_pat.is_empty() {
146 quote! {
147 _ => loop {}
148 }
149 } else {
150 quote! {
151 #(| #variants_pat)*
152 => #field_name,
153 }
154 };
155
156 quote! {
157 #(#attrs)*
158 #[allow(unused)]
159 #vis fn #immutable_getter(&self) -> &#ty {
160 match self {
161 #match_arms
162 }
163 }
164 #(#attrs)*
165 #[allow(unused)]
166 #vis fn #mutable_getter(&mut self) -> &mut #ty {
167 match self {
168 #match_arms
169 }
170 }
171 #(#attrs)*
172 #[allow(unused)]
173 #vis fn #owned_getter(self) -> #ty {
174 match self {
175 #match_arms
176 }
177 }
178 }
179 })
180 .collect()
181}
182
183fn allowed_field_attr(attr: &&syn::Attribute) -> bool {
184 attr.path().is_ident("doc") && attr.meta.require_name_value().is_ok()
185 || attr.path().is_ident("cfg") && attr.meta.require_list().is_ok()
186}
187
188fn lose_span(ident: &pm2::Ident) -> pm2::Ident {
189 pm2::Ident::new(&ident.to_string(), pm2::Span::call_site())
190}
191
192fn add_fields(variant_fields: &mut Fields, fields: &PunctedNamedFields) {
193 let needs_comma = !fields.trailing_punct() && !fields.is_empty();
194 match variant_fields {
195 Fields::Unit => {
196 let mut tokens = pm2::Group::new(pm2::Delimiter::Brace, pm2::TokenStream::new());
197 tokens.set_span(variant_fields.span());
198 *variant_fields = Fields::Named(syn::parse2(pm2::TokenTree::from(tokens).into()).unwrap());
199 add_fields(variant_fields, fields)
200 },
201 Fields::Named(syn::FieldsNamed { named, .. }) => {
202 let fields_iter = fields.pairs();
203 let tokens = if needs_comma {
204 quote_spanned! { fields.span() => #(#fields_iter)* , #named }
205 } else {
206 quote_spanned! { fields.span() => #(#fields_iter)* #named }
207 };
208 *named = syn::parse2::<PunctedNamedFields>(tokens).unwrap().0;
209 },
210 Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
211 let fields_iter = fields.0.clone().into_pairs().map(|mut pair| {
212 pair.value_mut().ident.take();
213 pair.value_mut().colon_token.take();
214 pair
215 });
216 let tokens = if needs_comma {
217 quote_spanned! { fields.span() => #(#fields_iter)* , #unnamed }
218 } else {
219 quote_spanned! { fields.span() => #(#fields_iter)* #unnamed }
220 };
221 *unnamed = syn::parse2::<PunctedUnnamedFields>(tokens).unwrap().0;
222 },
223 }
224}