1mod codegen;
2mod enum_parser;
3mod helpers;
4mod pattern_parser;
5mod type_analysis;
6mod variant_gen;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use std::collections::HashSet;
11
12use codegen::apply_type_hint_to_pattern;
13use enum_parser::ParsedEnum;
14use helpers::{add_static_bounds, collect_ordered_type_params};
15use pattern_parser::{extract_generics_from_type_hint, extract_type_and_pattern, parse_match_t};
16use variant_gen::generate_variant_code;
17
18#[proc_macro]
64pub fn type_enum(input: TokenStream) -> TokenStream {
65 let parsed = match syn::parse::<ParsedEnum>(input) {
66 Ok(p) => p,
67 Err(e) => return e.to_compile_error().into(),
68 };
69
70 let enum_name = &parsed.ident;
71 let vis = &parsed.vis;
72 let generics = &parsed.generics;
73
74 let all_type_params_ordered = collect_ordered_type_params(generics);
75 let all_type_params: HashSet<String> = all_type_params_ordered.iter().cloned().collect();
76
77 let generics_with_static = add_static_bounds(generics);
78 let (_impl_generics_static, _, where_clause_static) = generics_with_static.split_for_impl();
79
80 let structs_and_impls: Vec<_> = parsed
81 .variants
82 .iter()
83 .map(|variant| {
84 generate_variant_code(
85 variant,
86 &parsed.methods,
87 &generics_with_static,
88 &all_type_params,
89 &all_type_params_ordered,
90 vis,
91 enum_name,
92 )
93 })
94 .collect();
95
96 let trait_def = if !parsed.methods.is_empty() {
97 let method_sigs: Vec<_> = parsed.methods.iter().map(|m| &m.sig).collect();
98 quote! {
99 #vis trait #enum_name #generics_with_static: std::any::Any #where_clause_static {
100 #(#method_sigs;)*
101 }
102 }
103 } else {
104 quote! {
105 #vis trait #enum_name #generics_with_static: std::any::Any #where_clause_static {}
106 }
107 };
108
109 let expanded = quote! {
110 #trait_def
111 #(#structs_and_impls)*
112 };
113
114 TokenStream::from(expanded)
115}
116
117#[proc_macro]
149pub fn match_t(input: TokenStream) -> TokenStream {
150 let input_parsed = match parse_match_t(input) {
151 Ok(parsed) => parsed,
152 Err(e) => return e.to_compile_error().into(),
153 };
154
155 let expr = &input_parsed.expr;
156 let is_move = input_parsed.is_move;
157 let type_hint = &input_parsed.type_hint;
158
159 let hint_generics = type_hint
160 .as_ref()
161 .and_then(|hint| extract_generics_from_type_hint(hint));
162
163 if is_move {
164 let type_checks = input_parsed.arms.iter().enumerate().map(|(idx, arm)| {
165 let pattern = &arm.pattern;
166 let (type_name, _) = extract_type_and_pattern(pattern);
167 let type_name = apply_type_hint_to_pattern(type_name, &hint_generics);
168
169 quote! {
170 if (&*__expr as &dyn std::any::Any).is::<#type_name>() {
171 __matched_idx = Some(#idx);
172 }
173 }
174 });
175
176 let match_arms = input_parsed.arms.iter().enumerate().map(|(idx, arm)| {
177 let pattern = &arm.pattern;
178 let body = &arm.body;
179 let (type_name, pattern_for_match) = extract_type_and_pattern(pattern);
180 let type_name = apply_type_hint_to_pattern(type_name, &hint_generics);
181
182 quote! {
183 #idx => {
184 let __any_box: Box<dyn std::any::Any> = __expr;
185 if let Ok(__concrete_box) = __any_box.downcast::<#type_name>() {
186 match *__concrete_box {
187 #pattern_for_match => #body,
188 _ => panic!("Pattern match failed in match_t!")
189 }
190 } else {
191 panic!("Downcast failed in match_t!");
192 }
193 }
194 }
195 });
196
197 let expanded = quote! {
198 {
199 let __expr = #expr;
200 let mut __matched_idx: Option<usize> = None;
201
202 #(#type_checks)*
203
204 match __matched_idx {
205 Some(__idx) => {
206 match __idx {
207 #(#match_arms,)*
208 _ => panic!("Invalid match index in match_t!")
209 }
210 }
211 None => panic!("No matching type found in match_t!")
212 }
213 }
214 };
215
216 TokenStream::from(expanded)
217 } else {
218 let match_arms = input_parsed.arms.iter().map(|arm| {
219 let pattern = &arm.pattern;
220 let body = &arm.body;
221 let (type_name, pattern_for_match) = extract_type_and_pattern(pattern);
222 let type_name = apply_type_hint_to_pattern(type_name, &hint_generics);
223
224 quote! {
225 if let Some(__value_ref) = (&*__expr as &dyn std::any::Any).downcast_ref::<#type_name>() {
226 if let #pattern_for_match = __value_ref {
227 return Some(#body);
228 }
229 }
230 }
231 });
232
233 let expanded = quote! {
234 {
235 (|| -> Option<_> {
236 let __expr = #expr;
237 #(#match_arms)*
238 None
239 })().expect("No matching type found in match_t!")
240 }
241 };
242
243 TokenStream::from(expanded)
244 }
245}