derive_discriminant/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields};
6
7#[proc_macro_derive(Discriminant)]
8pub fn discriminant_derive(input: TokenStream) -> TokenStream {
9    let ast = parse_macro_input!(input as DeriveInput);
10    impl_discriminant_macro(ast)
11}
12
13#[allow(clippy::too_many_lines)] // TODO: fix too_many_lines allow
14fn impl_discriminant_macro(ast: DeriveInput) -> TokenStream {
15    let name = &ast.ident;
16    let vis = &ast.vis;
17
18    // all non-doc attributes
19    let global_attrs: Vec<_> = ast
20        .attrs
21        .into_iter()
22        .filter(|attr| !attr.path().is_ident("doc"))
23        .collect();
24
25    let Data::Enum(data_enum) = ast.data else {
26        panic!("Discriminant can only be derived for enums");
27    };
28
29    let variant_names: Vec<_> = data_enum
30        .variants
31        .iter()
32        .map(|variant| &variant.ident)
33        .collect();
34
35    // implementation for the .cast() method to cast into a trait object
36    // this requires nightly
37    let cast_method = quote! {
38        impl #name {
39            #vis fn cast<U: ?Sized>(self) -> Box<U> where #(#variant_names: ::core::marker::Unsize<U>),* {
40                let value = self;
41                // TODO: use a singular match expression
42                #(
43                    let value = match #variant_names::try_from(value) {
44                        Ok(v) => {
45                            let x = Box::new(v);
46                            return x;
47                        }
48                        Err(v) => v,
49                    };
50                )*
51
52                unreachable!();
53            }
54        }
55    };
56
57    let variant_impls = data_enum.variants.into_iter().map(|variant| {
58        let variant_name = &variant.ident;
59        let fields = &variant.fields;
60        let variant_attrs = variant.attrs;
61
62        let is_variant_name: syn::Ident = {
63            let lowercase = variant_name.to_string().to_lowercase();
64            let name = format!("is_{lowercase}");
65            syn::parse_str(&name).expect("failed to parse variant name")
66        };
67
68        match fields {
69            Fields::Unit => {
70                quote! {
71                    impl From<#variant_name> for #name {
72                        fn from(value: #variant_name) -> Self {
73                            Self::#variant_name
74                        }
75                    }
76
77                    impl std::convert::TryFrom<#name> for #variant_name {
78                        type Error = #name;
79
80                        fn try_from(value: #name) -> Result<Self, Self::Error> {
81                            if let #name::#variant_name = value {
82                                Ok(#variant_name)
83                            } else {
84                                Err(value)
85                            }
86                        }
87                    }
88
89                    impl #name {
90                        #vis fn #is_variant_name(&self) -> bool {
91                            matches!(self, Self::#variant_name)
92                        }
93                    }
94
95                    #(#global_attrs)*
96                    #(#variant_attrs)*
97                    #vis struct #variant_name;
98                }
99            }
100            _ => {
101                let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
102                let field_type = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
103
104                quote! {
105                    impl From<#variant_name> for #name {
106                        fn from(value: #variant_name) -> Self {
107                            Self::#variant_name {
108                                #(#field_name: value.#field_name),*
109                            }
110                        }
111                    }
112
113                    impl std::convert::TryFrom<#name> for #variant_name {
114                        type Error = #name;
115
116                        fn try_from(value: #name) -> Result<Self, Self::Error> {
117                            if let #name::#variant_name { #(#field_name),* } = value {
118                                Ok(#variant_name {
119                                    #(#field_name),*
120                                })
121                            } else {
122                                Err(value)
123                            }
124                        }
125                    }
126
127                    impl #name {
128                        #vis fn #is_variant_name(&self) -> bool {
129                            matches!(self, Self::#variant_name { .. })
130                        }
131                    }
132
133                    #(#global_attrs)*
134                    #(#variant_attrs)*
135                    #vis struct #variant_name {
136                        #(#vis #field_name: #field_type),*
137                    }
138                }
139            }
140        }
141    });
142
143    let output = quote! {
144        #(#variant_impls)*
145        #cast_method
146    };
147
148    TokenStream::from(output)
149}