1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{
5 parse::Parse, parse::ParseStream, parse_macro_input, Data, DeriveInput, Ident, LitStr, Token,
6 Type,
7};
8
9struct ConfigArgs {
10 target_type: Ident,
11 prefix: LitStr,
12 extra_types: Vec<Type>,
13}
14
15impl Parse for ConfigArgs {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let target_type: Ident = input.parse()?;
18 input.parse::<Token![,]>()?;
19 let prefix: LitStr = input.parse()?;
20 let mut extra_types = Vec::new();
21 while !input.is_empty() {
22 if input.peek(Token![,]) {
23 input.parse::<Token![,]>()?;
24 }
25 if input.is_empty() {
26 break;
27 }
28 extra_types.push(input.parse()?);
29 }
30 Ok(ConfigArgs {
31 target_type,
32 prefix,
33 extra_types,
34 })
35 }
36}
37
38fn to_upper_snake_case(s: &str) -> String {
39 s.to_case(Case::UpperSnake)
40}
41
42fn get_variant_prefix(variant: &syn::Variant, default_prefix: &str) -> String {
43 for attr in &variant.attrs {
44 if attr.path().is_ident("prefix") {
45 if let Ok(lit_str) = attr.parse_args::<LitStr>() {
46 return lit_str.value();
47 }
48 }
49 }
50 default_prefix.to_string()
51}
52
53fn get_target_variant(variant: &syn::Variant, default_prefix: &str) -> Ident {
54 for attr in &variant.attrs {
55 if attr.path().is_ident("alias") {
56 if let Ok(lit_str) = attr.parse_args::<LitStr>() {
57 return format_ident!("{}", lit_str.value());
58 }
59 }
60 }
61
62 let variant_prefix = get_variant_prefix(variant, default_prefix);
63 for attr in &variant.attrs {
64 if attr.path().is_ident("suffix") {
65 if let Ok(lit_str) = attr.parse_args::<LitStr>() {
66 return format_ident!("{}{}", variant_prefix, lit_str.value());
67 }
68 }
69 }
70
71 let upper_snake_variant = to_upper_snake_case(&variant.ident.to_string());
72 format_ident!("{}{}", variant_prefix, upper_snake_variant)
73}
74
75#[proc_macro_derive(EnumFrom, attributes(config, prefix, suffix, alias))]
76pub fn enum_from(input: TokenStream) -> TokenStream {
77 let input = parse_macro_input!(input as DeriveInput);
78 let name = &input.ident;
79
80 let args = input
81 .attrs
82 .iter()
83 .find(|attr| attr.path().is_ident("config"))
84 .map(|attr| attr.parse_args::<ConfigArgs>())
85 .expect("config attribute is required")
86 .expect("Failed to parse config attribute");
87
88 let target_type = args.target_type;
89 let default_prefix = args.prefix.value();
90 let extra_types = args.extra_types;
91
92 let variants = match &input.data {
93 Data::Enum(data_enum) => &data_enum.variants,
94 _ => panic!("EnumFrom can only be derived for enums"),
95 };
96
97 let from_attribute_type_arms: Vec<_> = variants
98 .iter()
99 .map(|v| {
100 let variant = &v.ident;
101 let target_variant = get_target_variant(v, &default_prefix);
102 quote! {
103 #name::#variant => #target_variant,
104 }
105 })
106 .collect();
107
108 let from_target_type_arms: Vec<_> = variants
109 .iter()
110 .map(|v| {
111 let variant = &v.ident;
112 let target_variant = get_target_variant(v, &default_prefix);
113 quote! {
114 #target_variant => #name::#variant,
115 }
116 })
117 .collect();
118
119 let try_from_target_type_arms: Vec<_> = variants
120 .iter()
121 .map(|v| {
122 let variant = &v.ident;
123 let target_variant = get_target_variant(v, &default_prefix);
124 quote! {
125 #target_variant => ::std::option::Option::Some(#name::#variant),
126 }
127 })
128 .collect();
129
130 let extra_from_attribute_type_impls: Vec<_> = extra_types
131 .iter()
132 .map(|extra_type| {
133 quote! {
134 impl From<#name> for #extra_type {
135 fn from(attr: #name) -> Self {
136 let raw: #target_type = attr.into();
137 raw as #extra_type
138 }
139 }
140 }
141 })
142 .collect();
143
144 let extra_from_target_type_impls: Vec<_> = extra_types
145 .iter()
146 .map(|extra_type| {
147 quote! {
148 impl From<#extra_type> for #name {
149 fn from(attr: #extra_type) -> Self {
150 let raw = attr as #target_type;
151 raw.into()
152 }
153 }
154 }
155 })
156 .collect();
157
158 let expanded = quote! {
159 impl From<#name> for #target_type {
160 fn from(attr: #name) -> Self {
161 match attr {
162 #(#from_attribute_type_arms)*
163 _ => unreachable!("Invalid attribute value"),
164 }
165 }
166 }
167
168 impl From<#target_type> for #name {
169 fn from(attr: #target_type) -> Self {
170 match attr {
171 #(#from_target_type_arms)*
172 _ => unreachable!("Invalid attribute value"),
173 }
174 }
175 }
176
177 impl #name {
178 pub fn try_from_raw(attr: #target_type) -> ::std::option::Option<Self> {
179 match attr {
180 #(#try_from_target_type_arms)*
181 _ => ::std::option::Option::None,
182 }
183 }
184 }
185
186 #(#extra_from_attribute_type_impls)*
187
188 #(#extra_from_target_type_impls)*
189 };
190
191 TokenStream::from(expanded)
192}