1use proc_macro::TokenStream;
2use proc_macro2::{Literal, Span, TokenStream as TokenStream2, TokenTree};
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::*;
5
6#[proc_macro_derive(Finite)]
7pub fn derive_finite(input: TokenStream) -> TokenStream {
8 let input = parse_macro_input!(input as DeriveInput);
9 let name = input.ident;
10 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
11 let (count, index_of, nth) = match input.data {
12 Data::Struct(data) => match data.fields {
13 Fields::Named(fields) => {
14 let mut field_tys = Vec::new();
15 let mut field_idents = Vec::new();
16 for field in fields.named {
17 field_tys.push(field.ty.to_token_stream());
18 field_idents.push(field.ident.to_token_stream());
19 }
20 let count = product_count(&*field_tys);
21 let index_of = product_index_of(&*field_tys, &*field_idents);
22 let nth = product_nth(
23 &*field_tys,
24 quote! { index },
25 &*field_idents,
26 quote! { Self { #(#field_idents),* } },
27 );
28 (
29 quote! { #count },
30 quote! {
31 let Self { #(#field_idents),* } = value;
32 #index_of
33 },
34 quote! {
35 if index < <Self as ::cantor::Finite>::COUNT {
36 Some(#nth)
37 } else {
38 None
39 }
40 }
41 )
42 },
43 Fields::Unnamed(fields) => {
44 let mut field_tys = Vec::new();
45 let mut field_idents = Vec::new();
46 for field in fields.unnamed {
47 field_tys.push(field.ty.to_token_stream());
48 let field_ident = format!("f{}", field_idents.len());
49 let field_ident = Ident::new(&*field_ident, Span::call_site());
50 field_idents.push(field_ident.to_token_stream());
51 }
52 let count = product_count(&*field_tys);
53 let index_of = product_index_of(&*field_tys, &*field_idents);
54 let nth = product_nth(
55 &*field_tys,
56 quote! { index },
57 &*field_idents,
58 quote! { Self(#(#field_idents),*) },
59 );
60 (
61 quote! { #count },
62 quote! {
63 let Self(#(#field_idents),*) = value;
64 #index_of
65 },
66 quote! {
67 if index < <Self as ::cantor::Finite>::COUNT {
68 Some(#nth)
69 } else {
70 None
71 }
72 }
73 )
74 }
75 Fields::Unit => (
76 quote! { 1 },
77 quote! { 0 },
78 quote! {
79 if index < 1 {
80 Some(Self)
81 } else {
82 None
83 }
84 },
85 ),
86 },
87 Data::Enum(data) => {
88 let mut count = SumExpr::new_zero();
90 let mut const_count = SumExpr::new_zero();
91 let mut consts = Vec::new();
92 let mut index_of_arms = Vec::new();
93 let mut nth_arms = Vec::new();
94 for variant in data.variants {
95 let variant_name = variant.ident;
97 let start_index = const_count.get_simple(&mut consts);
98 const_count.set_zero();
99 const_count.add(start_index.clone().into());
100 match variant.fields {
101 Fields::Named(fields) => {
102 let mut field_tys = Vec::new();
103 let mut field_idents = Vec::new();
104 for field in fields.named {
105 field_tys.push(field.ty.to_token_stream());
106 field_idents.push(field.ident.to_token_stream());
107 }
108 let index_of_arm = product_index_of(&*field_tys, &*field_idents);
109 index_of_arms.push(quote! {
110 Self::#variant_name { #(#field_idents),* } => #count + #index_of_arm
111 });
112 let nth_arm = product_nth(
113 &*field_tys,
114 quote! { index - #start_index },
115 &*field_idents,
116 quote! { Self::#variant_name { #(#field_idents),* } },
117 );
118 let variant_count = product_count(&*field_tys);
119 count.add(variant_count.clone());
120 const_count.add(variant_count);
121 const_count.add(NumTerm::Literal(-1));
122 let end_index = const_count.get_simple(&mut consts);
123 const_count.set_zero();
124 const_count.add(end_index.clone().into());
125 const_count.add(NumTerm::Literal(1));
126 nth_arms.push(quote! {
127 #start_index..=#end_index => Some(#nth_arm)
128 });
129 }
130 Fields::Unnamed(fields) => {
131 let mut field_tys = Vec::new();
132 let mut field_idents = Vec::new();
133 for field in fields.unnamed {
134 field_tys.push(field.ty.to_token_stream());
135 let field_ident = format!("f{}", field_idents.len());
136 let field_ident = Ident::new(&*field_ident, Span::call_site());
137 field_idents.push(field_ident.to_token_stream());
138 }
139 let index_of_arm = product_index_of(&*field_tys, &*field_idents);
140 index_of_arms.push(quote! {
141 Self::#variant_name(#(#field_idents),*) => #count + #index_of_arm
142 });
143 let nth_arm = product_nth(
144 &*field_tys,
145 quote! { index - #start_index },
146 &*field_idents,
147 quote! { Self::#variant_name(#(#field_idents),*) },
148 );
149 let variant_count = product_count(&*field_tys);
150 count.add(variant_count.clone());
151 const_count.add(variant_count);
152 const_count.add(NumTerm::Literal(-1));
153 let end_index = const_count.get_simple(&mut consts);
154 const_count.set_zero();
155 const_count.add(end_index.clone().into());
156 const_count.add(NumTerm::Literal(1));
157 nth_arms.push(quote! {
158 #start_index..=#end_index => Some(#nth_arm)
159 });
160 }
161 Fields::Unit => {
162 index_of_arms.push(quote! {
163 Self::#variant_name => #start_index
164 });
165 nth_arms.push(quote! {
166 #start_index => Some(Self::#variant_name)
167 });
168 count.add(NumTerm::Literal(1));
169 const_count.add(NumTerm::Literal(1));
170 }
171 };
172 }
173 nth_arms.push(quote! { _ => None });
174 (
175 quote! { #count },
176 quote! {
177 #(#consts)*
178 match value {
179 #(#index_of_arms,)*
180 }
181 },
182 quote! {
183 #(#consts)*
184 match index {
185 #(#nth_arms,)*
186 }
187 },
188 )
189 }
190 Data::Union(_) => todo!(),
191 };
192
193 let mut res = quote! {
195 #[automatically_derived]
196 unsafe impl #impl_generics ::cantor::Finite for #name #ty_generics #where_clause {
197 const COUNT: usize = #count;
198
199 fn index_of(value: Self) -> usize {
200 #index_of
201 }
202
203 fn nth(index: usize) -> Option<Self> {
204 #nth
205 }
206 }
207 };
208
209 if input.generics.type_params().next().is_none() {
211 res.extend(quote! {
212 ::cantor::impl_concrete_finite!(#name);
213 });
214 }
215
216 TokenStream::from(res)
218}
219
220#[derive(Clone)]
222enum SimpleNumTerm {
223 Literal(i64),
224 Constant(Ident),
225}
226
227impl ToTokens for SimpleNumTerm {
228 fn to_tokens(&self, tokens: &mut TokenStream2) {
229 match self {
230 SimpleNumTerm::Literal(value) => {
231 tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(*value)))
232 }
233 SimpleNumTerm::Constant(ident) => tokens.append(TokenTree::Ident(ident.clone())),
234 }
235 }
236}
237
238enum NonLiteralNumTerm {
240 Constant(Ident),
241 Complex(TokenStream2),
242}
243
244impl ToTokens for NonLiteralNumTerm {
245 fn to_tokens(&self, tokens: &mut TokenStream2) {
246 match self {
247 NonLiteralNumTerm::Constant(ident) => tokens.append(TokenTree::Ident(ident.clone())),
248 NonLiteralNumTerm::Complex(expr) => tokens.extend(expr.clone()),
249 }
250 }
251}
252
253#[derive(Clone)]
255enum NumTerm {
256 Literal(i64),
257 Constant(Ident),
258 Complex(TokenStream2),
259}
260
261impl From<SimpleNumTerm> for NumTerm {
262 fn from(term: SimpleNumTerm) -> Self {
263 match term {
264 SimpleNumTerm::Literal(value) => NumTerm::Literal(value),
265 SimpleNumTerm::Constant(ident) => NumTerm::Constant(ident),
266 }
267 }
268}
269
270impl ToTokens for NumTerm {
271 fn to_tokens(&self, tokens: &mut TokenStream2) {
272 match self {
273 NumTerm::Literal(value) => {
274 tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(*value)))
275 }
276 NumTerm::Constant(ident) => tokens.append(TokenTree::Ident(ident.clone())),
277 NumTerm::Complex(expr) => tokens.extend(expr.clone()),
278 }
279 }
280}
281
282struct SumExpr {
284 lit: i64,
285 non_lit: Vec<NonLiteralNumTerm>,
286}
287
288impl SumExpr {
289 pub fn new_zero() -> Self {
291 Self {
292 lit: 0,
293 non_lit: Vec::new(),
294 }
295 }
296
297 pub fn add(&mut self, value: NumTerm) {
299 match value {
300 NumTerm::Literal(value) => self.lit += value,
301 NumTerm::Constant(value) => self.non_lit.push(NonLiteralNumTerm::Constant(value)),
302 NumTerm::Complex(value) => self.non_lit.push(NonLiteralNumTerm::Complex(value)),
303 }
304 }
305
306 pub fn set_zero(&mut self) {
308 self.lit = 0;
309 self.non_lit.clear();
310 }
311
312 pub fn get_simple(&mut self, consts: &mut Vec<TokenStream2>) -> SimpleNumTerm {
315 if self.non_lit.len() == 0 {
316 return SimpleNumTerm::Literal(self.lit);
317 } else if self.lit == 0 && self.non_lit.len() == 1 {
318 match &self.non_lit[0] {
319 NonLiteralNumTerm::Constant(ident) => {
320 return SimpleNumTerm::Constant(ident.clone());
321 }
322 _ => (),
323 }
324 }
325 let ident = format!("C_{}", consts.len());
326 let ident = Ident::new(&*ident, Span::call_site());
327 consts.push(quote! { const #ident: usize = #self; });
328 SimpleNumTerm::Constant(ident)
329 }
330}
331
332impl ToTokens for SumExpr {
333 fn to_tokens(&self, tokens: &mut TokenStream2) {
334 if let Some((head_non_lit, tail_non_lit)) = self.non_lit.split_first() {
335 if self.lit > 0 {
336 tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(self.lit)));
337 tokens.extend(quote! { + });
338 }
339 tokens.extend(quote! { #head_non_lit #(+ #tail_non_lit)* });
340 if self.lit < 0 {
341 tokens.extend(quote! { - });
342 tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(-self.lit)));
343 }
344 } else {
345 tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(self.lit)));
346 }
347 }
348}
349
350fn product_count(field_tys: &[TokenStream2]) -> NumTerm {
352 if let Some((head_field_ty, tail_field_tys)) = field_tys.split_first() {
353 NumTerm::Complex(quote! {
354 <#head_field_ty as ::cantor::Finite>::COUNT
355 #(* <#tail_field_tys as ::cantor::Finite>::COUNT)*
356 })
357 } else {
358 NumTerm::Literal(1)
359 }
360}
361
362fn product_index_of(field_tys: &[TokenStream2], fields: &[TokenStream2]) -> TokenStream2 {
365 quote! {
366 {
367 let __index = 0;
368 #(let __index = __index *
369 <#field_tys as ::cantor::Finite>::COUNT +
370 <#field_tys as ::cantor::Finite>::index_of(#fields);)*
371 __index
372 }
373 }
374}
375
376fn product_nth(
379 field_tys: &[TokenStream2],
380 index: TokenStream2,
381 fields: &[TokenStream2],
382 cons: TokenStream2,
383) -> TokenStream2 {
384 let field_tys_rev = field_tys.iter().rev();
385 let fields_rev = fields.iter().rev();
386 quote! {
387 {
388 let __index = #index;
389 #(
390 let #fields_rev = <#field_tys_rev as ::cantor::Finite>::nth(__index %
391 <#field_tys_rev as ::cantor::Finite>::COUNT).unwrap();
392 let __index = __index / <#field_tys_rev as ::cantor::Finite>::COUNT;
393 )*
394 #cons
395 }
396 }
397}