enum_utility_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use functions_builder::EnumFunctionsBuilder;
4use proc_macro::TokenStream;
5use proc_macro2::{Ident, Span};
6use quote::{quote, ToTokens};
7use ref_enum_builder::RefEnumBuilder;
8use syn::{
9    parse::Parser,
10    parse_macro_input,
11    punctuated::Punctuated,
12    token::{self},
13    Expr, Fields, ItemEnum, ItemFn, Token, Type, TypeTuple, Variant, Visibility,
14};
15use tag_enum_builder::TagEnumBuilder;
16
17pub(crate) mod functions_builder;
18pub(crate) mod ref_enum_builder;
19pub(crate) mod tag_enum_builder;
20
21#[proc_macro_attribute]
22pub fn generate_enum_helper(attr: TokenStream, item: TokenStream) -> TokenStream {
23    let mut enum_stream = item.clone();
24
25    let parser = Punctuated::<Ident, Token![,]>::parse_separated_nonempty;
26    let attributes = parser.parse(attr).unwrap();
27    let input = parse_macro_input!(item as ItemEnum);
28
29    let mut generate_tag_enum = false;
30    let mut generate_ref_enum = false;
31    let mut generate_mut_enum = false;
32
33    let mut create_is_functions = false;
34    let mut create_unwrap_functions = false;
35    let mut create_unwrap_ref_functions = false;
36    let mut create_unwrap_ref_mut_functions = false;
37    let mut create_to_tag_functions = false;
38    let mut create_as_ref_functions = false;
39    let mut create_as_mut_functions = false;
40    let mut create_get_functions = false;
41    let mut create_get_ref_functions = false;
42    let mut create_get_mut_functions = false;
43    for item in attributes {
44        match item.to_string().as_str() {
45            "TagEnum" => generate_tag_enum = true,
46            "RefEnum" => generate_ref_enum = true,
47            "MutEnum" => generate_mut_enum = true,
48            "is" => create_is_functions = true,
49            "unwrap" => create_unwrap_functions = true,
50            "unwrap_ref" => create_unwrap_ref_functions = true,
51            "unwrap_mut" => create_unwrap_ref_mut_functions = true,
52            "to_tag" => create_to_tag_functions = true,
53            "as_ref" => create_as_ref_functions = true,
54            "as_mut" => create_as_mut_functions = true,
55            "get" => create_get_functions = true,
56            "get_ref" => create_get_ref_functions = true,
57            "get_mut" => create_get_mut_functions = true,
58            _ => panic!(),
59        }
60    }
61
62    let input_enum = InputEnum(input);
63    if create_is_functions
64        || create_unwrap_functions
65        || create_unwrap_ref_functions
66        || create_unwrap_ref_mut_functions
67        || create_to_tag_functions
68        || create_as_ref_functions
69        || create_as_mut_functions
70        || create_get_functions
71        || create_get_ref_functions
72        || create_get_mut_functions
73    {
74        let mut functions_builder = EnumFunctionsBuilder::new(&input_enum);
75        if create_is_functions {
76            functions_builder.is_functions();
77        }
78        if create_unwrap_functions {
79            functions_builder.unwrap_functions();
80        }
81        if create_unwrap_ref_functions {
82            functions_builder.unwrap_ref_functions();
83        }
84        if create_unwrap_ref_mut_functions {
85            functions_builder.unwrap_mut_functions();
86        }
87        if create_to_tag_functions {
88            functions_builder.to_tag_function();
89        }
90        if create_as_ref_functions {
91            functions_builder.as_ref_functions();
92        }
93        if create_as_mut_functions {
94            functions_builder.as_mut_functions();
95        }
96        if create_get_functions {
97            functions_builder.get_functions();
98        }
99        if create_get_ref_functions {
100            functions_builder.get_ref_functions();
101        }
102        if create_get_mut_functions {
103            functions_builder.get_mut_functions();
104        }
105
106        let ts = functions_builder.token_stream();
107        enum_stream.extend([ts]);
108    }
109
110    if generate_tag_enum {
111        let mut tag_enum_builder = TagEnumBuilder::new(&input_enum);
112        if create_is_functions {
113            tag_enum_builder.is_functions();
114        }
115        let ts = tag_enum_builder.token_stream();
116        enum_stream.extend([ts]);
117    }
118
119    if generate_ref_enum {
120        let mut ref_enum_builder = RefEnumBuilder::new(&input_enum, false);
121        if create_is_functions {
122            ref_enum_builder.is_functions();
123        }
124        if create_unwrap_functions {
125            ref_enum_builder.unwrap_functions();
126        }
127        if create_to_tag_functions {
128            ref_enum_builder.to_tag_functions();
129        }
130        if create_get_functions {
131            ref_enum_builder.get_functions();
132        }
133        let ts = ref_enum_builder.token_stream();
134        enum_stream.extend([ts]);
135    }
136
137    if generate_mut_enum {
138        let mut ref_enum_builder = RefEnumBuilder::new(&input_enum, true);
139        if create_is_functions {
140            ref_enum_builder.is_functions();
141        }
142        if create_unwrap_functions {
143            ref_enum_builder.unwrap_functions();
144        }
145        if create_to_tag_functions {
146            ref_enum_builder.to_tag_functions();
147        }
148        if create_get_functions {
149            ref_enum_builder.get_functions();
150        }
151        let ts = ref_enum_builder.token_stream();
152        enum_stream.extend([ts]);
153    }
154
155    enum_stream
156}
157
158pub(crate) struct InputEnum(ItemEnum);
159
160impl InputEnum {
161    fn vis(&self) -> &Visibility {
162        &self.0.vis
163    }
164
165    fn name(&self) -> String {
166        format!("{}", self.0.ident)
167    }
168
169    fn variant_snake_case_name(&self, i: usize) -> String {
170        let variant_name = self.0.variants[i].ident.to_string();
171        let mut snake_case_name = String::new();
172        for c in variant_name.chars() {
173            if c.is_uppercase() && snake_case_name.is_empty() {
174                snake_case_name += format!("{}", c.to_ascii_lowercase()).as_str();
175            } else if c.is_uppercase() {
176                snake_case_name += format!("_{}", c.to_ascii_lowercase()).as_str();
177            } else {
178                snake_case_name += format!("{c}").as_str();
179            }
180        }
181        snake_case_name
182    }
183
184    fn generics(&self) -> &syn::Generics {
185        &self.0.generics
186    }
187
188    fn attributes(&self) -> &Vec<syn::Attribute> {
189        &self.0.attrs
190    }
191
192    fn iter_variants(&self) -> impl Iterator<Item = &Variant> {
193        self.0.variants.iter()
194    }
195
196    fn variant_count(&self) -> usize {
197        self.0.variants.len()
198    }
199
200    fn variant(&self, i: usize) -> &Variant {
201        &self.0.variants[i]
202    }
203
204    fn variant_type(&self, i: usize) -> Type {
205        let elems: Punctuated<_, _> = self.0.variants[i]
206            .fields
207            .iter()
208            .map(|f| f.ty.clone())
209            .collect();
210
211        if elems.len() == 1 {
212            return (*elems.first().unwrap()).clone();
213        }
214
215        let group = proc_macro2::Group::new(
216            proc_macro2::Delimiter::Parenthesis,
217            proc_macro2::TokenStream::new(),
218        );
219        syn::Type::Tuple(TypeTuple {
220            paren_token: token::Paren {
221                span: group.delim_span(),
222            },
223            elems,
224        })
225    }
226
227    fn match_variant(&self, i: usize, enum_ident: Option<Ident>) -> syn::Pat {
228        let variant = self.variant(i);
229        let enum_name = enum_ident.as_ref().unwrap_or(&self.0.ident);
230        let variant_name = &self.variant(i).ident;
231        let pattern = match &variant.fields {
232            Fields::Unit => {
233                quote! {
234                    #enum_name :: #variant_name
235                }
236            }
237            Fields::Named(_) => {
238                quote! {
239                    #enum_name :: #variant_name { .. }
240                }
241            }
242            Fields::Unnamed(fields) => {
243                let wild_pattern = vec![
244                    syn::Pat::Wild(syn::PatWild {
245                        attrs: vec![],
246                        underscore_token: token::Underscore {
247                            spans: [Span::call_site(); 1]
248                        },
249                    });
250                    fields.unnamed.len()
251                ];
252
253                quote! {
254                    #enum_name :: #variant_name ( #(#wild_pattern ,)* )
255                }
256            }
257        };
258
259        syn::Pat::Verbatim(pattern)
260    }
261
262    fn match_variant_to_tuple(&self, i: usize, enum_ident: Option<Ident>) -> syn::Arm {
263        let (pat, body) = match &self.variant(i).fields {
264            Fields::Unit => {
265                let group = proc_macro2::Group::new(
266                    proc_macro2::Delimiter::Parenthesis,
267                    proc_macro2::TokenStream::new(),
268                );
269
270                (
271                    self.match_variant(i, None),
272                    Box::new(Expr::Tuple(syn::ExprTuple {
273                        attrs: vec![],
274                        paren_token: token::Paren {
275                            span: group.delim_span(),
276                        },
277                        elems: Punctuated::new(),
278                    })),
279                )
280            }
281            Fields::Unnamed(fields) => {
282                let mut patterns = Punctuated::new();
283                let mut elements = Punctuated::new();
284
285                for (index, _field) in fields.unnamed.iter().enumerate() {
286                    let name = format!("e{index}");
287                    let ident = Ident::new(name.as_str(), Span::call_site());
288                    patterns.push(syn::Pat::Path(syn::PatPath {
289                        attrs: vec![],
290                        qself: None,
291                        path: syn::PathSegment {
292                            arguments: syn::PathArguments::None,
293                            ident: ident.clone(),
294                        }
295                        .into(),
296                    }));
297
298                    elements.push(Expr::Path(syn::ExprPath {
299                        attrs: vec![],
300                        qself: None,
301                        path: syn::PathSegment {
302                            arguments: syn::PathArguments::None,
303                            ident,
304                        }
305                        .into(),
306                    }))
307                }
308
309                let pattern_path = {
310                    let mut punctuated = Punctuated::new();
311                    punctuated.push(syn::PathSegment {
312                        ident: enum_ident.unwrap_or(self.0.ident.clone()),
313                        arguments: syn::PathArguments::None,
314                    });
315                    punctuated.push(syn::PathSegment {
316                        ident: self.variant(i).ident.clone(),
317                        arguments: syn::PathArguments::None,
318                    });
319
320                    syn::Path {
321                        leading_colon: None,
322                        segments: punctuated,
323                    }
324                };
325
326                let group = proc_macro2::Group::new(
327                    proc_macro2::Delimiter::Parenthesis,
328                    proc_macro2::TokenStream::new(),
329                );
330                let pat = syn::Pat::TupleStruct(syn::PatTupleStruct {
331                    attrs: vec![],
332                    qself: None,
333                    path: pattern_path,
334                    paren_token: token::Paren {
335                        span: group.delim_span(),
336                    },
337                    elems: patterns,
338                });
339
340                let body = if elements.len() == 1 {
341                    let syn::Expr::Path(syn::ExprPath { path, ..}) = elements.first().unwrap() else {
342                        panic!()
343                    };
344                    Box::new(Expr::Path(syn::ExprPath {
345                        attrs: vec![],
346                        qself: None,
347                        path: path.clone(),
348                    }))
349                } else {
350                    Box::new(Expr::Tuple(syn::ExprTuple {
351                        attrs: vec![],
352                        paren_token: token::Paren {
353                            span: group.delim_span(),
354                        },
355                        elems: elements,
356                    }))
357                };
358
359                (pat, body)
360            }
361            Fields::Named(fields) => {
362                // Unnify with unnamed
363                let mut patterns = Punctuated::new();
364                let mut elements = Punctuated::new();
365
366                for field in fields.named.iter() {
367                    patterns.push(syn::FieldPat {
368                        attrs: vec![],
369                        member: syn::Member::Named(field.ident.clone().unwrap()),
370                        colon_token: None, // Shorthand field pattern
371                        pat: Box::new(syn::Pat::Path(syn::PatPath {
372                            attrs: vec![],
373                            qself: None,
374                            path: syn::PathSegment {
375                                arguments: syn::PathArguments::None,
376                                ident: field.ident.clone().unwrap(),
377                            }
378                            .into(),
379                        })),
380                    });
381
382                    elements.push(Expr::Path(syn::ExprPath {
383                        attrs: vec![],
384                        qself: None,
385                        path: syn::PathSegment {
386                            arguments: syn::PathArguments::None,
387                            ident: field.ident.clone().unwrap(),
388                        }
389                        .into(),
390                    }))
391                }
392
393                let pattern_path = {
394                    let mut punctuated = Punctuated::new();
395                    punctuated.push(syn::PathSegment {
396                        ident: enum_ident.unwrap_or(self.0.ident.clone()),
397                        arguments: syn::PathArguments::None,
398                    });
399                    punctuated.push(syn::PathSegment {
400                        ident: self.variant(i).ident.clone(),
401                        arguments: syn::PathArguments::None,
402                    });
403
404                    syn::Path {
405                        leading_colon: None,
406                        segments: punctuated,
407                    }
408                };
409
410                let group = proc_macro2::Group::new(
411                    proc_macro2::Delimiter::Parenthesis,
412                    proc_macro2::TokenStream::new(),
413                );
414                let pat = syn::Pat::Struct(syn::PatStruct {
415                    attrs: vec![],
416                    qself: None,
417                    path: pattern_path,
418                    brace_token: token::Brace {
419                        span: group.delim_span(),
420                    },
421                    fields: patterns,
422                    rest: None,
423                });
424
425                let body = if elements.len() == 1 {
426                    let syn::Expr::Path(syn::ExprPath { path, ..}) = elements.first().unwrap() else {
427                        panic!()
428                    };
429                    Box::new(Expr::Path(syn::ExprPath {
430                        attrs: vec![],
431                        qself: None,
432                        path: path.clone(),
433                    }))
434                } else {
435                    Box::new(Expr::Tuple(syn::ExprTuple {
436                        attrs: vec![],
437                        paren_token: token::Paren {
438                            span: group.delim_span(),
439                        },
440                        elems: elements,
441                    }))
442                };
443
444                (pat, body)
445            }
446        };
447
448        syn::Arm {
449            attrs: vec![],
450            guard: None,
451            fat_arrow_token: token::FatArrow {
452                spans: [Span::call_site(); 2],
453            },
454            comma: Some(token::Comma {
455                spans: [Span::call_site(); 1],
456            }),
457            pat,
458            body,
459        }
460    }
461}
462
463pub(crate) fn parse_function(
464    ts: proc_macro2::TokenStream,
465    ifn: &mut Option<ItemFn>,
466) -> TokenStream {
467    let r = TokenStream::from(ts);
468    let r2 = r.clone();
469    let pifn = parse_macro_input!(r2 as ItemFn);
470    *ifn = Some(pifn);
471    r
472}
473
474fn filter_derive_attributes(
475    attrs: &[syn::Attribute],
476    filtered_out: &[&str],
477) -> Vec<syn::Attribute> {
478    let mut result = vec![];
479    for attr in attrs {
480        match &attr.meta {
481            syn::Meta::List(ml) if ml.path.to_token_stream().to_string() == "derive" => {
482                let punctuated_parser = Punctuated::<syn::Path, Token![,]>::parse_terminated;
483                let punctuated = punctuated_parser.parse2(ml.tokens.clone()).unwrap();
484
485                let mut punctuated_result = Punctuated::<_, Token![,]>::new();
486                for item in punctuated.into_iter() {
487                    let last_segment = item.segments.last().unwrap().ident.to_string();
488                    if filtered_out.contains(&last_segment.as_str()) {
489                        continue;
490                    }
491                    punctuated_result.push(item);
492                }
493                result.push(syn::Attribute {
494                    pound_token: attr.pound_token,
495                    style: attr.style,
496                    bracket_token: attr.bracket_token,
497                    meta: syn::Meta::List(syn::MetaList {
498                        path: ml.path.clone(),
499                        delimiter: ml.delimiter.clone(),
500                        tokens: punctuated_result.to_token_stream(),
501                    }),
502                });
503            }
504            _ => result.push(attr.clone()),
505        }
506    }
507    result
508}