sigma_enum_macros/
lib.rs

1use crate::attrs::extract_expansion;
2use crate::nice_type::Infallible;
3use crate::nice_type::NiceType;
4use attrs::ItemAttr;
5use heck::ToSnakeCase;
6use nice_type::NiceTypeLit;
7use proc_macro::TokenStream;
8use quote::ToTokens;
9use quote::TokenStreamExt;
10use quote::format_ident;
11use quote::quote;
12use std::collections::BTreeMap;
13use std::collections::BTreeSet;
14use syn::Attribute;
15use syn::Expr;
16use syn::Ident;
17use syn::LitStr;
18use syn::Token;
19use syn::Visibility;
20use syn::braced;
21use syn::parenthesized;
22use syn::parse::Parse;
23use syn::parse::ParseStream;
24use syn::parse_macro_input;
25use syn::spanned::Spanned;
26
27mod attrs;
28mod nice_type;
29
30const INTERNAL_IDENT: &str = "__INTERNAL_IDENT";
31const INTERNAL_FULL_WILDCARD: &str = "__INTERNAL_FULL_WILDCARD";
32
33#[derive(Clone)]
34struct Variant {
35    ty: NiceType<Infallible>,
36    name: Ident,
37    attrs: Vec<Attribute>,
38    docs: proc_macro2::TokenStream,
39}
40
41#[derive(Clone)]
42struct SigmaEnum {
43    visibility: Visibility,
44    name: Ident,
45    variants: Vec<Variant>,
46    subattrs: Vec<Attribute>,
47    attr: ItemAttr,
48}
49
50impl SigmaEnum {
51    fn macro_match_name(&self) -> Ident {
52        self.attr.macro_match.name.as_ref().map_or_else(
53            || format_ident!("{}_match", self.name.to_string().to_snake_case()),
54            |name| format_ident!("{}", name),
55        )
56    }
57
58    fn macro_construct_name(&self) -> Ident {
59        self.attr.macro_construct.name.as_ref().map_or_else(
60            || format_ident!("{}_construct", self.name.to_string().to_snake_case()),
61            |name| format_ident!("{}", name),
62        )
63    }
64
65    fn into_trait_name(&self) -> Ident {
66        self.attr.into_trait.name.as_ref().map_or_else(
67            || format_ident!("Into{}", self.name),
68            |name| format_ident!("{}", name),
69        )
70    }
71
72    fn into_method_name(&self) -> Ident {
73        self.attr.into_method.name.as_ref().map_or_else(
74            || format_ident!("into_{}", self.name.to_string().to_snake_case()),
75            |name| format_ident!("{}", name),
76        )
77    }
78
79    fn try_from_method_name(&self) -> Ident {
80        self.attr.try_from_method.name.as_ref().map_or_else(
81            || format_ident!("try_from_{}", self.name.to_string().to_snake_case()),
82            |name| format_ident!("{}", name),
83        )
84    }
85
86    fn try_from_owned_method_name(&self) -> Ident {
87        self.attr.try_from_owned_method.name.as_ref().map_or_else(
88            || format_ident!("try_from_owned_{}", self.name.to_string().to_snake_case()),
89            |name| format_ident!("{}", name),
90        )
91    }
92
93    fn try_from_mut_method_name(&self) -> Ident {
94        self.attr.try_from_mut_method.name.as_ref().map_or_else(
95            || format_ident!("try_from_mut_{}", self.name.to_string().to_snake_case()),
96            |name| format_ident!("{}", name),
97        )
98    }
99
100    fn extract_method_name(&self) -> Ident {
101        self.attr.extract_method.name.as_ref().map_or_else(
102            || format_ident!("extract"),
103            |name| format_ident!("{}", name),
104        )
105    }
106
107    fn extract_owned_method_name(&self) -> Ident {
108        self.attr.extract_owned_method.name.as_ref().map_or_else(
109            || format_ident!("extract_owned"),
110            |name| format_ident!("{}", name),
111        )
112    }
113
114    fn extract_mut_method_name(&self) -> Ident {
115        self.attr.extract_mut_method.name.as_ref().map_or_else(
116            || format_ident!("extract_mut"),
117            |name| format_ident!("{}", name),
118        )
119    }
120
121    fn try_from_error_name(&self) -> Ident {
122        self.attr.try_from_error.name.as_ref().map_or_else(
123            || format_ident!("TryFrom{}Error", self.name.to_string()),
124            |name| format_ident!("{}", name),
125        )
126    }
127
128    fn internal_name(&self, which: &str, suffix: &str) -> Ident {
129        format_ident!(
130            "{INTERNAL_IDENT}_{}_{}{}",
131            self.name.to_string().to_snake_case(),
132            which,
133            suffix
134        )
135    }
136
137    fn to_tokens_macros(&self, tokens: &mut proc_macro2::TokenStream, export: bool, suffix: &str) {
138        let SigmaEnum {
139            visibility: _,
140            name,
141            variants,
142            subattrs: _,
143            attr,
144        } = &self;
145
146        let item_path = match &attr.path {
147            Some(path) => quote! { $ #path :: },
148            None => quote! {},
149        };
150        let macro_path = if export {
151            quote! { $crate :: }
152        } else {
153            item_path.clone()
154        };
155
156        let variants_btree: BTreeMap<_, _> = variants
157            .iter()
158            .map(|var| (var.ty.clone(), var.name.clone()))
159            .collect();
160        let variant_pats: Vec<_> = variants.iter().map(|var| var.ty.clone()).collect();
161
162        let macro_match = format_ident!("{}{}", self.macro_match_name(), suffix);
163        let macro_construct = format_ident!("{}{}", self.macro_construct_name(), suffix);
164        let macro_match_body = self.internal_name("body", suffix);
165        let macro_match_process_body = self.internal_name("process_body", suffix);
166        let macro_process_type = self.internal_name("process_type", suffix);
167        let macro_match_variant = self.internal_name("variant", suffix);
168        let macro_match_pattern = self.internal_name("pattern", suffix);
169        let macro_construct_inner = self.internal_name("construct_inner", suffix);
170
171        let macro_match_docstring = self.attr.macro_match.docstring();
172        let macro_construct_docstring = self.attr.macro_construct.docstring();
173
174        // https://github.com/rust-lang/rust/pull/52234#issuecomment-1417098097
175        let macro_match_export;
176        let macro_construct_export;
177        let macro_match_body_export;
178        let macro_match_process_body_export;
179        let macro_process_type_export;
180        let macro_match_variant_export;
181        let macro_match_pattern_export;
182        let macro_construct_inner_export;
183        let macro_match_pub_use;
184        let macro_construct_pub_use;
185        let macro_match_body_pub_use;
186        let macro_match_process_body_pub_use;
187        let macro_process_type_pub_use;
188        let macro_match_variant_pub_use;
189        let macro_match_pattern_pub_use;
190        let macro_construct_inner_pub_use;
191        if export {
192            macro_match_export = quote! { #macro_match_docstring #[macro_export] };
193            macro_construct_export = quote! { #macro_construct_docstring #[macro_export] };
194            macro_match_body_export = quote! { #[macro_export] };
195            macro_match_process_body_export = quote! { #[macro_export] };
196            macro_process_type_export = quote! { #[macro_export] };
197            macro_match_variant_export = quote! { #[macro_export] };
198            macro_match_pattern_export = quote! { #[macro_export] };
199            macro_construct_inner_export = quote! { #[macro_export] };
200            macro_match_pub_use = quote! {};
201            macro_construct_pub_use = quote! {};
202            macro_match_body_pub_use = quote! {};
203            macro_match_process_body_pub_use = quote! {};
204            macro_process_type_pub_use = quote! {};
205            macro_match_variant_pub_use = quote! {};
206            macro_match_pattern_pub_use = quote! {};
207            macro_construct_inner_pub_use = quote! {};
208        } else {
209            macro_match_export = quote! { #macro_match_docstring };
210            macro_construct_export = quote! { #macro_construct_docstring };
211            macro_match_body_export = quote! {};
212            macro_match_process_body_export = quote! {};
213            macro_process_type_export = quote! {};
214            macro_match_variant_export = quote! {};
215            macro_match_pattern_export = quote! {};
216            macro_construct_inner_export = quote! {};
217            macro_match_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #macro_match_docstring pub(crate) use #macro_match; };
218            macro_construct_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #macro_construct_docstring pub(crate) use #macro_construct; };
219            macro_match_body_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #[doc(hidden)] pub(crate) use #macro_match_body; };
220            macro_match_process_body_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #[doc(hidden)] pub(crate) use #macro_match_process_body; };
221            macro_process_type_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #[doc(hidden)] pub(crate) use #macro_process_type; };
222            macro_match_variant_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #[doc(hidden)] pub(crate) use #macro_match_variant; };
223            macro_match_pattern_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #[doc(hidden)] pub(crate) use #macro_match_pattern; };
224            macro_construct_inner_pub_use = quote! { #[allow(nonstandard_style)] #[allow(unused_imports)] #[doc(hidden)] pub(crate) use #macro_construct_inner; };
225        }
226
227        let internal_full_wildcard = format_ident!("{INTERNAL_FULL_WILDCARD}");
228
229        let mut patterns_map = BTreeMap::new();
230        patterns_map.insert(NiceType::PatternIdent(()), Vec::new());
231        for ty in &variant_pats {
232            for pat in ty.patterns_matching() {
233                let matches = patterns_map.entry(pat).or_insert(Vec::new());
234                matches.push(ty);
235            }
236        }
237
238        let patterns: Vec<_> = patterns_map.keys().collect();
239        let pat_variants: Vec<_> = patterns_map.values().collect();
240        let pat_variant_names: Vec<Vec<_>> = pat_variants
241            .iter()
242            .map(|v| v.iter().map(|ty| variants_btree[ty].clone()).collect())
243            .collect();
244
245        let patterns_vars: Vec<_> = patterns.iter().map(|pat| pat.index_patterns()).collect();
246        let patterns_vars_assoc: Vec<Vec<Vec<_>>> = pat_variants
247            .iter()
248            .zip(&patterns_vars)
249            .map(|(v, pat)| {
250                v.iter()
251                    .map(|ty| {
252                        ty.matches_map(&pat)
253                            .into_iter()
254                            .filter_map(|(ident, (ty, location))| {
255                                let NiceType::Literal(lit) = ty else {
256                                    return None;
257                                };
258                                // try block. sad
259                                let generic_ty = (|| {
260                                    let (parent, i) = location?;
261                                    self.attr.generics.get(&parent)?[i].as_ref()
262                                })();
263                                Some((ident, lit, generic_ty))
264                            })
265                            .collect()
266                    })
267                    .collect()
268            })
269            .collect();
270        // for each pattern, for each variant it matches, get the type pattern variables
271        // and their literals and locations, and generate let statements for them
272        let const_let_statements: Vec<Vec<proc_macro2::TokenStream>> = patterns_vars_assoc
273            .iter()
274            .map(|v| {
275                v.iter()
276                    .map(|v| {
277                        v.iter()
278                            .map(|(ident, lit, generic_ty)| match generic_ty {
279                                Some(generic_ty) => quote! { const $ #ident : #generic_ty = #lit; },
280                                None => quote! { let $ #ident = #lit; },
281                            })
282                            .map(|q| quote! { #[allow(nonstandard_style)] #[allow(unused_variables)] #q })
283                            .collect()
284                    })
285                    .collect()
286            })
287            .collect();
288
289        let pat_vars_params_eqs: Vec<Vec<Vec<_>>> = patterns_vars_assoc
290            .iter()
291            .map(|v| {
292                v.iter()
293                    .map(|v| {
294                        v.iter()
295                            .map(|(ident, lit, _generic_ty)| quote! { $ #ident == #lit })
296                            .collect()
297                    })
298                    .collect()
299            })
300            .collect();
301
302        let (pat_vars_names, pat_vars_params): (Vec<_>, Vec<_>) = patterns_vars
303            .iter()
304            .map(|pat| match pat {
305                NiceType::Ident(name, params) => (format_ident!("{}", name), {
306                    let params: Vec<_> = params
307                        .iter()
308                        .map(|param| param.map_pattern(|p| quote! { ? $ #p :ident }))
309                        .collect();
310                    (!params.is_empty())
311                        .then_some(params)
312                        .into_iter()
313                        .collect::<Vec<_>>()
314                }),
315                NiceType::PatternIdent(_p) => (
316                    internal_full_wildcard.clone(),
317                    None.into_iter().collect::<Vec<_>>(),
318                ),
319                _ => panic!("not ident {:?}", pat),
320            })
321            .unzip();
322
323        tokens.append_all(quote! {
324            #macro_match_export
325            #[allow(unused_macros)]
326            macro_rules! #macro_match {
327                ( match $( $rest:tt )* ) => {
328                    #macro_path #macro_match_body ! { (), ( $($rest)* ) }
329                };
330            }
331            #macro_match_pub_use
332        });
333
334        tokens.append_all(quote! {
335            #macro_match_body_export
336            #[doc(hidden)]
337            #[allow(nonstandard_style)]
338            macro_rules! #macro_match_body {
339                ( $what:tt, ({
340                    $( $rest:tt )*
341                }) ) => {
342                    #macro_path #macro_match_process_body !( $what, ( $($rest)* ), () )
343                };
344                ( ( $( $what:tt )* ), ( $next:tt $( $rest:tt )* ) ) => {
345                    #macro_path #macro_match_body ! { ( $($what)* $next ), ( $($rest)* ) }
346                };
347            }
348            #macro_match_body_pub_use
349        });
350
351        tokens.append_all(quote! {
352            #macro_match_process_body_export
353            #[doc(hidden)]
354            #[allow(nonstandard_style)]
355            macro_rules! #macro_match_process_body {
356                ( $what:tt, (), ( $( ( $ty:tt; $binding:pat => $body:expr ) )* ) ) => {
357                    {
358                        let what = $what;
359
360                        #[allow(unreachable_patterns)]
361                        match what {
362                            $( #macro_path #macro_match_pattern !($ty) => (), )*
363                        }
364
365                        #[allow(unused_labels)]
366                        'ma: {
367                            $( #macro_path #macro_match_variant !{$ty; what; 'ma; $binding => $body} )*
368                            ::core::unreachable!();
369                        }
370                    }
371                };
372                (
373                    $what:tt,
374                    ( $binding:ident => { $( $body:tt )* } $(,)? $( $rest:tt )* ),
375                    ( $( $matched:tt )* )
376                ) => {
377                    #macro_path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( (#internal_full_wildcard) ; $binding => { $( $body )* } ) ) )
378                };
379                (
380                    $what:tt,
381                    ( $binding:ident => $body:expr, $( $rest:tt )* ),
382                    ( $( $matched:tt )* )
383                ) => {
384                    #macro_path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( (#internal_full_wildcard) ; $binding => { $body } ) ) )
385                };
386                (
387                    $what:tt,
388                    ( $tyn:ident ( $binding:pat ) => { $( $body:tt )* } $(,)? $( $rest:tt )* ),
389                    ( $( $matched:tt )* )
390                ) => {
391                    #macro_path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn); $binding => { $($body)* } ) ) )
392                };
393                (
394                    $what:tt,
395                    ( $tyn:ident ( $binding:pat ) => $body:expr, $( $rest:tt )* ),
396                    ( $( $matched:tt )* )
397                ) => {
398                    #macro_path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn); $binding => { $body } ) ) )
399                };
400                (
401                    $what:tt,
402                    ( $tyn:ident ::< $( $rest:tt )* ),
403                    ( $( $matched:tt )* )
404                ) => {
405                    #macro_path #macro_process_type !( (@match, $what, $tyn, ($( $matched )*)), ($( $rest )*), (<), (<) )
406                };
407            }
408            #macro_match_process_body_pub_use
409        });
410
411        tokens.append_all(quote! {
412            #macro_process_type_export
413            #[doc(hidden)]
414            #[allow(nonstandard_style)]
415            macro_rules! #macro_process_type {
416                ( $bundle:tt, ($(,)? > $($rest:tt)*), ( $($params:tt)* ), (< $($counter:tt)*) ) => {
417                    #macro_path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* >), ($($counter)*) )
418                };
419                ( $bundle:tt, ($(,)? >> $($rest:tt)*), ( $($params:tt)* ), (< < $($counter:tt)*) ) => {
420                    #macro_path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* > >), ($($counter)*) )
421                };
422                ( $bundle:tt, ($(,)? > $($rest:tt)*), ( $($params:tt)* ), () ) => {
423                    ::core::compile_error!("imbalanced")
424                };
425                ( $bundle:tt, ($(,)? >> $($rest:tt)*), ( $($params:tt)* ), () ) => {
426                    ::core::compile_error!("imbalanced")
427                };
428                ( $bundle:tt, (< $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
429                    #macro_path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* <), (< $($counter)*) )
430                };
431                ( $bundle:tt, (<< $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
432                    #macro_path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* < <), (< < $($counter)*) )
433                };
434                ( (@match, $what:tt, $tyn:ident, ( $($matched:tt)* )), (( $binding:pat ) => { $( $body:tt )* } $(,)? $($rest:tt)*), ( $($params:tt)* ), () ) => {
435                    #macro_path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn :: $($params)+); $binding => { $($body)* } ) ) )
436                };
437                ( (@match, $what:tt, $tyn:ident, ( $($matched:tt)* )), (( $binding:pat ) => $body:expr, $($rest:tt)*), ( $($params:tt)* ), () ) => {
438                    #macro_path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn :: $($params)+); $binding => { $body } ) ) )
439                };
440                ( (@construct, $tyn:ident), (( $expr:expr )), ( $($params:tt)+ ), () ) => {
441                    #macro_path #macro_construct_inner !( ($tyn :: $($params)+); ( $expr ) )
442                };
443                ( $bundle:tt, (( $($any:tt)* ) $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
444                    ::core::compile_error!("imbalanced or something")
445                };
446                ( $bundle:tt, ($thing:tt $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
447                    #macro_path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* $thing), ( $($counter)*) )
448                };
449            }
450            #macro_process_type_pub_use
451        });
452
453        tokens.append_all(quote! {
454            #macro_match_variant_export
455            #[doc(hidden)]
456            #[allow(nonstandard_style)]
457            macro_rules! #macro_match_variant {
458                #( ( (#pat_vars_names #(::< #( #pat_vars_params ),* >)* ); $what:ident; $ma:lifetime; $binding:pat => $body:expr ) => {
459                    #( if let #item_path #name :: #pat_variant_names ($binding) = $what {
460                        #const_let_statements
461                        break $ma($body);
462                    } )*
463                }; )*
464            }
465            #macro_match_variant_pub_use
466        });
467
468        tokens.append_all(quote! {
469            #macro_match_pattern_export
470            #[doc(hidden)]
471            #[allow(nonstandard_style)]
472            macro_rules! #macro_match_pattern {
473                ( ( #internal_full_wildcard ) ) => { _ };
474                #( ( ( #pat_vars_names #(::< #( #pat_vars_params ),* >)* ) ) => {
475                    #( #item_path #name :: #pat_variant_names (_) )|*
476                }; )*
477            }
478            #macro_match_pattern_pub_use
479        });
480
481        tokens.append_all(quote! {
482            #macro_construct_export
483            #[allow(unused_macros)]
484            macro_rules! #macro_construct {
485                ( $tyn:ident ::< $($tt:tt)* ) => {
486                    #macro_path #macro_process_type !( (@construct, $tyn), ($($tt)*), (<), (<) )
487                };
488                ( $tyn:ident ( $body:expr ) ) => {
489                    #macro_path #macro_construct_inner !( ($tyn); ($body) )
490                };
491            }
492            #macro_construct_pub_use
493        });
494
495        tokens.append_all(quote! {
496            #macro_construct_inner_export
497            #[doc(hidden)]
498            #[allow(nonstandard_style)]
499            macro_rules! #macro_construct_inner {
500                #( ( (#pat_vars_names #(::< #( #pat_vars_params ),* >)* ); $body:expr ) => {
501                    'ma: {
502                        #( if true #(&& #pat_vars_params_eqs)* {
503                            #const_let_statements
504                            break 'ma ::core::option::Option::Some(#item_path #name :: #pat_variant_names($body));
505                        } )*
506                        ::core::option::Option::None
507                    }
508                }; )*
509            }
510            #macro_construct_inner_pub_use
511        });
512    }
513
514    fn to_tokens_traits(&self, tokens: &mut proc_macro2::TokenStream) {
515        let SigmaEnum {
516            visibility,
517            name,
518            variants,
519            subattrs: _,
520            attr,
521        } = &self;
522
523        let variant_types: Vec<_> = variants
524            .iter()
525            .map(|var| var.ty.to_tokens_aliased(&attr.alias))
526            .collect();
527        let variant_names: Vec<_> = variants.iter().map(|var| var.name.clone()).collect();
528
529        let into_trait = self.into_trait_name();
530        let into_trait_sealed_mod = self.internal_name("into_trait_sealed_mod", "");
531        let into_method = self.into_method_name();
532        let try_from_method = self.try_from_method_name();
533        let try_from_owned_method = self.try_from_owned_method_name();
534        let try_from_mut_method = self.try_from_mut_method_name();
535        let extract_method = self.extract_method_name();
536        let extract_owned_method = self.extract_owned_method_name();
537        let extract_mut_method = self.extract_mut_method_name();
538        let try_from_error = self.try_from_error_name();
539
540        let into_trait_docstring = self.attr.into_trait.docstring();
541        let into_method_docstring = self.attr.into_method.docstring();
542        let try_from_method_docstring = self.attr.try_from_method.docstring();
543        let try_from_owned_method_docstring = self.attr.try_from_owned_method.docstring();
544        let try_from_mut_method_docstring = self.attr.try_from_mut_method.docstring();
545        let extract_method_docstring = self.attr.extract_method.docstring();
546        let extract_owned_method_docstring = self.attr.extract_owned_method.docstring();
547        let extract_mut_method_docstring = self.attr.extract_mut_method.docstring();
548        let try_from_error_docstring = self.attr.try_from_error.docstring();
549
550        let methods = quote! {
551            #into_method_docstring
552            fn #into_method (self) -> #name;
553            #try_from_method_docstring
554            fn #try_from_method (value: & #name) -> ::core::option::Option<&Self>;
555            #try_from_owned_method_docstring
556            fn #try_from_owned_method (value: #name) -> ::core::option::Option<Self>
557                where Self: ::core::marker::Sized;
558            #try_from_mut_method_docstring
559            fn #try_from_mut_method (value: &mut #name) -> ::core::option::Option<&mut Self>;
560        };
561
562        tokens.append_all(quote! {
563            #into_trait_docstring
564            pub trait #into_trait : #into_trait_sealed_mod ::Sealed {
565                #methods
566            }
567
568            #[allow(nonstandard_style)]
569            mod #into_trait_sealed_mod {
570                pub trait Sealed {}
571            }
572
573            #(
574                #[automatically_derived]
575                impl #into_trait_sealed_mod ::Sealed for #variant_types {}
576            )*
577        });
578
579        tokens.append_all(quote! {
580            #(
581                #into_trait_docstring
582                #[automatically_derived]
583                impl #into_trait for #variant_types {
584                    fn #into_method (self) -> #name {
585                        #name :: #variant_names (self)
586                    }
587
588                    fn #try_from_method <'a>(value: &'a #name) -> ::core::option::Option<&'a Self> {
589                        if let #name :: #variant_names (out) = value {
590                            ::core::option::Option::Some(out)
591                        } else {
592                            ::core::option::Option::None
593                        }
594                    }
595
596                    fn #try_from_owned_method (value: #name) -> ::core::option::Option<Self>
597                        where Self: ::core::marker::Sized
598                    {
599                        if let #name :: #variant_names (out) = value {
600                            ::core::option::Option::Some(out)
601                        } else {
602                            ::core::option::Option::None
603                        }
604                    }
605
606                    fn #try_from_mut_method <'a>(value: &'a mut #name) -> ::core::option::Option<&'a mut Self> {
607                        if let #name :: #variant_names (out) = value {
608                            ::core::option::Option::Some(out)
609                        } else {
610                            ::core::option::Option::None
611                        }
612                    }
613                }
614
615                #[automatically_derived]
616                impl ::core::convert::From<#variant_types> for #name {
617                    fn from(value: #variant_types) -> Self {
618                        #into_trait :: #into_method (value)
619                    }
620                }
621
622                #[automatically_derived]
623                impl<'a> ::core::convert::TryFrom<&'a #name> for &'a #variant_types {
624                    type Error = #try_from_error;
625                    fn try_from(value: &'a #name) -> ::core::result::Result<&'a #variant_types, #try_from_error > {
626                       < #variant_types as #into_trait >:: #try_from_method (value).ok_or( #try_from_error )
627                    }
628                }
629
630                #[automatically_derived]
631                impl ::core::convert::TryFrom<#name> for #variant_types
632                        where Self: ::core::marker::Sized
633                {
634                    type Error = #try_from_error;
635                    fn try_from(value: #name) -> ::core::result::Result<#variant_types, #try_from_error > {
636                       < #variant_types as #into_trait >:: #try_from_owned_method (value).ok_or( #try_from_error )
637                    }
638                }
639            )*
640
641            impl #name {
642                #extract_method_docstring
643                #visibility fn #extract_method <T: #into_trait >(&self) -> ::core::option::Option<&T> {
644                    T:: #try_from_method (self)
645                }
646
647                #extract_owned_method_docstring
648                #visibility fn #extract_owned_method <T: #into_trait >(self) -> ::core::option::Option<T> {
649                    T:: #try_from_owned_method (self)
650                }
651
652                #extract_mut_method_docstring
653                #visibility fn #extract_mut_method <T: #into_trait >(&mut self) -> ::core::option::Option<&mut T> {
654                    T:: #try_from_mut_method (self)
655                }
656            }
657        });
658
659        tokens.append_all(quote! {
660            #try_from_error_docstring
661            pub struct #try_from_error;
662
663            #[automatically_derived]
664            impl ::core::fmt::Debug for #try_from_error {
665                #[inline]
666                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
667                    ::core::fmt::Formatter::write_str(f, ::core::stringify!(#try_from_error))
668                }
669            }
670            #[automatically_derived]
671            impl ::core::clone::Clone for #try_from_error {
672                #[inline]
673                fn clone(&self) -> #try_from_error {
674                    *self
675                }
676            }
677            #[automatically_derived]
678            impl ::core::marker::Copy for #try_from_error {}
679            #[automatically_derived]
680            impl ::core::cmp::PartialEq for #try_from_error {
681                #[inline]
682                fn eq(&self, other: & #try_from_error) -> bool {
683                    true
684                }
685            }
686            #[automatically_derived]
687            impl ::core::cmp::Eq for #try_from_error {}
688            #[automatically_derived]
689            impl ::core::hash::Hash for #try_from_error {
690                #[inline]
691                fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {}
692            }
693            #[automatically_derived]
694            impl ::core::cmp::PartialOrd for #try_from_error {
695                #[inline]
696                fn partial_cmp(&self, other: & #try_from_error) -> ::core::option::Option<::core::cmp::Ordering> {
697                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
698                }
699            }
700            #[automatically_derived]
701            impl ::core::cmp::Ord for #try_from_error {
702                #[inline]
703                fn cmp(&self, other: & #try_from_error) -> ::core::cmp::Ordering {
704                    ::core::cmp::Ordering::Equal
705                }
706            }
707
708            #[automatically_derived]
709            impl ::core::fmt::Display for #try_from_error {
710                fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
711                    f.write_str("attempted to extract value from a ")?;
712                    f.write_str(::core::stringify!( #name ))?;
713                    f.write_str(" holding a different type")?;
714                    ::core::fmt::Result::Ok(())
715                }
716            }
717
718            #[automatically_derived]
719            impl ::core::error::Error for #try_from_error {}
720        });
721    }
722}
723
724impl ToTokens for SigmaEnum {
725    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
726        let SigmaEnum {
727            visibility,
728            name,
729            variants,
730            subattrs,
731            attr,
732        } = &self;
733
734        if attr.path.is_none()
735            && matches!(
736                visibility,
737                Visibility::Public(_) | Visibility::Restricted(_)
738            )
739        {
740            tokens.append_all(
741                quote! { ::core::compile_error!("public or restricted enum without path attribute"); },
742            );
743            return;
744        }
745
746        let variant_types: Vec<_> = variants
747            .iter()
748            .map(|var| var.ty.to_tokens_aliased(&attr.alias))
749            .collect();
750        let variant_names: Vec<_> = variants.iter().map(|var| var.name.clone()).collect();
751        let variant_attrs: Vec<_> = variants.iter().map(|var| var.attrs.clone()).collect();
752        let variant_docs: Vec<_> = variants.iter().map(|var| var.docs.clone()).collect();
753
754        tokens.append_all(quote! {
755            #(#subattrs)*
756            #visibility enum #name {
757                #(
758                    #variant_docs
759                    #(#variant_attrs)*
760                    #variant_names(#variant_types),
761                )*
762            }
763        });
764
765        match visibility {
766            Visibility::Public(_) => {
767                self.to_tokens_macros(tokens, true, "");
768                self.to_tokens_macros(tokens, false, "_crate");
769            }
770            _ => {
771                self.to_tokens_macros(tokens, false, "");
772            }
773        }
774        self.to_tokens_traits(tokens);
775    }
776}
777
778fn substitute_template(
779    template: &str,
780    assignments: &[(Ident, NiceTypeLit)],
781) -> syn::Result<String> {
782    let mut name = template.to_string();
783    for (var, val) in assignments {
784        name = name.replace(&format!("{{{}}}", var), &val.variant_name_string());
785    }
786    if name.contains('{') {
787        return Err(syn::Error::new(
788            template.span(),
789            "invalid metavariable in rename template",
790        ));
791    }
792    Ok(name)
793}
794
795impl Parse for SigmaEnum {
796    fn parse(input: ParseStream) -> syn::Result<Self> {
797        let subattrs = Attribute::parse_outer(input)?;
798        let visibility: Visibility = input.parse()?;
799        let _: Token![enum] = input.parse()?;
800        let name: Ident = input.parse()?;
801        let content;
802        braced!(content in input);
803        let mut variants = Vec::new();
804        let mut variant_tys = BTreeSet::new();
805        let mut attrs = Vec::new();
806        while !content.is_empty() {
807            let mut expand = BTreeMap::new();
808            let mut rename = None;
809            let mut docs = None;
810            if let Ok(attributes) = content.call(Attribute::parse_outer) {
811                for attr in &attributes {
812                    if attr.path().is_ident("sigma_enum") {
813                        attr.parse_nested_meta(|meta| {
814                            match meta.path.require_ident()?.to_string().as_str() {
815                                "expand" => {
816                                    meta.parse_nested_meta(|meta| {
817                                        let ident = meta.path.require_ident()?;
818                                        let value: Expr = meta.value()?.parse()?;
819                                        let value = extract_expansion(&value)?;
820                                        if expand.contains_key(ident) {
821                                            return Err(syn::Error::new(
822                                                meta.path.span(),
823                                                "duplicate expand attribute",
824                                            ));
825                                        }
826                                        expand.insert(ident.clone(), value);
827                                        Ok(())
828                                    })?;
829                                }
830                                "rename" => {
831                                    if rename.is_some() {
832                                        return Err(syn::Error::new(
833                                            meta.path.span(),
834                                            "duplicate rename attribute",
835                                        ));
836                                    }
837                                    let _: Token![=] = meta.input.parse()?;
838                                    if let Ok(ident) = meta.input.parse::<Ident>() {
839                                        rename = Some(ident.to_string());
840                                    } else if let Ok(template) = meta.input.parse::<LitStr>() {
841                                        rename = Some(template.value());
842                                    } else {
843                                        return Err(syn::Error::new(
844                                            meta.input.span(),
845                                            "invalid renaming template",
846                                        ));
847                                    }
848                                }
849                                "docs" => {
850                                    if docs.is_some() {
851                                        return Err(syn::Error::new(
852                                            meta.path.span(),
853                                            "duplicate docs attribute",
854                                        ));
855                                    }
856                                    let _: Token![=] = meta.input.parse()?;
857                                    if let Ok(template) = meta.input.parse::<LitStr>() {
858                                        docs = Some(template.value());
859                                    } else {
860                                        return Err(syn::Error::new(
861                                            meta.input.span(),
862                                            "invalid docstring template",
863                                        ));
864                                    }
865                                }
866                                _ => {
867                                    return Err(syn::Error::new(meta.path.span(), "invalid attr"));
868                                }
869                            }
870                            Ok(())
871                        })?;
872                    } else {
873                        attrs.push(attr.clone());
874                    }
875                }
876            }
877
878            // variant name
879            // we cannot have rename and variant name
880
881            let enum_var_name: Ident = content.parse()?;
882            let enum_var_name =
883                (!enum_var_name.to_string().starts_with("_")).then_some(enum_var_name);
884            if rename.is_some() && enum_var_name.is_some() {
885                return Err(syn::Error::new(
886                    enum_var_name.span(),
887                    "cannot use variant name and rename attribute",
888                ));
889            }
890            if !expand.is_empty() && enum_var_name.is_some() {
891                return Err(syn::Error::new(
892                    enum_var_name.span(),
893                    "cannot use variant name and expand attribute",
894                ));
895            }
896
897            let ty_paren;
898            parenthesized!(ty_paren in content);
899            let nice_type: NiceType<Infallible> = ty_paren.parse()?;
900            assert!(ty_paren.is_empty());
901            let _ = content.parse::<Token![,]>();
902
903            if rename.as_deref().is_some_and(|rename| {
904                !expand
905                    .keys()
906                    .all(|ident| rename.contains(&format!("{{{}}}", ident)))
907            }) {
908                return Err(syn::Error::new(
909                    enum_var_name.span(),
910                    "rename template does not have all metavariables",
911                ));
912            }
913
914            let cartesian: Vec<Vec<(Ident, NiceTypeLit)>> =
915                expand
916                    .into_iter()
917                    .fold(vec![Vec::new()], |accum, (ident, range)| {
918                        accum
919                            .into_iter()
920                            .flat_map(|a| {
921                                range.iter().map({
922                                    let ident = &ident;
923                                    move |r| {
924                                        let mut a = a.clone();
925                                        a.push((ident.clone(), r.clone()));
926                                        a
927                                    }
928                                })
929                            })
930                            .collect()
931                    });
932
933            for assignments in cartesian {
934                let mut var_type = nice_type.clone();
935                for (ident, r) in &assignments {
936                    var_type = var_type.replace_ident(&ident.to_string(), &r)
937                }
938                let name = match &rename {
939                    Some(template) => {
940                        format_ident!("{}", substitute_template(&template, &assignments)?)
941                    }
942                    None => match &enum_var_name {
943                        Some(enum_var_name) => enum_var_name.clone(),
944                        None => var_type.variant_name(),
945                    },
946                };
947                let docs = match &docs {
948                    Some(template) => {
949                        let docstring = substitute_template(&template, &assignments)?;
950                        quote! {#[doc = #docstring]}
951                    }
952                    None => quote! {},
953                };
954                if !variant_tys.insert(var_type.clone()) {
955                    return Err(syn::Error::new(var_type.span(), "duplicate variant types"));
956                }
957                variants.push(Variant {
958                    ty: var_type,
959                    name,
960                    docs,
961                    attrs: attrs.clone(),
962                });
963            }
964        }
965
966        Ok(SigmaEnum {
967            visibility,
968            name,
969            variants,
970            subattrs,
971            attr: ItemAttr::default(),
972        })
973    }
974}
975
976#[proc_macro_attribute]
977pub fn sigma_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
978    // Parse the input tokens into a syntax tree
979    let mut sigma_enum = parse_macro_input!(item as SigmaEnum);
980    let attr = parse_macro_input!(attr as ItemAttr);
981    sigma_enum.attr = attr;
982
983    // panic!("{}", quote! { #sigma_enum });
984    quote! { #sigma_enum }.into()
985}