1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
extern crate proc_macro;

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Data, Ident};

/// A procedural macro to generate an insert-only set for any enum.
///
/// This macro generates a struct with `insert`, `contains`, and `iter` methods for an enum.
/// The struct uses `OnceLock` for thread-safe, one-time insertion of enum variants.
///
/// # Examples
///
/// ```rust
/// use insert_only_set::InsertOnlySet;
///
/// #[derive(InsertOnlySet, Debug, PartialEq)]
/// pub enum Type {
///     Customer,
///     Employee,
/// }
///
/// fn main() {
///     let set = Type::InsertOnlySet();
///
///     assert!(!set.contains(Type::Customer));
///     assert!(!set.contains(Type::Employee));
///
///     set.insert(Type::Customer);
///     assert!(set.contains(Type::Customer));
///     assert!(!set.contains(Type::Employee));
///
///     set.insert(Type::Employee);
///     assert!(set.contains(Type::Customer));
///     assert!(set.contains(Type::Employee));
///
///     for variant in set.iter() {
///         println!("{:?}", variant);
///     }
/// }
/// ```
#[proc_macro_derive(InsertOnlySet)]
pub fn generate_add_only_set(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    let name = &input.ident;
    let set_name = Ident::new(&format!("{}InsertOnlySet", name), name.span());

    let fields = if let Data::Enum(ref data_enum) = input.data {
        data_enum.variants.iter().map(|variant| {
            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
            quote! {
                pub #field_name: std::sync::OnceLock<bool>,
            }
        }).collect::<Vec<_>>()
    } else {
        vec![]
    };

    let new_fields_init = if let Data::Enum(ref data_enum) = input.data {
        data_enum.variants.iter().map(|variant| {
            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
            quote! {
                #field_name: std::sync::OnceLock::new(),
            }
        }).collect::<Vec<_>>()
    } else {
        vec![]
    };

    let insert_methods = if let Data::Enum(ref data_enum) = input.data {
        data_enum.variants.iter().map(|variant| {
            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
            let variant_name = &variant.ident;
            quote! {
                #name::#variant_name => {
                    if self.#field_name.set(true).is_ok() {
                        true
                    } else {
                        false
                    }
                },
            }
        }).collect::<Vec<_>>()
    } else {
        vec![]
    };

    let contains_methods = if let Data::Enum(ref data_enum) = input.data {
        data_enum.variants.iter().map(|variant| {
            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
            let variant_name = &variant.ident;
            quote! {
                #name::#variant_name => self.#field_name.get().copied().unwrap_or(false),
            }
        }).collect::<Vec<_>>()
    } else {
        vec![]
    };

    let iter_body = if let Data::Enum(ref data_enum) = input.data {
        data_enum.variants.iter().map(|variant| {
            let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
            let variant_name = &variant.ident;
            quote! {
                if self.#field_name.get().copied().unwrap_or(false) {
                    variants.push(#name::#variant_name);
                }
            }
        }).collect::<Vec<_>>()
    } else {
        vec![]
    };

    let expanded = quote! {
        pub struct #set_name {
            #(#fields)*
        }

        impl #set_name {
            pub fn new() -> Self {
                Self {
                    #(#new_fields_init)*
                }
            }

            pub fn insert(&self, t: #name) -> bool {
                match t {
                    #(#insert_methods)*
                }
            }

            pub fn contains(&self, t: #name) -> bool {
                match t {
                    #(#contains_methods)*
                }
            }

            pub fn iter(&self) -> impl Iterator<Item = #name> + '_ {
                let mut variants = Vec::new();
                #(#iter_body)*
                variants.into_iter()
            }
        }

        impl #name {
            /// Creates a new, empty insert-only set for this enum.
            pub fn InsertOnlySet() -> #set_name {
                #set_name::new()
            }
        }
    };

    TokenStream::from(expanded)
}