enum_typer/
lib.rs

1mod codegen;
2mod enum_parser;
3mod helpers;
4mod pattern_parser;
5mod type_analysis;
6mod variant_gen;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use std::collections::HashSet;
11
12use codegen::apply_type_hint_to_pattern;
13use enum_parser::ParsedEnum;
14use helpers::{add_static_bounds, collect_ordered_type_params};
15use pattern_parser::{extract_generics_from_type_hint, extract_type_and_pattern, parse_match_t};
16use variant_gen::generate_variant_code;
17
18/// Function-like macro for converting enums to traits with struct variants.
19/// It supports optional type indexing per variant and method definitions with
20/// pattern/body arms and existential return types.
21///
22/// # Example
23///
24/// Lift an enum definition into a trait with struct variants.
25///
26/// ```ignore
27/// type_enum! {
28///     pub enum Either<A, E> {
29///         Right(A),
30///         Left(E),
31///     }
32/// }
33/// ```
34///
35/// Or with indexed types. It is a feature similar to GADTs in other languages,
36/// where each variant can refine the overall type with specific type arguments.
37///
38/// ```ignore
39/// type_enum! {
40///    enum Expr<T> {
41///       LitInt(i32) : Expr<i32>,
42///       LitBool(bool) : Expr<bool>,
43///       Add(Box<Expr<i32>>, Box<Expr<i32>>) : Expr<i32>,
44///       Or(Box<Expr<bool>>, Box<Expr<bool>>) : Expr<bool>,
45///    }
46/// }
47/// ```
48///
49/// Or with functions using existential return types
50///
51/// ```ignore
52/// type_enum! {
53///    enum Expr<T> { ... }
54///
55///    fn eval(&self) -> T {
56///       LitInt(i) => *i,
57///       LitBool(b) => *b,
58///       Add(lhs, rhs) => lhs.eval() + rhs.eval(),
59///       Or(lhs, rhs) => lhs.eval() || rhs.eval(),
60///    }
61/// }
62/// ```
63#[proc_macro]
64pub fn type_enum(input: TokenStream) -> TokenStream {
65    let parsed = match syn::parse::<ParsedEnum>(input) {
66        Ok(p) => p,
67        Err(e) => return e.to_compile_error().into(),
68    };
69
70    let enum_name = &parsed.ident;
71    let vis = &parsed.vis;
72    let generics = &parsed.generics;
73
74    let all_type_params_ordered = collect_ordered_type_params(generics);
75    let all_type_params: HashSet<String> = all_type_params_ordered.iter().cloned().collect();
76
77    let generics_with_static = add_static_bounds(generics);
78    let (_impl_generics_static, _, where_clause_static) = generics_with_static.split_for_impl();
79
80    let structs_and_impls: Vec<_> = parsed
81        .variants
82        .iter()
83        .map(|variant| {
84            generate_variant_code(
85                variant,
86                &parsed.methods,
87                &generics_with_static,
88                &all_type_params,
89                &all_type_params_ordered,
90                vis,
91                enum_name,
92            )
93        })
94        .collect();
95
96    let trait_def = if !parsed.methods.is_empty() {
97        let method_sigs: Vec<_> = parsed.methods.iter().map(|m| &m.sig).collect();
98        quote! {
99            #vis trait #enum_name #generics_with_static: std::any::Any #where_clause_static {
100                #(#method_sigs;)*
101            }
102        }
103    } else {
104        quote! {
105            #vis trait #enum_name #generics_with_static: std::any::Any #where_clause_static {}
106        }
107    };
108
109    let expanded = quote! {
110        #trait_def
111        #(#structs_and_impls)*
112    };
113
114    TokenStream::from(expanded)
115}
116
117/// Pattern match on trait objects based on their concrete types.
118/// It supports both reference (`&dyn Trait`) and boxed (`Box<dyn Trait>`)
119/// trait objects.
120///
121/// Use `move` keyword to indicate ownership transfer when matching on `Box<dyn Trait>`.
122///
123/// # Example
124///
125/// ```ignore
126/// type_enum! {
127///     enum Tree<T: Display> {
128///         Leaf(T),
129///         Node(Box<Tree<T>>, Box<Tree<T>>),
130///     }
131/// }
132///
133/// let tree: Box<dyn Tree<i32>> = Box::new(...);
134/// let tree_ref: &dyn Tree<i32> = &...;
135/// let describe = match_t! {
136///     move tree {
137///         Leaf(value) => format!("Leaf: {}", value),
138///         Node(left, right) => format!("Node with left and right"),
139///     }
140/// }
141/// let describe_ref = match_t! {
142///     tree_ref {
143///         Leaf(value) => format!("Leaf: {}", value),
144///         Node(left, right) => format!("Node with left and right"),
145///     }
146/// }
147/// ```
148#[proc_macro]
149pub fn match_t(input: TokenStream) -> TokenStream {
150    let input_parsed = match parse_match_t(input) {
151        Ok(parsed) => parsed,
152        Err(e) => return e.to_compile_error().into(),
153    };
154
155    let expr = &input_parsed.expr;
156    let is_move = input_parsed.is_move;
157    let type_hint = &input_parsed.type_hint;
158
159    let hint_generics = type_hint
160        .as_ref()
161        .and_then(|hint| extract_generics_from_type_hint(hint));
162
163    if is_move {
164        let type_checks = input_parsed.arms.iter().enumerate().map(|(idx, arm)| {
165            let pattern = &arm.pattern;
166            let (type_name, _) = extract_type_and_pattern(pattern);
167            let type_name = apply_type_hint_to_pattern(type_name, &hint_generics);
168
169            quote! {
170                if (&*__expr as &dyn std::any::Any).is::<#type_name>() {
171                    __matched_idx = Some(#idx);
172                }
173            }
174        });
175
176        let match_arms = input_parsed.arms.iter().enumerate().map(|(idx, arm)| {
177            let pattern = &arm.pattern;
178            let body = &arm.body;
179            let (type_name, pattern_for_match) = extract_type_and_pattern(pattern);
180            let type_name = apply_type_hint_to_pattern(type_name, &hint_generics);
181
182            quote! {
183                #idx => {
184                    let __any_box: Box<dyn std::any::Any> = __expr;
185                    if let Ok(__concrete_box) = __any_box.downcast::<#type_name>() {
186                        match *__concrete_box {
187                            #pattern_for_match => #body,
188                            _ => panic!("Pattern match failed in match_t!")
189                        }
190                    } else {
191                        panic!("Downcast failed in match_t!");
192                    }
193                }
194            }
195        });
196
197        let expanded = quote! {
198            {
199                let __expr = #expr;
200                let mut __matched_idx: Option<usize> = None;
201
202                #(#type_checks)*
203
204                match __matched_idx {
205                    Some(__idx) => {
206                        match __idx {
207                            #(#match_arms,)*
208                            _ => panic!("Invalid match index in match_t!")
209                        }
210                    }
211                    None => panic!("No matching type found in match_t!")
212                }
213            }
214        };
215
216        TokenStream::from(expanded)
217    } else {
218        let match_arms = input_parsed.arms.iter().map(|arm| {
219            let pattern = &arm.pattern;
220            let body = &arm.body;
221            let (type_name, pattern_for_match) = extract_type_and_pattern(pattern);
222            let type_name = apply_type_hint_to_pattern(type_name, &hint_generics);
223
224            quote! {
225                if let Some(__value_ref) = (&*__expr as &dyn std::any::Any).downcast_ref::<#type_name>() {
226                    if let #pattern_for_match = __value_ref {
227                        return Some(#body);
228                    }
229                }
230            }
231        });
232
233        let expanded = quote! {
234            {
235                (|| -> Option<_> {
236                    let __expr = #expr;
237                    #(#match_arms)*
238                    None
239                })().expect("No matching type found in match_t!")
240            }
241        };
242
243        TokenStream::from(expanded)
244    }
245}