1#![doc = include_str!("../README.md")]
2
3use functions_builder::EnumFunctionsBuilder;
4use proc_macro::TokenStream;
5use proc_macro2::{Ident, Span};
6use quote::{quote, ToTokens};
7use ref_enum_builder::RefEnumBuilder;
8use syn::{
9 parse::Parser,
10 parse_macro_input,
11 punctuated::Punctuated,
12 token::{self},
13 Expr, Fields, ItemEnum, ItemFn, Token, Type, TypeTuple, Variant, Visibility,
14};
15use tag_enum_builder::TagEnumBuilder;
16
17pub(crate) mod functions_builder;
18pub(crate) mod ref_enum_builder;
19pub(crate) mod tag_enum_builder;
20
21#[proc_macro_attribute]
22pub fn generate_enum_helper(attr: TokenStream, item: TokenStream) -> TokenStream {
23 let mut enum_stream = item.clone();
24
25 let parser = Punctuated::<Ident, Token![,]>::parse_separated_nonempty;
26 let attributes = parser.parse(attr).unwrap();
27 let input = parse_macro_input!(item as ItemEnum);
28
29 let mut generate_tag_enum = false;
30 let mut generate_ref_enum = false;
31 let mut generate_mut_enum = false;
32
33 let mut create_is_functions = false;
34 let mut create_unwrap_functions = false;
35 let mut create_unwrap_ref_functions = false;
36 let mut create_unwrap_ref_mut_functions = false;
37 let mut create_to_tag_functions = false;
38 let mut create_as_ref_functions = false;
39 let mut create_as_mut_functions = false;
40 let mut create_get_functions = false;
41 let mut create_get_ref_functions = false;
42 let mut create_get_mut_functions = false;
43 for item in attributes {
44 match item.to_string().as_str() {
45 "TagEnum" => generate_tag_enum = true,
46 "RefEnum" => generate_ref_enum = true,
47 "MutEnum" => generate_mut_enum = true,
48 "is" => create_is_functions = true,
49 "unwrap" => create_unwrap_functions = true,
50 "unwrap_ref" => create_unwrap_ref_functions = true,
51 "unwrap_mut" => create_unwrap_ref_mut_functions = true,
52 "to_tag" => create_to_tag_functions = true,
53 "as_ref" => create_as_ref_functions = true,
54 "as_mut" => create_as_mut_functions = true,
55 "get" => create_get_functions = true,
56 "get_ref" => create_get_ref_functions = true,
57 "get_mut" => create_get_mut_functions = true,
58 _ => panic!(),
59 }
60 }
61
62 let input_enum = InputEnum(input);
63 if create_is_functions
64 || create_unwrap_functions
65 || create_unwrap_ref_functions
66 || create_unwrap_ref_mut_functions
67 || create_to_tag_functions
68 || create_as_ref_functions
69 || create_as_mut_functions
70 || create_get_functions
71 || create_get_ref_functions
72 || create_get_mut_functions
73 {
74 let mut functions_builder = EnumFunctionsBuilder::new(&input_enum);
75 if create_is_functions {
76 functions_builder.is_functions();
77 }
78 if create_unwrap_functions {
79 functions_builder.unwrap_functions();
80 }
81 if create_unwrap_ref_functions {
82 functions_builder.unwrap_ref_functions();
83 }
84 if create_unwrap_ref_mut_functions {
85 functions_builder.unwrap_mut_functions();
86 }
87 if create_to_tag_functions {
88 functions_builder.to_tag_function();
89 }
90 if create_as_ref_functions {
91 functions_builder.as_ref_functions();
92 }
93 if create_as_mut_functions {
94 functions_builder.as_mut_functions();
95 }
96 if create_get_functions {
97 functions_builder.get_functions();
98 }
99 if create_get_ref_functions {
100 functions_builder.get_ref_functions();
101 }
102 if create_get_mut_functions {
103 functions_builder.get_mut_functions();
104 }
105
106 let ts = functions_builder.token_stream();
107 enum_stream.extend([ts]);
108 }
109
110 if generate_tag_enum {
111 let mut tag_enum_builder = TagEnumBuilder::new(&input_enum);
112 if create_is_functions {
113 tag_enum_builder.is_functions();
114 }
115 let ts = tag_enum_builder.token_stream();
116 enum_stream.extend([ts]);
117 }
118
119 if generate_ref_enum {
120 let mut ref_enum_builder = RefEnumBuilder::new(&input_enum, false);
121 if create_is_functions {
122 ref_enum_builder.is_functions();
123 }
124 if create_unwrap_functions {
125 ref_enum_builder.unwrap_functions();
126 }
127 if create_to_tag_functions {
128 ref_enum_builder.to_tag_functions();
129 }
130 if create_get_functions {
131 ref_enum_builder.get_functions();
132 }
133 let ts = ref_enum_builder.token_stream();
134 enum_stream.extend([ts]);
135 }
136
137 if generate_mut_enum {
138 let mut ref_enum_builder = RefEnumBuilder::new(&input_enum, true);
139 if create_is_functions {
140 ref_enum_builder.is_functions();
141 }
142 if create_unwrap_functions {
143 ref_enum_builder.unwrap_functions();
144 }
145 if create_to_tag_functions {
146 ref_enum_builder.to_tag_functions();
147 }
148 if create_get_functions {
149 ref_enum_builder.get_functions();
150 }
151 let ts = ref_enum_builder.token_stream();
152 enum_stream.extend([ts]);
153 }
154
155 enum_stream
156}
157
158pub(crate) struct InputEnum(ItemEnum);
159
160impl InputEnum {
161 fn vis(&self) -> &Visibility {
162 &self.0.vis
163 }
164
165 fn name(&self) -> String {
166 format!("{}", self.0.ident)
167 }
168
169 fn variant_snake_case_name(&self, i: usize) -> String {
170 let variant_name = self.0.variants[i].ident.to_string();
171 let mut snake_case_name = String::new();
172 for c in variant_name.chars() {
173 if c.is_uppercase() && snake_case_name.is_empty() {
174 snake_case_name += format!("{}", c.to_ascii_lowercase()).as_str();
175 } else if c.is_uppercase() {
176 snake_case_name += format!("_{}", c.to_ascii_lowercase()).as_str();
177 } else {
178 snake_case_name += format!("{c}").as_str();
179 }
180 }
181 snake_case_name
182 }
183
184 fn generics(&self) -> &syn::Generics {
185 &self.0.generics
186 }
187
188 fn attributes(&self) -> &Vec<syn::Attribute> {
189 &self.0.attrs
190 }
191
192 fn iter_variants(&self) -> impl Iterator<Item = &Variant> {
193 self.0.variants.iter()
194 }
195
196 fn variant_count(&self) -> usize {
197 self.0.variants.len()
198 }
199
200 fn variant(&self, i: usize) -> &Variant {
201 &self.0.variants[i]
202 }
203
204 fn variant_type(&self, i: usize) -> Type {
205 let elems: Punctuated<_, _> = self.0.variants[i]
206 .fields
207 .iter()
208 .map(|f| f.ty.clone())
209 .collect();
210
211 if elems.len() == 1 {
212 return (*elems.first().unwrap()).clone();
213 }
214
215 let group = proc_macro2::Group::new(
216 proc_macro2::Delimiter::Parenthesis,
217 proc_macro2::TokenStream::new(),
218 );
219 syn::Type::Tuple(TypeTuple {
220 paren_token: token::Paren {
221 span: group.delim_span(),
222 },
223 elems,
224 })
225 }
226
227 fn match_variant(&self, i: usize, enum_ident: Option<Ident>) -> syn::Pat {
228 let variant = self.variant(i);
229 let enum_name = enum_ident.as_ref().unwrap_or(&self.0.ident);
230 let variant_name = &self.variant(i).ident;
231 let pattern = match &variant.fields {
232 Fields::Unit => {
233 quote! {
234 #enum_name :: #variant_name
235 }
236 }
237 Fields::Named(_) => {
238 quote! {
239 #enum_name :: #variant_name { .. }
240 }
241 }
242 Fields::Unnamed(fields) => {
243 let wild_pattern = vec![
244 syn::Pat::Wild(syn::PatWild {
245 attrs: vec![],
246 underscore_token: token::Underscore {
247 spans: [Span::call_site(); 1]
248 },
249 });
250 fields.unnamed.len()
251 ];
252
253 quote! {
254 #enum_name :: #variant_name ( #(#wild_pattern ,)* )
255 }
256 }
257 };
258
259 syn::Pat::Verbatim(pattern)
260 }
261
262 fn match_variant_to_tuple(&self, i: usize, enum_ident: Option<Ident>) -> syn::Arm {
263 let (pat, body) = match &self.variant(i).fields {
264 Fields::Unit => {
265 let group = proc_macro2::Group::new(
266 proc_macro2::Delimiter::Parenthesis,
267 proc_macro2::TokenStream::new(),
268 );
269
270 (
271 self.match_variant(i, None),
272 Box::new(Expr::Tuple(syn::ExprTuple {
273 attrs: vec![],
274 paren_token: token::Paren {
275 span: group.delim_span(),
276 },
277 elems: Punctuated::new(),
278 })),
279 )
280 }
281 Fields::Unnamed(fields) => {
282 let mut patterns = Punctuated::new();
283 let mut elements = Punctuated::new();
284
285 for (index, _field) in fields.unnamed.iter().enumerate() {
286 let name = format!("e{index}");
287 let ident = Ident::new(name.as_str(), Span::call_site());
288 patterns.push(syn::Pat::Path(syn::PatPath {
289 attrs: vec![],
290 qself: None,
291 path: syn::PathSegment {
292 arguments: syn::PathArguments::None,
293 ident: ident.clone(),
294 }
295 .into(),
296 }));
297
298 elements.push(Expr::Path(syn::ExprPath {
299 attrs: vec![],
300 qself: None,
301 path: syn::PathSegment {
302 arguments: syn::PathArguments::None,
303 ident,
304 }
305 .into(),
306 }))
307 }
308
309 let pattern_path = {
310 let mut punctuated = Punctuated::new();
311 punctuated.push(syn::PathSegment {
312 ident: enum_ident.unwrap_or(self.0.ident.clone()),
313 arguments: syn::PathArguments::None,
314 });
315 punctuated.push(syn::PathSegment {
316 ident: self.variant(i).ident.clone(),
317 arguments: syn::PathArguments::None,
318 });
319
320 syn::Path {
321 leading_colon: None,
322 segments: punctuated,
323 }
324 };
325
326 let group = proc_macro2::Group::new(
327 proc_macro2::Delimiter::Parenthesis,
328 proc_macro2::TokenStream::new(),
329 );
330 let pat = syn::Pat::TupleStruct(syn::PatTupleStruct {
331 attrs: vec![],
332 qself: None,
333 path: pattern_path,
334 paren_token: token::Paren {
335 span: group.delim_span(),
336 },
337 elems: patterns,
338 });
339
340 let body = if elements.len() == 1 {
341 let syn::Expr::Path(syn::ExprPath { path, ..}) = elements.first().unwrap() else {
342 panic!()
343 };
344 Box::new(Expr::Path(syn::ExprPath {
345 attrs: vec![],
346 qself: None,
347 path: path.clone(),
348 }))
349 } else {
350 Box::new(Expr::Tuple(syn::ExprTuple {
351 attrs: vec![],
352 paren_token: token::Paren {
353 span: group.delim_span(),
354 },
355 elems: elements,
356 }))
357 };
358
359 (pat, body)
360 }
361 Fields::Named(fields) => {
362 let mut patterns = Punctuated::new();
364 let mut elements = Punctuated::new();
365
366 for field in fields.named.iter() {
367 patterns.push(syn::FieldPat {
368 attrs: vec![],
369 member: syn::Member::Named(field.ident.clone().unwrap()),
370 colon_token: None, pat: Box::new(syn::Pat::Path(syn::PatPath {
372 attrs: vec![],
373 qself: None,
374 path: syn::PathSegment {
375 arguments: syn::PathArguments::None,
376 ident: field.ident.clone().unwrap(),
377 }
378 .into(),
379 })),
380 });
381
382 elements.push(Expr::Path(syn::ExprPath {
383 attrs: vec![],
384 qself: None,
385 path: syn::PathSegment {
386 arguments: syn::PathArguments::None,
387 ident: field.ident.clone().unwrap(),
388 }
389 .into(),
390 }))
391 }
392
393 let pattern_path = {
394 let mut punctuated = Punctuated::new();
395 punctuated.push(syn::PathSegment {
396 ident: enum_ident.unwrap_or(self.0.ident.clone()),
397 arguments: syn::PathArguments::None,
398 });
399 punctuated.push(syn::PathSegment {
400 ident: self.variant(i).ident.clone(),
401 arguments: syn::PathArguments::None,
402 });
403
404 syn::Path {
405 leading_colon: None,
406 segments: punctuated,
407 }
408 };
409
410 let group = proc_macro2::Group::new(
411 proc_macro2::Delimiter::Parenthesis,
412 proc_macro2::TokenStream::new(),
413 );
414 let pat = syn::Pat::Struct(syn::PatStruct {
415 attrs: vec![],
416 qself: None,
417 path: pattern_path,
418 brace_token: token::Brace {
419 span: group.delim_span(),
420 },
421 fields: patterns,
422 rest: None,
423 });
424
425 let body = if elements.len() == 1 {
426 let syn::Expr::Path(syn::ExprPath { path, ..}) = elements.first().unwrap() else {
427 panic!()
428 };
429 Box::new(Expr::Path(syn::ExprPath {
430 attrs: vec![],
431 qself: None,
432 path: path.clone(),
433 }))
434 } else {
435 Box::new(Expr::Tuple(syn::ExprTuple {
436 attrs: vec![],
437 paren_token: token::Paren {
438 span: group.delim_span(),
439 },
440 elems: elements,
441 }))
442 };
443
444 (pat, body)
445 }
446 };
447
448 syn::Arm {
449 attrs: vec![],
450 guard: None,
451 fat_arrow_token: token::FatArrow {
452 spans: [Span::call_site(); 2],
453 },
454 comma: Some(token::Comma {
455 spans: [Span::call_site(); 1],
456 }),
457 pat,
458 body,
459 }
460 }
461}
462
463pub(crate) fn parse_function(
464 ts: proc_macro2::TokenStream,
465 ifn: &mut Option<ItemFn>,
466) -> TokenStream {
467 let r = TokenStream::from(ts);
468 let r2 = r.clone();
469 let pifn = parse_macro_input!(r2 as ItemFn);
470 *ifn = Some(pifn);
471 r
472}
473
474fn filter_derive_attributes(
475 attrs: &[syn::Attribute],
476 filtered_out: &[&str],
477) -> Vec<syn::Attribute> {
478 let mut result = vec![];
479 for attr in attrs {
480 match &attr.meta {
481 syn::Meta::List(ml) if ml.path.to_token_stream().to_string() == "derive" => {
482 let punctuated_parser = Punctuated::<syn::Path, Token![,]>::parse_terminated;
483 let punctuated = punctuated_parser.parse2(ml.tokens.clone()).unwrap();
484
485 let mut punctuated_result = Punctuated::<_, Token![,]>::new();
486 for item in punctuated.into_iter() {
487 let last_segment = item.segments.last().unwrap().ident.to_string();
488 if filtered_out.contains(&last_segment.as_str()) {
489 continue;
490 }
491 punctuated_result.push(item);
492 }
493 result.push(syn::Attribute {
494 pound_token: attr.pound_token,
495 style: attr.style,
496 bracket_token: attr.bracket_token,
497 meta: syn::Meta::List(syn::MetaList {
498 path: ml.path.clone(),
499 delimiter: ml.delimiter.clone(),
500 tokens: punctuated_result.to_token_stream(),
501 }),
502 });
503 }
504 _ => result.push(attr.clone()),
505 }
506 }
507 result
508}