1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput, Data, Ident};
6
7#[proc_macro_derive(InsertOnlySet)]
43pub fn generate_add_only_set(input: TokenStream) -> TokenStream {
44 let input = parse_macro_input!(input as DeriveInput);
45
46 let name = &input.ident;
47 let set_name = Ident::new(&format!("{}InsertOnlySet", name), name.span());
48
49 let fields = if let Data::Enum(ref data_enum) = input.data {
50 data_enum.variants.iter().map(|variant| {
51 let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
52 quote! {
53 pub #field_name: std::sync::OnceLock<bool>,
54 }
55 }).collect::<Vec<_>>()
56 } else {
57 vec![]
58 };
59
60 let new_fields_init = if let Data::Enum(ref data_enum) = input.data {
61 data_enum.variants.iter().map(|variant| {
62 let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
63 quote! {
64 #field_name: std::sync::OnceLock::new(),
65 }
66 }).collect::<Vec<_>>()
67 } else {
68 vec![]
69 };
70
71 let insert_methods = if let Data::Enum(ref data_enum) = input.data {
72 data_enum.variants.iter().map(|variant| {
73 let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
74 let variant_name = &variant.ident;
75 quote! {
76 #name::#variant_name => {
77 if self.#field_name.set(true).is_ok() {
78 true
79 } else {
80 false
81 }
82 },
83 }
84 }).collect::<Vec<_>>()
85 } else {
86 vec![]
87 };
88
89 let contains_methods = if let Data::Enum(ref data_enum) = input.data {
90 data_enum.variants.iter().map(|variant| {
91 let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
92 let variant_name = &variant.ident;
93 quote! {
94 #name::#variant_name => self.#field_name.get().copied().unwrap_or(false),
95 }
96 }).collect::<Vec<_>>()
97 } else {
98 vec![]
99 };
100
101 let iter_body = if let Data::Enum(ref data_enum) = input.data {
102 data_enum.variants.iter().map(|variant| {
103 let field_name = Ident::new(&variant.ident.to_string().to_lowercase(), variant.ident.span());
104 let variant_name = &variant.ident;
105 quote! {
106 if self.#field_name.get().copied().unwrap_or(false) {
107 variants.push(#name::#variant_name);
108 }
109 }
110 }).collect::<Vec<_>>()
111 } else {
112 vec![]
113 };
114
115 let expanded = quote! {
116 #[derive(Debug, Default)]
117 pub struct #set_name {
118 #(#fields)*
119 }
120
121 impl #set_name {
122 pub fn new() -> Self {
123 Self {
124 #(#new_fields_init)*
125 }
126 }
127
128 pub fn insert(&self, t: #name) -> bool {
129 match t {
130 #(#insert_methods)*
131 }
132 }
133
134 pub fn contains(&self, t: #name) -> bool {
135 match t {
136 #(#contains_methods)*
137 }
138 }
139
140 pub fn iter(&self) -> impl Iterator<Item = #name> + '_ {
141 let mut variants = Vec::new();
142 #(#iter_body)*
143 variants.into_iter()
144 }
145 }
146
147 impl #name {
148 pub fn InsertOnlySet() -> #set_name {
150 #set_name::new()
151 }
152 }
153 };
154
155 TokenStream::from(expanded)
156}