insert_only_set/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput, Data, Ident};
6
7/// A procedural macro to generate an insert-only set for any enum.
8///
9/// This macro generates a struct with `insert`, `contains`, and `iter` methods for an enum.
10/// The struct uses `OnceLock` for thread-safe, one-time insertion of enum variants.
11///
12/// # Examples
13///
14/// ```rust
15/// use insert_only_set::InsertOnlySet;
16///
17/// #[derive(InsertOnlySet, Debug, PartialEq)]
18/// pub enum Type {
19///     Customer,
20///     Employee,
21/// }
22///
23/// fn main() {
24///     let set = Type::InsertOnlySet();
25///
26///     assert!(!set.contains(Type::Customer));
27///     assert!(!set.contains(Type::Employee));
28///
29///     set.insert(Type::Customer);
30///     assert!(set.contains(Type::Customer));
31///     assert!(!set.contains(Type::Employee));
32///
33///     set.insert(Type::Employee);
34///     assert!(set.contains(Type::Customer));
35///     assert!(set.contains(Type::Employee));
36///
37///     for variant in set.iter() {
38///         println!("{:?}", variant);
39///     }
40/// }
41/// ```
42#[proc_macro_derive(InsertOnlySet)]
43pub fn generate_add_only_set(input: TokenStream) -> TokenStream {
44    let input = parse_macro_input!(input as DeriveInput);
45
46    let name = &input.ident;
47    let set_name = Ident::new(&format!("{}InsertOnlySet", name), name.span());
48
49    let fields = if let Data::Enum(ref data_enum) = input.data {
50        data_enum.variants.iter().map(|variant| {
51            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
52            quote! {
53                pub #field_name: std::sync::OnceLock<bool>,
54            }
55        }).collect::<Vec<_>>()
56    } else {
57        vec![]
58    };
59
60    let new_fields_init = if let Data::Enum(ref data_enum) = input.data {
61        data_enum.variants.iter().map(|variant| {
62            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
63            quote! {
64                #field_name: std::sync::OnceLock::new(),
65            }
66        }).collect::<Vec<_>>()
67    } else {
68        vec![]
69    };
70
71    let insert_methods = if let Data::Enum(ref data_enum) = input.data {
72        data_enum.variants.iter().map(|variant| {
73            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
74            let variant_name = &variant.ident;
75            quote! {
76                #name::#variant_name => {
77                    if self.#field_name.set(true).is_ok() {
78                        true
79                    } else {
80                        false
81                    }
82                },
83            }
84        }).collect::<Vec<_>>()
85    } else {
86        vec![]
87    };
88
89    let contains_methods = if let Data::Enum(ref data_enum) = input.data {
90        data_enum.variants.iter().map(|variant| {
91            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
92            let variant_name = &variant.ident;
93            quote! {
94                #name::#variant_name => self.#field_name.get().copied().unwrap_or(false),
95            }
96        }).collect::<Vec<_>>()
97    } else {
98        vec![]
99    };
100
101    let iter_body = if let Data::Enum(ref data_enum) = input.data {
102        data_enum.variants.iter().map(|variant| {
103            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
104            let variant_name = &variant.ident;
105            quote! {
106                if self.#field_name.get().copied().unwrap_or(false) {
107                    variants.push(#name::#variant_name);
108                }
109            }
110        }).collect::<Vec<_>>()
111    } else {
112        vec![]
113    };
114
115    let expanded = quote! {
116        #[derive(Debug, Default)]
117        pub struct #set_name {
118            #(#fields)*
119        }
120
121        impl #set_name {
122            pub fn new() -> Self {
123                Self {
124                    #(#new_fields_init)*
125                }
126            }
127
128            pub fn insert(&self, t: #name) -> bool {
129                match t {
130                    #(#insert_methods)*
131                }
132            }
133
134            pub fn contains(&self, t: #name) -> bool {
135                match t {
136                    #(#contains_methods)*
137                }
138            }
139
140            pub fn iter(&self) -> impl Iterator<Item = #name> + '_ {
141                let mut variants = Vec::new();
142                #(#iter_body)*
143                variants.into_iter()
144            }
145        }
146
147        impl #name {
148            /// Creates a new, empty insert-only set for this enum.
149            pub fn InsertOnlySet() -> #set_name {
150                #set_name::new()
151            }
152        }
153    };
154
155    TokenStream::from(expanded)
156}