derive_discriminant/
lib.rs1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields};
6
7#[proc_macro_derive(Discriminant)]
8pub fn discriminant_derive(input: TokenStream) -> TokenStream {
9 let ast = parse_macro_input!(input as DeriveInput);
10 impl_discriminant_macro(ast)
11}
12
13#[allow(clippy::too_many_lines)] fn impl_discriminant_macro(ast: DeriveInput) -> TokenStream {
15 let name = &ast.ident;
16 let vis = &ast.vis;
17
18 let global_attrs: Vec<_> = ast
20 .attrs
21 .into_iter()
22 .filter(|attr| !attr.path().is_ident("doc"))
23 .collect();
24
25 let Data::Enum(data_enum) = ast.data else {
26 panic!("Discriminant can only be derived for enums");
27 };
28
29 let variant_names: Vec<_> = data_enum
30 .variants
31 .iter()
32 .map(|variant| &variant.ident)
33 .collect();
34
35 let cast_method = quote! {
38 impl #name {
39 #vis fn cast<U: ?Sized>(self) -> Box<U> where #(#variant_names: ::core::marker::Unsize<U>),* {
40 let value = self;
41 #(
43 let value = match #variant_names::try_from(value) {
44 Ok(v) => {
45 let x = Box::new(v);
46 return x;
47 }
48 Err(v) => v,
49 };
50 )*
51
52 unreachable!();
53 }
54 }
55 };
56
57 let variant_impls = data_enum.variants.into_iter().map(|variant| {
58 let variant_name = &variant.ident;
59 let fields = &variant.fields;
60 let variant_attrs = variant.attrs;
61
62 let is_variant_name: syn::Ident = {
63 let lowercase = variant_name.to_string().to_lowercase();
64 let name = format!("is_{lowercase}");
65 syn::parse_str(&name).expect("failed to parse variant name")
66 };
67
68 match fields {
69 Fields::Unit => {
70 quote! {
71 impl From<#variant_name> for #name {
72 fn from(value: #variant_name) -> Self {
73 Self::#variant_name
74 }
75 }
76
77 impl std::convert::TryFrom<#name> for #variant_name {
78 type Error = #name;
79
80 fn try_from(value: #name) -> Result<Self, Self::Error> {
81 if let #name::#variant_name = value {
82 Ok(#variant_name)
83 } else {
84 Err(value)
85 }
86 }
87 }
88
89 impl #name {
90 #vis fn #is_variant_name(&self) -> bool {
91 matches!(self, Self::#variant_name)
92 }
93 }
94
95 #(#global_attrs)*
96 #(#variant_attrs)*
97 #vis struct #variant_name;
98 }
99 }
100 _ => {
101 let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
102 let field_type = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
103
104 quote! {
105 impl From<#variant_name> for #name {
106 fn from(value: #variant_name) -> Self {
107 Self::#variant_name {
108 #(#field_name: value.#field_name),*
109 }
110 }
111 }
112
113 impl std::convert::TryFrom<#name> for #variant_name {
114 type Error = #name;
115
116 fn try_from(value: #name) -> Result<Self, Self::Error> {
117 if let #name::#variant_name { #(#field_name),* } = value {
118 Ok(#variant_name {
119 #(#field_name),*
120 })
121 } else {
122 Err(value)
123 }
124 }
125 }
126
127 impl #name {
128 #vis fn #is_variant_name(&self) -> bool {
129 matches!(self, Self::#variant_name { .. })
130 }
131 }
132
133 #(#global_attrs)*
134 #(#variant_attrs)*
135 #vis struct #variant_name {
136 #(#vis #field_name: #field_type),*
137 }
138 }
139 }
140 }
141 });
142
143 let output = quote! {
144 #(#variant_impls)*
145 #cast_method
146 };
147
148 TokenStream::from(output)
149}