unbound_derive/
lib.rs

1//! Derive macros for unbound library
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident};
6
7/// Derive macro for the Alpha trait
8#[proc_macro_derive(Alpha)]
9pub fn derive_alpha(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    let name = &input.ident;
12    let generics = &input.generics;
13    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
14
15    let aeq_impl = generate_aeq_impl(&input.data, name);
16    let aeq_in_impl = generate_aeq_in_impl(&input.data);
17    let fv_in_impl = generate_fv_in_impl(&input.data);
18
19    let expanded = quote! {
20        impl #impl_generics unbound::Alpha for #name #ty_generics #where_clause {
21            fn aeq(&self, other: &Self) -> bool {
22                #aeq_impl
23            }
24
25            fn aeq_in(&self, ctx: &mut unbound::alpha::AlphaCtx, other: &Self) -> bool {
26                #aeq_in_impl
27            }
28
29            fn fv_in(&self, vars: &mut Vec<String>) {
30                #fv_in_impl
31            }
32        }
33    };
34
35    TokenStream::from(expanded)
36}
37
38fn generate_aeq_impl(data: &Data, name: &Ident) -> proc_macro2::TokenStream {
39    match data {
40        Data::Struct(data_struct) => match &data_struct.fields {
41            Fields::Named(fields) => {
42                let field_checks = fields.named.iter().map(|f| {
43                    let field_name = &f.ident;
44                    quote! {
45                        self.#field_name.aeq(&other.#field_name)
46                    }
47                });
48                quote! {
49                    #(#field_checks)&&*
50                }
51            }
52            Fields::Unnamed(fields) => {
53                let field_checks = fields.unnamed.iter().enumerate().map(|(i, _)| {
54                    let index = syn::Index::from(i);
55                    quote! {
56                        self.#index.aeq(&other.#index)
57                    }
58                });
59                quote! {
60                    #(#field_checks)&&*
61                }
62            }
63            Fields::Unit => quote! { true },
64        },
65        Data::Enum(data_enum) => {
66            let variant_matches = data_enum.variants.iter().map(|variant| {
67                let variant_name = &variant.ident;
68                match &variant.fields {
69                    Fields::Named(fields) => {
70                        let field_names: Vec<_> = fields
71                            .named
72                            .iter()
73                            .filter_map(|f| f.ident.as_ref())
74                            .collect();
75                        let other_field_names: Vec<_> = field_names
76                            .iter()
77                            .map(|f| quote::format_ident!("other_{}", f))
78                            .collect();
79                        let field_checks = field_names.iter().zip(other_field_names.iter()).map(
80                            |(field_name, other_field_name)| {
81                                quote! {
82                                    #field_name.aeq(#other_field_name)
83                                }
84                            },
85                        );
86                        let other_bindings = field_names.iter().zip(other_field_names.iter()).map(
87                            |(field_name, other_field_name)| {
88                                quote! { #field_name: #other_field_name }
89                            },
90                        );
91                        quote! {
92                            (#name::#variant_name { #(#field_names),* },
93                             #name::#variant_name { #(#other_bindings),* }) => {
94                                #(#field_checks)&&*
95                            }
96                        }
97                    }
98                    Fields::Unnamed(fields) => {
99                        let field_names: Vec<_> = (0..fields.unnamed.len())
100                            .map(|i| quote::format_ident!("f{}", i))
101                            .collect();
102                        let other_names: Vec<_> = (0..fields.unnamed.len())
103                            .map(|i| quote::format_ident!("other_f{}", i))
104                            .collect();
105                        let field_checks =
106                            field_names
107                                .iter()
108                                .zip(other_names.iter())
109                                .map(|(f, other_f)| {
110                                    quote! {
111                                        #f.aeq(#other_f)
112                                    }
113                                });
114                        quote! {
115                            (#name::#variant_name(#(#field_names),*),
116                             #name::#variant_name(#(#other_names),*)) => {
117                                #(#field_checks)&&*
118                            }
119                        }
120                    }
121                    Fields::Unit => {
122                        quote! {
123                            (#name::#variant_name, #name::#variant_name) => true
124                        }
125                    }
126                }
127            });
128            quote! {
129                match (self, other) {
130                    #(#variant_matches,)*
131                    _ => false,
132                }
133            }
134        }
135        Data::Union(_) => panic!("Unions are not supported"),
136    }
137}
138
139fn generate_aeq_in_impl(data: &Data) -> proc_macro2::TokenStream {
140    match data {
141        Data::Struct(data_struct) => match &data_struct.fields {
142            Fields::Named(fields) => {
143                let field_checks = fields.named.iter().map(|f| {
144                    let field_name = &f.ident;
145                    quote! {
146                        self.#field_name.aeq_in(ctx, &other.#field_name)
147                    }
148                });
149                quote! {
150                    #(#field_checks)&&*
151                }
152            }
153            Fields::Unnamed(fields) => {
154                let field_checks = fields.unnamed.iter().enumerate().map(|(i, _)| {
155                    let index = syn::Index::from(i);
156                    quote! {
157                        self.#index.aeq_in(ctx, &other.#index)
158                    }
159                });
160                quote! {
161                    #(#field_checks)&&*
162                }
163            }
164            Fields::Unit => quote! { true },
165        },
166        Data::Enum(data_enum) => {
167            // Generate proper pattern matching for enums using aeq_in
168            let variant_matches = data_enum.variants.iter().map(|variant| {
169                let variant_name = &variant.ident;
170                match &variant.fields {
171                    Fields::Named(fields) => {
172                        let field_names: Vec<_> = fields
173                            .named
174                            .iter()
175                            .filter_map(|f| f.ident.as_ref())
176                            .collect();
177                        let other_field_names: Vec<_> = field_names
178                            .iter()
179                            .map(|f| quote::format_ident!("other_{}", f))
180                            .collect();
181                        let field_checks = field_names.iter().zip(other_field_names.iter()).map(
182                            |(field_name, other_field_name)| {
183                                quote! {
184                                    #field_name.aeq_in(ctx, #other_field_name)
185                                }
186                            },
187                        );
188                        let other_bindings = field_names.iter().zip(other_field_names.iter()).map(
189                            |(field_name, other_field_name)| {
190                                quote! { #field_name: #other_field_name }
191                            },
192                        );
193                        quote! {
194                            (Self::#variant_name { #(#field_names),* },
195                             Self::#variant_name { #(#other_bindings),* }) => {
196                                #(#field_checks)&&*
197                            }
198                        }
199                    }
200                    Fields::Unnamed(fields) => {
201                        let field_names: Vec<_> = (0..fields.unnamed.len())
202                            .map(|i| quote::format_ident!("f{}", i))
203                            .collect();
204                        let other_names: Vec<_> = (0..fields.unnamed.len())
205                            .map(|i| quote::format_ident!("other_f{}", i))
206                            .collect();
207                        let field_checks =
208                            field_names
209                                .iter()
210                                .zip(other_names.iter())
211                                .map(|(f, other_f)| {
212                                    quote! {
213                                        #f.aeq_in(ctx, #other_f)
214                                    }
215                                });
216                        quote! {
217                            (Self::#variant_name(#(#field_names),*),
218                             Self::#variant_name(#(#other_names),*)) => {
219                                #(#field_checks)&&*
220                            }
221                        }
222                    }
223                    Fields::Unit => {
224                        quote! {
225                            (Self::#variant_name, Self::#variant_name) => true
226                        }
227                    }
228                }
229            });
230            quote! {
231                match (self, other) {
232                    #(#variant_matches,)*
233                    _ => false,
234                }
235            }
236        }
237        Data::Union(_) => panic!("Unions are not supported"),
238    }
239}
240
241fn generate_fv_in_impl(data: &Data) -> proc_macro2::TokenStream {
242    match data {
243        Data::Struct(data_struct) => match &data_struct.fields {
244            Fields::Named(fields) => {
245                let field_calls = fields.named.iter().map(|f| {
246                    let field_name = &f.ident;
247                    quote! {
248                        self.#field_name.fv_in(vars);
249                    }
250                });
251                quote! {
252                    #(#field_calls)*
253                }
254            }
255            Fields::Unnamed(fields) => {
256                let field_calls = fields.unnamed.iter().enumerate().map(|(i, _)| {
257                    let index = syn::Index::from(i);
258                    quote! {
259                        self.#index.fv_in(vars);
260                    }
261                });
262                quote! {
263                    #(#field_calls)*
264                }
265            }
266            Fields::Unit => quote! {},
267        },
268        Data::Enum(data_enum) => {
269            let variant_matches = data_enum.variants.iter().map(|variant| {
270                let variant_name = &variant.ident;
271                match &variant.fields {
272                    Fields::Named(fields) => {
273                        let field_names: Vec<_> = fields
274                            .named
275                            .iter()
276                            .filter_map(|f| f.ident.as_ref())
277                            .collect();
278                        let field_calls = field_names.iter().map(|field_name| {
279                            quote! {
280                                #field_name.fv_in(vars);
281                            }
282                        });
283                        quote! {
284                            Self::#variant_name { #(#field_names),* } => {
285                                #(#field_calls)*
286                            }
287                        }
288                    }
289                    Fields::Unnamed(fields) => {
290                        let field_names: Vec<_> = (0..fields.unnamed.len())
291                            .map(|i| quote::format_ident!("f{}", i))
292                            .collect();
293                        let field_calls = field_names.iter().map(|f| {
294                            quote! {
295                                #f.fv_in(vars);
296                            }
297                        });
298                        quote! {
299                            Self::#variant_name(#(#field_names),*) => {
300                                #(#field_calls)*
301                            }
302                        }
303                    }
304                    Fields::Unit => {
305                        quote! {
306                            Self::#variant_name => {}
307                        }
308                    }
309                }
310            });
311            quote! {
312                match self {
313                    #(#variant_matches)*
314                }
315            }
316        }
317        Data::Union(_) => panic!("Unions are not supported"),
318    }
319}
320
321/// Derive macro for the Subst trait
322#[proc_macro_derive(Subst, attributes(subst_var))]
323pub fn derive_subst(input: TokenStream) -> TokenStream {
324    let input = parse_macro_input!(input as DeriveInput);
325    let name = &input.ident;
326    let generics = &input.generics;
327    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
328
329    let (is_var_impl, subst_impl) = generate_subst_impl(&input.data, name);
330
331    let expanded = quote! {
332        impl #impl_generics unbound::Subst<#name #ty_generics> for #name #ty_generics #where_clause {
333            fn is_var(&self) -> Option<unbound::SubstName<#name #ty_generics>> {
334                #is_var_impl
335            }
336
337            fn subst(&self, var: &unbound::Name<#name #ty_generics>, value: &#name #ty_generics) -> Self {
338                #subst_impl
339            }
340        }
341    };
342
343    TokenStream::from(expanded)
344}
345
346fn generate_subst_impl(
347    data: &Data,
348    name: &Ident,
349) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
350    match data {
351        Data::Enum(data_enum) => {
352            // Check for a V or Var variant
353            let var_variant = data_enum
354                .variants
355                .iter()
356                .find(|v| v.ident == "V" || v.ident == "Var" || v.ident == "Variable");
357
358            let is_var_impl = if let Some(var_variant) = var_variant {
359                let variant_name = &var_variant.ident;
360                quote! {
361                    match self {
362                        #name::#variant_name(x) => Some(unbound::SubstName::Name(x.clone())),
363                        _ => None,
364                    }
365                }
366            } else {
367                quote! { None }
368            };
369
370            let subst_cases = data_enum.variants.iter().map(|variant| {
371                let variant_name = &variant.ident;
372
373                // Special handling for variable variant
374                if Some(&variant.ident) == var_variant.as_ref().map(|v| &v.ident) {
375                    quote! {
376                        #name::#variant_name(x) => {
377                            if x == var {
378                                value.clone()
379                            } else {
380                                self.clone()
381                            }
382                        }
383                    }
384                } else if variant.ident == "Lam" {
385                    // Special handling for lambda variant with Bind
386                    match &variant.fields {
387                        Fields::Unnamed(_) => {
388                            quote! {
389                                #name::#variant_name(bnd) => {
390                                    // Check if the bound variable is the same as the substitution variable
391                                    let bound_var = bnd.pattern();
392                                    if bound_var == var {
393                                        // No substitution under the binder
394                                        self.clone()
395                                    } else {
396                                        // Perform capture-avoiding substitution
397                                        let body_subst = bnd.body().subst(var, value);
398                                        #name::#variant_name(unbound::bind(bound_var.clone(), body_subst))
399                                    }
400                                }
401                            }
402                        }
403                        _ => {
404                            // Fallback for other field types
405                            match &variant.fields {
406                                Fields::Named(fields) => {
407                                    let field_names: Vec<_> =
408                                        fields.named.iter().filter_map(|f| f.ident.as_ref()).collect();
409                                    let field_substs = field_names.iter().map(|field_name| {
410                                        quote! {
411                                            #field_name: #field_name.subst(var, value)
412                                        }
413                                    });
414                                    quote! {
415                                        #name::#variant_name { #(#field_names),* } => {
416                                            #name::#variant_name {
417                                                #(#field_substs),*
418                                            }
419                                        }
420                                    }
421                                }
422                                Fields::Unnamed(fields) => {
423                                    let field_names: Vec<_> = (0..fields.unnamed.len())
424                                        .map(|i| quote::format_ident!("f{}", i))
425                                        .collect();
426                                    let field_substs = field_names.iter().map(|f| {
427                                        quote! {
428                                            #f.subst(var, value)
429                                        }
430                                    });
431                                    quote! {
432                                        #name::#variant_name(#(#field_names),*) => {
433                                            #name::#variant_name(#(#field_substs),*)
434                                        }
435                                    }
436                                }
437                                Fields::Unit => {
438                                    quote! {
439                                        #name::#variant_name => #name::#variant_name
440                                    }
441                                }
442                            }
443                        }
444                    }
445                } else {
446                    match &variant.fields {
447                        Fields::Named(fields) => {
448                            let field_names: Vec<_> =
449                                fields.named.iter().filter_map(|f| f.ident.as_ref()).collect();
450                            let field_substs = field_names.iter().map(|field_name| {
451                                quote! {
452                                    #field_name: #field_name.subst(var, value)
453                                }
454                            });
455                            quote! {
456                                #name::#variant_name { #(#field_names),* } => {
457                                    #name::#variant_name {
458                                        #(#field_substs),*
459                                    }
460                                }
461                            }
462                        }
463                        Fields::Unnamed(fields) => {
464                            let field_names: Vec<_> = (0..fields.unnamed.len())
465                                .map(|i| quote::format_ident!("f{}", i))
466                                .collect();
467                            let field_substs = field_names.iter().map(|f| {
468                                quote! {
469                                    #f.subst(var, value)
470                                }
471                            });
472                            quote! {
473                                #name::#variant_name(#(#field_names),*) => {
474                                    #name::#variant_name(#(#field_substs),*)
475                                }
476                            }
477                        }
478                        Fields::Unit => {
479                            quote! {
480                                #name::#variant_name => #name::#variant_name
481                            }
482                        }
483                    }
484                }
485            });
486
487            let subst_impl = quote! {
488                match self {
489                    #(#subst_cases),*
490                }
491            };
492
493            (is_var_impl, subst_impl)
494        }
495        Data::Struct(data_struct) => {
496            let is_var_impl = quote! { None };
497
498            let subst_impl = match &data_struct.fields {
499                Fields::Named(fields) => {
500                    let field_names: Vec<_> = fields
501                        .named
502                        .iter()
503                        .filter_map(|f| f.ident.as_ref())
504                        .collect();
505                    let field_substs = field_names.iter().map(|field_name| {
506                        quote! {
507                            #field_name: self.#field_name.subst(var, value)
508                        }
509                    });
510                    quote! {
511                        #name {
512                            #(#field_substs),*
513                        }
514                    }
515                }
516                Fields::Unnamed(fields) => {
517                    let field_substs = (0..fields.unnamed.len()).map(|i| {
518                        let index = syn::Index::from(i);
519                        quote! {
520                            self.#index.subst(var, value)
521                        }
522                    });
523                    quote! {
524                        #name(#(#field_substs),*)
525                    }
526                }
527                Fields::Unit => quote! { #name },
528            };
529
530            (is_var_impl, subst_impl)
531        }
532        Data::Union(_) => panic!("Unions are not supported"),
533    }
534}