cantor_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Literal, Span, TokenStream as TokenStream2, TokenTree};
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::*;
5
6#[proc_macro_derive(Finite)]
7pub fn derive_finite(input: TokenStream) -> TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9    let name = input.ident;
10    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
11    let (count, index_of, nth) = match input.data {
12        Data::Struct(data) => match data.fields {
13            Fields::Named(fields) => {
14                let mut field_tys = Vec::new();
15                let mut field_idents = Vec::new();
16                for field in fields.named {
17                    field_tys.push(field.ty.to_token_stream());
18                    field_idents.push(field.ident.to_token_stream());
19                }
20                let count = product_count(&*field_tys);
21                let index_of = product_index_of(&*field_tys, &*field_idents);
22                let nth = product_nth(
23                    &*field_tys,
24                    quote! { index },
25                    &*field_idents,
26                    quote! { Self { #(#field_idents),* } },
27                );
28                (
29                    quote! { #count }, 
30                    quote! {
31                        let Self { #(#field_idents),* } = value;
32                        #index_of
33                    },
34                    quote! {
35                        if index < <Self as ::cantor::Finite>::COUNT {
36                            Some(#nth)
37                        } else {
38                            None
39                        }
40                    }
41                )
42            },
43            Fields::Unnamed(fields) => {
44                let mut field_tys = Vec::new();
45                let mut field_idents = Vec::new();
46                for field in fields.unnamed {
47                    field_tys.push(field.ty.to_token_stream());
48                    let field_ident = format!("f{}", field_idents.len());
49                    let field_ident = Ident::new(&*field_ident, Span::call_site());
50                    field_idents.push(field_ident.to_token_stream());
51                }
52                let count = product_count(&*field_tys);
53                let index_of = product_index_of(&*field_tys, &*field_idents);
54                let nth = product_nth(
55                    &*field_tys,
56                    quote! { index },
57                    &*field_idents,
58                    quote! { Self(#(#field_idents),*) },
59                );
60                (
61                    quote! { #count }, 
62                    quote! {
63                        let Self(#(#field_idents),*) = value;
64                        #index_of
65                    },
66                    quote! {
67                        if index < <Self as ::cantor::Finite>::COUNT {
68                            Some(#nth)
69                        } else {
70                            None
71                        }
72                    }
73                )
74            }
75            Fields::Unit => (
76                quote! { 1 },
77                quote! { 0 },
78                quote! {
79                    if index < 1 {
80                        Some(Self)
81                    } else {
82                        None
83                    }
84                },
85            ),
86        },
87        Data::Enum(data) => {
88            // Gather info from variants
89            let mut count = SumExpr::new_zero();
90            let mut const_count = SumExpr::new_zero();
91            let mut consts = Vec::new();
92            let mut index_of_arms = Vec::new();
93            let mut nth_arms = Vec::new();
94            for variant in data.variants {
95                // Consider the different types of variant definitions
96                let variant_name = variant.ident;
97                let start_index = const_count.get_simple(&mut consts);
98                const_count.set_zero();
99                const_count.add(start_index.clone().into());
100                match variant.fields {
101                    Fields::Named(fields) => {
102                        let mut field_tys = Vec::new();
103                        let mut field_idents = Vec::new();
104                        for field in fields.named {
105                            field_tys.push(field.ty.to_token_stream());
106                            field_idents.push(field.ident.to_token_stream());
107                        }
108                        let index_of_arm = product_index_of(&*field_tys, &*field_idents);
109                        index_of_arms.push(quote! {
110                            Self::#variant_name { #(#field_idents),* } => #count + #index_of_arm
111                        });
112                        let nth_arm = product_nth(
113                            &*field_tys,
114                            quote! { index - #start_index },
115                            &*field_idents,
116                            quote! { Self::#variant_name { #(#field_idents),* } },
117                        );
118                        let variant_count = product_count(&*field_tys);
119                        count.add(variant_count.clone());
120                        const_count.add(variant_count);
121                        const_count.add(NumTerm::Literal(-1));
122                        let end_index = const_count.get_simple(&mut consts);
123                        const_count.set_zero();
124                        const_count.add(end_index.clone().into());
125                        const_count.add(NumTerm::Literal(1));
126                        nth_arms.push(quote! {
127                            #start_index..=#end_index => Some(#nth_arm)
128                        });
129                    }
130                    Fields::Unnamed(fields) => {
131                        let mut field_tys = Vec::new();
132                        let mut field_idents = Vec::new();
133                        for field in fields.unnamed {
134                            field_tys.push(field.ty.to_token_stream());
135                            let field_ident = format!("f{}", field_idents.len());
136                            let field_ident = Ident::new(&*field_ident, Span::call_site());
137                            field_idents.push(field_ident.to_token_stream());
138                        }
139                        let index_of_arm = product_index_of(&*field_tys, &*field_idents);
140                        index_of_arms.push(quote! {
141                            Self::#variant_name(#(#field_idents),*) => #count + #index_of_arm
142                        });
143                        let nth_arm = product_nth(
144                            &*field_tys,
145                            quote! { index - #start_index },
146                            &*field_idents,
147                            quote! { Self::#variant_name(#(#field_idents),*) },
148                        );
149                        let variant_count = product_count(&*field_tys);
150                        count.add(variant_count.clone());
151                        const_count.add(variant_count);
152                        const_count.add(NumTerm::Literal(-1));
153                        let end_index = const_count.get_simple(&mut consts);
154                        const_count.set_zero();
155                        const_count.add(end_index.clone().into());
156                        const_count.add(NumTerm::Literal(1));
157                        nth_arms.push(quote! {
158                            #start_index..=#end_index => Some(#nth_arm)
159                        });
160                    }
161                    Fields::Unit => {
162                        index_of_arms.push(quote! {
163                            Self::#variant_name => #start_index
164                        });
165                        nth_arms.push(quote! {
166                            #start_index => Some(Self::#variant_name)
167                        });
168                        count.add(NumTerm::Literal(1));
169                        const_count.add(NumTerm::Literal(1));
170                    }
171                };
172            }
173            nth_arms.push(quote! { _ => None });
174            (
175                quote! { #count },
176                quote! {
177                    #(#consts)*
178                    match value {
179                        #(#index_of_arms,)*
180                    }
181                },
182                quote! {
183                    #(#consts)*
184                    match index {
185                        #(#nth_arms,)*
186                    }
187                },
188            )
189        }
190        Data::Union(_) => todo!(),
191    };
192
193    // Build implementation
194    let mut res = quote! {
195        #[automatically_derived]
196        unsafe impl #impl_generics ::cantor::Finite for #name #ty_generics #where_clause {
197            const COUNT: usize = #count;
198
199            fn index_of(value: Self) -> usize {
200                #index_of
201            }
202
203            fn nth(index: usize) -> Option<Self> {
204                #nth
205            }
206        }
207    };
208
209    // If this is a concrete type (no generic parameters), also implement helper traits.
210    if input.generics.type_params().next().is_none() {
211        res.extend(quote! {
212            ::cantor::impl_concrete_finite!(#name);
213        });
214    }
215
216    // Return final result
217    TokenStream::from(res)
218}
219
220/// A [`NumTerm`] that can be used as a range bound.
221#[derive(Clone)]
222enum SimpleNumTerm {
223    Literal(i64),
224    Constant(Ident),
225}
226
227impl ToTokens for SimpleNumTerm {
228    fn to_tokens(&self, tokens: &mut TokenStream2) {
229        match self {
230            SimpleNumTerm::Literal(value) => {
231                tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(*value)))
232            }
233            SimpleNumTerm::Constant(ident) => tokens.append(TokenTree::Ident(ident.clone())),
234        }
235    }
236}
237
238/// A [`NumTerm`] which is not a literal.
239enum NonLiteralNumTerm {
240    Constant(Ident),
241    Complex(TokenStream2),
242}
243
244impl ToTokens for NonLiteralNumTerm {
245    fn to_tokens(&self, tokens: &mut TokenStream2) {
246        match self {
247            NonLiteralNumTerm::Constant(ident) => tokens.append(TokenTree::Ident(ident.clone())),
248            NonLiteralNumTerm::Complex(expr) => tokens.extend(expr.clone()),
249        }
250    }
251}
252
253/// A term which provides a number.
254#[derive(Clone)]
255enum NumTerm {
256    Literal(i64),
257    Constant(Ident),
258    Complex(TokenStream2),
259}
260
261impl From<SimpleNumTerm> for NumTerm {
262    fn from(term: SimpleNumTerm) -> Self {
263        match term {
264            SimpleNumTerm::Literal(value) => NumTerm::Literal(value),
265            SimpleNumTerm::Constant(ident) => NumTerm::Constant(ident),
266        }
267    }
268}
269
270impl ToTokens for NumTerm {
271    fn to_tokens(&self, tokens: &mut TokenStream2) {
272        match self {
273            NumTerm::Literal(value) => {
274                tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(*value)))
275            }
276            NumTerm::Constant(ident) => tokens.append(TokenTree::Ident(ident.clone())),
277            NumTerm::Complex(expr) => tokens.extend(expr.clone()),
278        }
279    }
280}
281
282/// An expression for a sum of values.
283struct SumExpr {
284    lit: i64,
285    non_lit: Vec<NonLiteralNumTerm>,
286}
287
288impl SumExpr {
289    /// Creates a [`SumExpr`] with an initial value of zero.
290    pub fn new_zero() -> Self {
291        Self {
292            lit: 0,
293            non_lit: Vec::new(),
294        }
295    }
296
297    /// Adds a value to this expression.
298    pub fn add(&mut self, value: NumTerm) {
299        match value {
300            NumTerm::Literal(value) => self.lit += value,
301            NumTerm::Constant(value) => self.non_lit.push(NonLiteralNumTerm::Constant(value)),
302            NumTerm::Complex(value) => self.non_lit.push(NonLiteralNumTerm::Complex(value)),
303        }
304    }
305
306    /// Sets this expression to 0.
307    pub fn set_zero(&mut self) {
308        self.lit = 0;
309        self.non_lit.clear();
310    }
311
312    /// Gets a [`SimpleNumTerm`] representation of this expression, assuming its possible to define
313    /// an arbitrary constant ahead of time.
314    pub fn get_simple(&mut self, consts: &mut Vec<TokenStream2>) -> SimpleNumTerm {
315        if self.non_lit.len() == 0 {
316            return SimpleNumTerm::Literal(self.lit);
317        } else if self.lit == 0 && self.non_lit.len() == 1 {
318            match &self.non_lit[0] {
319                NonLiteralNumTerm::Constant(ident) => {
320                    return SimpleNumTerm::Constant(ident.clone());
321                }
322                _ => (),
323            }
324        }
325        let ident = format!("C_{}", consts.len());
326        let ident = Ident::new(&*ident, Span::call_site());
327        consts.push(quote! { const #ident: usize = #self; });
328        SimpleNumTerm::Constant(ident)
329    }
330}
331
332impl ToTokens for SumExpr {
333    fn to_tokens(&self, tokens: &mut TokenStream2) {
334        if let Some((head_non_lit, tail_non_lit)) = self.non_lit.split_first() {
335            if self.lit > 0 {
336                tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(self.lit)));
337                tokens.extend(quote! { + });
338            }
339            tokens.extend(quote! { #head_non_lit #(+ #tail_non_lit)* });
340            if self.lit < 0 {
341                tokens.extend(quote! { - });
342                tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(-self.lit)));
343            }
344        } else {
345            tokens.append(TokenTree::Literal(Literal::i64_unsuffixed(self.lit)));
346        }
347    }
348}
349
350/// Gets an expression for the number of values for a product of the given types.
351fn product_count(field_tys: &[TokenStream2]) -> NumTerm {
352    if let Some((head_field_ty, tail_field_tys)) = field_tys.split_first() {
353        NumTerm::Complex(quote! {
354            <#head_field_ty as ::cantor::Finite>::COUNT
355            #(* <#tail_field_tys as ::cantor::Finite>::COUNT)*
356        })
357    } else {
358        NumTerm::Literal(1)
359    }
360}
361
362/// Gets an expression which produces the index of a value of the product type, given the values
363/// of its fields.
364fn product_index_of(field_tys: &[TokenStream2], fields: &[TokenStream2]) -> TokenStream2 {
365    quote! {
366        {
367            let __index = 0;
368            #(let __index = __index *
369                <#field_tys as ::cantor::Finite>::COUNT +
370                <#field_tys as ::cantor::Finite>::index_of(#fields);)*
371            __index
372        }
373    }
374}
375
376/// Gets an expression which produces a value of the product, given an expression for a
377/// valid index and a constructor for values of the product.
378fn product_nth(
379    field_tys: &[TokenStream2],
380    index: TokenStream2,
381    fields: &[TokenStream2],
382    cons: TokenStream2,
383) -> TokenStream2 {
384    let field_tys_rev = field_tys.iter().rev();
385    let fields_rev = fields.iter().rev();
386    quote! {
387        {
388            let __index = #index;
389            #(
390                let #fields_rev = <#field_tys_rev as ::cantor::Finite>::nth(__index %
391                    <#field_tys_rev as ::cantor::Finite>::COUNT).unwrap();
392                let __index = __index / <#field_tys_rev as ::cantor::Finite>::COUNT;
393            )*
394            #cons
395        }
396    }
397}