derive_variants_derive/
lib.rs

1use std::collections::HashMap;
2
3use proc_macro2::{Ident, Span};
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, Variant};
6
7#[proc_macro_derive(EnumVariants, attributes(variant_derive, variant_attr))]
8pub fn derive_partial(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9  let DeriveInput {
10    attrs,
11    vis,
12    ident,
13    data,
14    ..
15  } = syn::parse(input).unwrap();
16
17  let variant_ident = Ident::new(&format!("{}Variant", ident), Span::call_site());
18
19  let variant_derives = attrs
20    .iter()
21    .find(|attr| attr.path().is_ident("variant_derive"));
22
23  let variant_derives = if let Some(variant_derives) = variant_derives {
24    variant_derives
25      .parse_args()
26      .expect("failed to parse partial_derive")
27  } else {
28    proc_macro2::TokenStream::new()
29  };
30
31  let variant_attrs = attrs
32    .iter()
33    .filter(|attr| attr.path().is_ident("variant_attr"))
34    .map(|attr| {
35      attr
36        .parse_args::<proc_macro2::TokenStream>()
37        .expect("failed to parse variant_attr args")
38    });
39
40  let variants = match data {
41    Data::Enum(e) => e.variants,
42    _ => panic!(""),
43  };
44
45  let variant_variants = variants.iter().map(|v| v.ident.clone());
46
47  let variant_from_full = variants.iter().map(
48    |Variant {
49       ident: v_ident,
50       fields,
51       ..
52     }| {
53      match &fields {
54        Fields::Named(_) => quote!(#ident::#v_ident { .. } => #variant_ident::#v_ident),
55        Fields::Unnamed(_) => {
56          let underscores = fields.iter().map(|_| quote!(_));
57          quote!(#ident::#v_ident(#(#underscores),*) => #variant_ident::#v_ident)
58        }
59        Fields::Unit => quote!(#ident::#v_ident => #variant_ident::#v_ident),
60      }
61    },
62  );
63
64  let derive_extract_data = variants.iter().filter_map(
65    |Variant {
66       ident: v_ident,
67       fields,
68       ..
69     }| {
70      match &fields {
71        Fields::Unnamed(_) => {
72          let data_tys = fields.iter().map(|syn::Field { ty, .. }| ty);
73          let data = quote!((#(#data_tys),*));
74          let idents = fields
75            .iter()
76            .enumerate()
77            .map(|(i, _)| Ident::new(&format!("f{i}"), Span::call_site()));
78          let idents = quote!((#(#idents),*));
79          Some((
80            // use as hashmap key
81            data.to_string(),
82            quote! {
83              // impl derive_variants::ExtractData<#variant_ident, #data> for #ident {
84              //   fn extract_data(&self, variant: &#variant_ident) -> #data {
85              //     match variant {
86              #variant_ident::#v_ident => match self {
87                #ident::#v_ident(#idents) => Ok(#idents),
88                _ => Err(derive_variants::Error::VariantMismatch)
89              }
90              //     }
91              //   }
92              // }
93            },
94            data,
95          ))
96        }
97        Fields::Unit | Fields::Named(_) => None,
98      }
99    },
100  );
101
102  // group the handlers by data ty
103  let mut data_handlers =
104    HashMap::<String, (proc_macro2::TokenStream, Vec<proc_macro2::TokenStream>)>::default();
105  for (key, handler, data) in derive_extract_data {
106    let entry = data_handlers.entry(key).or_default();
107    entry.0 = data;
108    entry.1.push(handler);
109  }
110
111  let data_impls = data_handlers.values().map(|(data, handlers)| {
112    quote! {
113      impl derive_variants::ExtractData<#variant_ident, #data> for #ident {
114        fn extract_data(self, variant: &#variant_ident) -> Result<#data, derive_variants::Error> {
115          match variant {
116            #(#handlers),*
117            _ => Err(derive_variants::Error::WrongVariantForData)
118          }
119        }
120      }
121    }
122  });
123
124  quote! {
125    #[derive(#variant_derives)]
126    #(#variant_attrs)*
127    #vis enum #variant_ident {
128      #(#variant_variants),*
129    }
130
131    impl derive_variants::ExtractVariant<#variant_ident> for #ident {
132      fn extract_variant(&self) -> #variant_ident {
133        match self {
134          #(#variant_from_full),*
135        }
136      }
137    }
138
139    #(#data_impls)*
140  }
141  .into()
142}