chalk_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use quote::ToTokens;
6use syn::{parse_quote, DeriveInput, Ident, TypeParam, TypeParamBound};
7
8use synstructure::decl_derive;
9
10/// Checks whether a generic parameter has a `: HasInterner` bound
11fn has_interner(param: &TypeParam) -> Option<&Ident> {
12    bounded_by_trait(param, "HasInterner")
13}
14
15/// Checks whether a generic parameter has a `: Interner` bound
16fn is_interner(param: &TypeParam) -> Option<&Ident> {
17    bounded_by_trait(param, "Interner")
18}
19
20fn has_interner_attr(input: &DeriveInput) -> Option<TokenStream> {
21    Some(
22        input
23            .attrs
24            .iter()
25            .find(|a| a.path().is_ident("has_interner"))?
26            .parse_args::<TokenStream>()
27            .expect("Expected has_interner argument"),
28    )
29}
30
31fn bounded_by_trait<'p>(param: &'p TypeParam, name: &str) -> Option<&'p Ident> {
32    let name = Some(String::from(name));
33    param.bounds.iter().find_map(|b| {
34        if let TypeParamBound::Trait(trait_bound) = b {
35            if trait_bound
36                .path
37                .segments
38                .last()
39                .map(|s| s.ident.to_string())
40                == name
41            {
42                return Some(&param.ident);
43            }
44        }
45        None
46    })
47}
48
49fn get_intern_param(input: &DeriveInput) -> Option<(DeriveKind, &Ident)> {
50    let mut params = input.generics.type_params().filter_map(|param| {
51        has_interner(param)
52            .map(|ident| (DeriveKind::FromHasInterner, ident))
53            .or_else(|| is_interner(param).map(|ident| (DeriveKind::FromInterner, ident)))
54    });
55
56    let param = params.next();
57    assert!(params.next().is_none(), "deriving this trait only works with at most one type parameter that implements HasInterner or Interner");
58
59    param
60}
61
62fn get_intern_param_name(input: &DeriveInput) -> &Ident {
63    get_intern_param(input)
64        .expect("deriving this trait requires a parameter that implements HasInterner or Interner")
65        .1
66}
67
68fn try_find_interner(s: &mut synstructure::Structure) -> Option<(TokenStream, DeriveKind)> {
69    let input = s.ast();
70
71    if let Some(arg) = has_interner_attr(input) {
72        // Hardcoded interner:
73        //
74        // #[has_interner(ChalkIr)]
75        // struct S {
76        //
77        // }
78        return Some((arg, DeriveKind::FromHasInternerAttr));
79    }
80
81    get_intern_param(input).map(|generic_param0| match generic_param0 {
82        (DeriveKind::FromHasInterner, param) => {
83            // HasInterner bound:
84            //
85            // Example:
86            //
87            // struct Binders<T: HasInterner> { }
88            s.add_impl_generic(parse_quote! { _I });
89
90            s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
91            s.add_where_predicate(
92                parse_quote! { #param: ::chalk_ir::interner::HasInterner<Interner = _I> },
93            );
94
95            (quote! { _I }, DeriveKind::FromHasInterner)
96        }
97        (DeriveKind::FromInterner, i) => {
98            // Interner bound:
99            //
100            // Example:
101            //
102            // struct Foo<I: Interner> { }
103            (quote! { #i }, DeriveKind::FromInterner)
104        }
105        _ => unreachable!(),
106    })
107}
108
109fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
110    try_find_interner(s)
111        .expect("deriving this trait requires a `#[has_interner]` attr or a parameter that implements HasInterner or Interner")
112}
113
114#[derive(Copy, Clone, PartialEq)]
115enum DeriveKind {
116    FromHasInternerAttr,
117    FromHasInterner,
118    FromInterner,
119}
120
121decl_derive!([FallibleTypeFolder, attributes(has_interner)] => derive_fallible_type_folder);
122decl_derive!([HasInterner, attributes(has_interner)] => derive_has_interner);
123decl_derive!([TypeVisitable, attributes(has_interner)] => derive_type_visitable);
124decl_derive!([TypeSuperVisitable, attributes(has_interner)] => derive_type_super_visitable);
125decl_derive!([TypeFoldable, attributes(has_interner)] => derive_type_foldable);
126decl_derive!([Zip, attributes(has_interner)] => derive_zip);
127
128fn derive_has_interner(mut s: synstructure::Structure) -> TokenStream {
129    s.underscore_const(true);
130    let (interner, _) = find_interner(&mut s);
131
132    s.add_bounds(synstructure::AddBounds::None);
133    s.bound_impl(
134        quote!(::chalk_ir::interner::HasInterner),
135        quote! {
136            type Interner = #interner;
137        },
138    )
139}
140
141/// Derives TypeVisitable for structs and enums for which one of the following is true:
142/// - It has a `#[has_interner(TheInterner)]` attribute
143/// - There is a single parameter `T: HasInterner` (does not have to be named `T`)
144/// - There is a single parameter `I: Interner` (does not have to be named `I`)
145fn derive_type_visitable(s: synstructure::Structure) -> TokenStream {
146    derive_any_type_visitable(
147        s,
148        parse_quote! { TypeVisitable },
149        parse_quote! { visit_with },
150    )
151}
152
153/// Same as TypeVisitable, but derives TypeSuperVisitable instead
154fn derive_type_super_visitable(s: synstructure::Structure) -> TokenStream {
155    derive_any_type_visitable(
156        s,
157        parse_quote! { TypeSuperVisitable },
158        parse_quote! { super_visit_with },
159    )
160}
161
162fn derive_any_type_visitable(
163    mut s: synstructure::Structure,
164    trait_name: Ident,
165    method_name: Ident,
166) -> TokenStream {
167    s.underscore_const(true);
168    let input = s.ast();
169    let (interner, kind) = find_interner(&mut s);
170
171    let body = s.each(|bi| {
172        quote! {
173            ::chalk_ir::try_break!(::chalk_ir::visit::TypeVisitable::visit_with(#bi, visitor, outer_binder));
174        }
175    });
176
177    if kind == DeriveKind::FromHasInterner {
178        let param = get_intern_param_name(input);
179        s.add_where_predicate(parse_quote! { #param: ::chalk_ir::visit::TypeVisitable<#interner> });
180    }
181
182    s.add_bounds(synstructure::AddBounds::None);
183    s.bound_impl(
184        quote!(::chalk_ir::visit:: #trait_name <#interner>),
185        quote! {
186            fn #method_name <B>(
187                &self,
188                visitor: &mut dyn ::chalk_ir::visit::TypeVisitor < #interner, BreakTy = B >,
189                outer_binder: ::chalk_ir::DebruijnIndex,
190            ) -> std::ops::ControlFlow<B> {
191                match *self {
192                    #body
193                }
194                std::ops::ControlFlow::Continue(())
195            }
196        },
197    )
198}
199
200fn each_variant_pair<F, R>(
201    a: &mut synstructure::Structure,
202    b: &mut synstructure::Structure,
203    mut f: F,
204) -> TokenStream
205where
206    F: FnMut(&synstructure::VariantInfo<'_>, &synstructure::VariantInfo<'_>) -> R,
207    R: ToTokens,
208{
209    let mut t = TokenStream::new();
210    for (v_a, v_b) in a.variants_mut().iter_mut().zip(b.variants_mut().iter_mut()) {
211        v_a.binding_name(|_, i| Ident::new(&format!("a_{}", i), Span::call_site()));
212        v_b.binding_name(|_, i| Ident::new(&format!("b_{}", i), Span::call_site()));
213
214        let pat_a = v_a.pat();
215        let pat_b = v_b.pat();
216        let body = f(v_a, v_b);
217
218        quote!((#pat_a, #pat_b)  => {#body}).to_tokens(&mut t);
219    }
220    t
221}
222
223fn derive_zip(mut s: synstructure::Structure) -> TokenStream {
224    s.underscore_const(true);
225    let (interner, _) = find_interner(&mut s);
226
227    let mut a = s.clone();
228    let mut b = s.clone();
229
230    let mut body = each_variant_pair(&mut a, &mut b, |v_a, v_b| {
231        let mut t = TokenStream::new();
232        for (b_a, b_b) in v_a.bindings().iter().zip(v_b.bindings().iter()) {
233            quote!(chalk_ir::zip::Zip::zip_with(zipper, variance, #b_a, #b_b)?;).to_tokens(&mut t);
234        }
235        quote!(Ok(())).to_tokens(&mut t);
236        t
237    });
238
239    // when the two variants are different
240    quote!((_, _)  => Err(::chalk_ir::NoSolution)).to_tokens(&mut body);
241
242    s.add_bounds(synstructure::AddBounds::None);
243    s.bound_impl(
244        quote!(::chalk_ir::zip::Zip<#interner>),
245        quote! {
246
247            fn zip_with<Z: ::chalk_ir::zip::Zipper<#interner>>(
248                zipper: &mut Z,
249                variance: ::chalk_ir::Variance,
250                a: &Self,
251                b: &Self,
252            ) -> ::chalk_ir::Fallible<()> {
253                    match (a, b) { #body }
254                }
255        },
256    )
257}
258
259/// Derives TypeFoldable for structs and enums for which one of the following is true:
260/// - It has a `#[has_interner(TheInterner)]` attribute
261/// - There is a single parameter `T: HasInterner` (does not have to be named `T`)
262/// - There is a single parameter `I: Interner` (does not have to be named `I`)
263fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
264    s.underscore_const(true);
265    s.bind_with(|_| synstructure::BindStyle::Move);
266
267    let (interner, kind) = find_interner(&mut s);
268
269    let body = s.each_variant(|vi| {
270        let bindings = vi.bindings();
271        vi.construct(|_, index| {
272            let bind = &bindings[index];
273            quote! {
274                ::chalk_ir::fold::TypeFoldable::try_fold_with(#bind, folder, outer_binder)?
275            }
276        })
277    });
278
279    let input = s.ast();
280
281    if kind == DeriveKind::FromHasInterner {
282        let param = get_intern_param_name(input);
283        s.add_where_predicate(parse_quote! { #param: ::chalk_ir::fold::TypeFoldable<#interner> });
284    };
285
286    s.add_bounds(synstructure::AddBounds::None);
287    s.bound_impl(
288        quote!(::chalk_ir::fold::TypeFoldable<#interner>),
289        quote! {
290            fn try_fold_with<E>(
291                self,
292                folder: &mut dyn ::chalk_ir::fold::FallibleTypeFolder < #interner, Error = E >,
293                outer_binder: ::chalk_ir::DebruijnIndex,
294            ) -> ::std::result::Result<Self, E> {
295                Ok(match self { #body })
296            }
297        },
298    )
299}
300
301fn derive_fallible_type_folder(mut s: synstructure::Structure) -> TokenStream {
302    let interner = try_find_interner(&mut s).map_or_else(
303        || {
304            s.add_impl_generic(parse_quote! { _I });
305            s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
306            quote! { _I }
307        },
308        |(interner, _)| interner,
309    );
310    s.underscore_const(true);
311    s.unbound_impl(
312        quote!(::chalk_ir::fold::FallibleTypeFolder<#interner>),
313        quote! {
314            type Error = ::core::convert::Infallible;
315
316            fn as_dyn(&mut self) -> &mut dyn ::chalk_ir::fold::FallibleTypeFolder<#interner, Error = Self::Error> {
317                self
318            }
319
320            fn try_fold_ty(
321                &mut self,
322                ty: ::chalk_ir::Ty<#interner>,
323                outer_binder: ::chalk_ir::DebruijnIndex,
324            ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
325                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_ty(self, ty, outer_binder))
326            }
327
328            fn try_fold_lifetime(
329                &mut self,
330                lifetime: ::chalk_ir::Lifetime<#interner>,
331                outer_binder: ::chalk_ir::DebruijnIndex,
332            ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
333                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_lifetime(self, lifetime, outer_binder))
334            }
335
336            fn try_fold_const(
337                &mut self,
338                constant: ::chalk_ir::Const<#interner>,
339                outer_binder: ::chalk_ir::DebruijnIndex,
340            ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
341                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_const(self, constant, outer_binder))
342            }
343
344            fn try_fold_program_clause(
345                &mut self,
346                clause: ::chalk_ir::ProgramClause<#interner>,
347                outer_binder: ::chalk_ir::DebruijnIndex,
348            ) -> ::core::result::Result<::chalk_ir::ProgramClause<#interner>, Self::Error> {
349                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_program_clause(self, clause, outer_binder))
350            }
351
352            fn try_fold_goal(
353                &mut self,
354                goal: ::chalk_ir::Goal<#interner>,
355                outer_binder: ::chalk_ir::DebruijnIndex,
356            ) -> ::core::result::Result<::chalk_ir::Goal<#interner>, Self::Error> {
357                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_goal(self, goal, outer_binder))
358            }
359
360            fn forbid_free_vars(&self) -> bool {
361                ::chalk_ir::fold::TypeFolder::forbid_free_vars(self)
362            }
363
364            fn try_fold_free_var_ty(
365                &mut self,
366                bound_var: ::chalk_ir::BoundVar,
367                outer_binder: ::chalk_ir::DebruijnIndex,
368            ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
369                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_ty(self, bound_var, outer_binder))
370            }
371
372            fn try_fold_free_var_lifetime(
373                &mut self,
374                bound_var: ::chalk_ir::BoundVar,
375                outer_binder: ::chalk_ir::DebruijnIndex,
376            ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
377                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_lifetime(self, bound_var, outer_binder))
378            }
379
380            fn try_fold_free_var_const(
381                &mut self,
382                ty: ::chalk_ir::Ty<#interner>,
383                bound_var: ::chalk_ir::BoundVar,
384                outer_binder: ::chalk_ir::DebruijnIndex,
385            ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
386                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_const(self, ty, bound_var, outer_binder))
387            }
388
389            fn forbid_free_placeholders(&self) -> bool {
390                ::chalk_ir::fold::TypeFolder::forbid_free_placeholders(self)
391            }
392
393            fn try_fold_free_placeholder_ty(
394                &mut self,
395                universe: ::chalk_ir::PlaceholderIndex,
396                outer_binder: ::chalk_ir::DebruijnIndex,
397            ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
398                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_ty(self, universe, outer_binder))
399            }
400
401            fn try_fold_free_placeholder_lifetime(
402                &mut self,
403                universe: ::chalk_ir::PlaceholderIndex,
404                outer_binder: ::chalk_ir::DebruijnIndex,
405            ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
406                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_lifetime(self, universe, outer_binder))
407            }
408
409            fn try_fold_free_placeholder_const(
410                &mut self,
411                ty: ::chalk_ir::Ty<#interner>,
412                universe: ::chalk_ir::PlaceholderIndex,
413                outer_binder: ::chalk_ir::DebruijnIndex,
414            ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
415                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_const(self, ty, universe, outer_binder))
416            }
417
418            fn forbid_inference_vars(&self) -> bool {
419                ::chalk_ir::fold::TypeFolder::forbid_inference_vars(self)
420            }
421
422            fn try_fold_inference_ty(
423                &mut self,
424                var: ::chalk_ir::InferenceVar,
425                kind: ::chalk_ir::TyVariableKind,
426                outer_binder: ::chalk_ir::DebruijnIndex,
427            ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
428                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_ty(self, var, kind, outer_binder))
429            }
430
431            fn try_fold_inference_lifetime(
432                &mut self,
433                var: ::chalk_ir::InferenceVar,
434                outer_binder: ::chalk_ir::DebruijnIndex,
435            ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
436                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_lifetime(self, var, outer_binder))
437            }
438
439            fn try_fold_inference_const(
440                &mut self,
441                ty: ::chalk_ir::Ty<#interner>,
442                var: ::chalk_ir::InferenceVar,
443                outer_binder: ::chalk_ir::DebruijnIndex,
444            ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
445                ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_const(self, ty, var, outer_binder))
446            }
447
448            fn interner(&self) -> #interner {
449                ::chalk_ir::fold::TypeFolder::interner(self)
450            }
451        },
452    )
453}