derive_variants_derive/
lib.rs1use std::collections::HashMap;
2
3use proc_macro2::{Ident, Span};
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Variant};
6
7#[proc_macro_derive(EnumVariants, attributes(variant_derive, variant_attr))]
8pub fn derive_partial(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9 let DeriveInput {
10 attrs,
11 vis,
12 ident,
13 data,
14 ..
15 } = syn::parse(input).unwrap();
16
17 let variant_ident = Ident::new(&format!("{}Variant", ident), Span::call_site());
18
19 let variant_derives = attrs
20 .iter()
21 .find(|attr| attr.path().is_ident("variant_derive"));
22
23 let variant_derives = if let Some(variant_derives) = variant_derives {
24 variant_derives
25 .parse_args()
26 .expect("failed to parse partial_derive")
27 } else {
28 proc_macro2::TokenStream::new()
29 };
30
31 let variant_attrs = attrs
32 .iter()
33 .filter(|attr| attr.path().is_ident("variant_attr"))
34 .map(|attr| {
35 attr
36 .parse_args::<proc_macro2::TokenStream>()
37 .expect("failed to parse variant_attr args")
38 });
39
40 let variants = match data {
41 Data::Enum(e) => e.variants,
42 _ => panic!(""),
43 };
44
45 let variant_variants = variants.iter().map(|v| v.ident.clone());
46
47 let variant_from_full = variants.iter().map(
48 |Variant {
49 ident: v_ident,
50 fields,
51 ..
52 }| {
53 match &fields {
54 Fields::Named(_) => quote!(#ident::#v_ident { .. } => #variant_ident::#v_ident),
55 Fields::Unnamed(_) => {
56 let underscores = fields.iter().map(|_| quote!(_));
57 quote!(#ident::#v_ident(#(#underscores),*) => #variant_ident::#v_ident)
58 }
59 Fields::Unit => quote!(#ident::#v_ident => #variant_ident::#v_ident),
60 }
61 },
62 );
63
64 let derive_extract_data = variants.iter().filter_map(
65 |Variant {
66 ident: v_ident,
67 fields,
68 ..
69 }| {
70 match &fields {
71 Fields::Unnamed(_) => {
72 let data_tys = fields.iter().map(|syn::Field { ty, .. }| ty);
73 let data = quote!((#(#data_tys),*));
74 let idents = fields
75 .iter()
76 .enumerate()
77 .map(|(i, _)| Ident::new(&format!("f{i}"), Span::call_site()));
78 let idents = quote!((#(#idents),*));
79 Some((
80 data.to_string(),
82 quote! {
83 #variant_ident::#v_ident => match self {
87 #ident::#v_ident(#idents) => Ok(#idents),
88 _ => Err(derive_variants::Error::VariantMismatch)
89 }
90 },
94 data,
95 ))
96 }
97 Fields::Unit | Fields::Named(_) => None,
98 }
99 },
100 );
101
102 let mut data_handlers =
104 HashMap::<String, (proc_macro2::TokenStream, Vec<proc_macro2::TokenStream>)>::default();
105 for (key, handler, data) in derive_extract_data {
106 let entry = data_handlers.entry(key).or_default();
107 entry.0 = data;
108 entry.1.push(handler);
109 }
110
111 let data_impls = data_handlers.values().map(|(data, handlers)| {
112 quote! {
113 impl derive_variants::ExtractData<#variant_ident, #data> for #ident {
114 fn extract_data(self, variant: &#variant_ident) -> Result<#data, derive_variants::Error> {
115 match variant {
116 #(#handlers),*
117 _ => Err(derive_variants::Error::WrongVariantForData)
118 }
119 }
120 }
121 }
122 });
123
124 quote! {
125 #[derive(#variant_derives)]
126 #(#variant_attrs)*
127 #vis enum #variant_ident {
128 #(#variant_variants),*
129 }
130
131 impl derive_variants::ExtractVariant<#variant_ident> for #ident {
132 fn extract_variant(&self) -> #variant_ident {
133 match self {
134 #(#variant_from_full),*
135 }
136 }
137 }
138
139 #(#data_impls)*
140 }
141 .into()
142}