filterable_enum_derive/
lib.rs

1// MIT License
2
3// Copyright (c) 2022 Gino Valente
4// Copyright (c) 2024 Levi Zim
5
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to deal
8// in the Software without restriction, including without limitation the rights
9// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10// copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12
13// The above copyright notice and this permission notice shall be included in all
14// copies or substantial portions of the Software.
15
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22// SOFTWARE.
23
24use darling::FromDeriveInput;
25use proc_macro::TokenStream;
26use proc_macro_crate::{crate_name, FoundCrate};
27use quote::{format_ident, quote};
28use syn::Data;
29
30mod opts;
31
32#[proc_macro_derive(FilterableEnum, attributes(filterable_enum))]
33pub fn derive_filterable_enum(ts: TokenStream) -> TokenStream {
34    let input = syn::parse_macro_input!(ts as syn::DeriveInput);
35
36    let attrs = self::opts::Opts::from_derive_input(&input).unwrap();
37
38    let Data::Enum(data) = input.data else {
39        return syn::Error::new_spanned(
40            input,
41            "Cannot derive FilterableEnum on this type, expected enum",
42        )
43        .to_compile_error()
44        .into();
45    };
46
47    let kind_extra_derive = attrs.kind_extra_derive;
48    let repr = format_ident!("{}", attrs.repr.as_deref().unwrap_or("u32"));
49    let kind_extra_attrs = attrs.kind_extra_attrs.iter().map(|attr| {
50        let attr = syn::parse_str::<syn::Meta>(attr).unwrap();
51        quote!(#[#attr])
52    });
53
54    let vis = &input.vis;
55    let ident = &input.ident;
56    let kinds = data
57        .variants
58        .iter()
59        .map(|variant| &variant.ident)
60        .collect::<Vec<_>>();
61    let patterns = data.variants.iter().map(|v| match &v.fields {
62        syn::Fields::Unit => quote! {},
63        syn::Fields::Named(_) => quote! { { .. } },
64        syn::Fields::Unnamed(_) => quote! { (_) },
65    });
66    let filterable_enum = get_crate("filterable-enum");
67    let ident_kind = format_ident!("{}Kind", ident);
68    let ident_filterable = format_ident!("Filterable{}", ident);
69    let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
70
71    TokenStream::from(quote::quote! {
72        // Create EnumKind
73        #[#filterable_enum::enumflags2::bitflags]
74        #[repr(#repr)]
75        #[derive(Debug, PartialEq, Eq, Clone, Copy, #(#kind_extra_derive,)*)]
76        #(#kind_extra_attrs)*
77        #vis enum #ident_kind {
78            #(#kinds,)*
79        }
80
81        impl #filterable_enum::EnumFilter<#ident_kind> for #filterable_enum::enumflags2::BitFlags<#ident_kind> {
82            fn contains(&self, id: #ident_kind) -> bool {
83                self.intersects(id)
84            }
85        }
86
87        #vis struct #ident_filterable #ty_generics {
88            inner: #ident #ty_generics,
89            id: #ident_kind,
90        }
91
92        impl #impl_generics #filterable_enum::FilterableEnum<#ident #ty_generics> for #ident_filterable #where_clause {
93            type Id = #ident_kind;
94            type Filter = #filterable_enum::enumflags2::BitFlags<#ident_kind>;
95
96            fn filterable_id(&self) -> Self::Id {
97                self.id
98            }
99
100            fn filter_ref(&self, filter: impl Into<Self::Filter>) -> Option<&#ident> {
101                if filter.into().contains(self.id) {
102                    Some(&self.inner)
103                } else {
104                    None
105                }
106            }
107
108            fn filter_and_take(self, filter: impl Into<Self::Filter>) -> Option<#ident> {
109                if filter.into().contains(self.id) {
110                    Some(self.inner)
111                } else {
112                    None
113                }
114            }
115        }
116
117        impl #impl_generics From<#ident #ty_generics> for #ident_filterable #ty_generics #where_clause {
118            fn from(inner: #ident #ty_generics) -> Self {
119                let id = match inner {
120                    #(
121                        #ident::#kinds #patterns => #ident_kind::#kinds,
122                    )*
123                };
124                #ident_filterable { inner, id }
125            }
126        }
127    })
128}
129
130fn get_crate(name: &str) -> proc_macro2::TokenStream {
131    let found_crate =
132        crate_name(name).unwrap_or_else(|_| panic!("`{}` not found in `Cargo.toml`", name));
133
134    match found_crate {
135        FoundCrate::Itself => quote!(crate),
136        FoundCrate::Name(name) => {
137            let ident = format_ident!("{}", &name);
138            quote!( #ident )
139        }
140    }
141}