enum_visitor_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Data, DeriveInput, Fields};
4
5#[proc_macro_derive(VisitEnum)]
15pub fn derive_visit_enum(input: TokenStream) -> TokenStream {
16 let input = parse_macro_input!(input as DeriveInput);
17
18 let enum_ident = input.ident;
19 let enum_name = enum_ident.to_string();
20 let macro_ident = format_ident!("visit_{}", to_snake_case(&enum_name));
21
22 let data_enum = match input.data {
23 Data::Enum(e) => e,
24 _ => {
25 return syn::Error::new_spanned(enum_ident, "VisitEnum can only be derived for enums")
26 .to_compile_error()
27 .into();
28 }
29 };
30
31 let mut arms_expr = Vec::new();
33 let mut arms_block = Vec::new();
34
35 for variant in data_enum.variants.iter() {
36 let v_ident = &variant.ident;
37 match &variant.fields {
38 Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => {
39 }
41 _ => {
42 return syn::Error::new_spanned(
43 &variant.ident,
44 "VisitEnum only supports tuple variants with exactly 1 field",
45 )
46 .to_compile_error()
47 .into();
48 }
49 }
50
51 arms_expr.push(quote! { #enum_ident::#v_ident($v) => { $body } });
53 arms_block.push(quote! { #enum_ident::#v_ident($v) => { $($tt)* } });
55 }
56
57 let gen = quote! {
61 #[allow(non_snake_case, unused_macros)]
64 macro_rules! #macro_ident {
65 ($expr:expr, |$v:pat_param| $body:expr $(,)?) => {{
66 match $expr {
67 #( #arms_expr ),*
68 }
69 }};
70 ($expr:expr, |$v:pat_param| { $($tt:tt)* } $(,)?) => {{
71 match $expr {
72 #( #arms_block ),*
73 }
74 }};
75 }
76
77 #[allow(unused_macros)]
79 macro_rules! visit {
80 ($expr:expr, |$v:pat_param| $body:expr $(,)?) => { #macro_ident!($expr, |$v| $body) };
81 ($expr:expr, |$v:pat_param| { $($tt:tt)* } $(,)?) => { #macro_ident!($expr, |$v| { $($tt)* }) };
82 }
83 };
84
85 gen.into()
86}
87
88fn to_snake_case(name: &str) -> String {
89 let mut out = String::new();
91 for (i, ch) in name.chars().enumerate() {
92 if ch.is_uppercase() {
93 if i != 0 {
94 out.push('_');
95 }
96 for c in ch.to_lowercase() {
97 out.push(c);
98 }
99 } else {
100 out.push(ch);
101 }
102 }
103 out
104}