1extern crate proc_macro;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use quote::ToTokens;
6use syn::{parse_quote, DeriveInput, Ident, TypeParam, TypeParamBound};
7
8use synstructure::decl_derive;
9
10fn has_interner(param: &TypeParam) -> Option<&Ident> {
12 bounded_by_trait(param, "HasInterner")
13}
14
15fn is_interner(param: &TypeParam) -> Option<&Ident> {
17 bounded_by_trait(param, "Interner")
18}
19
20fn has_interner_attr(input: &DeriveInput) -> Option<TokenStream> {
21 Some(
22 input
23 .attrs
24 .iter()
25 .find(|a| a.path().is_ident("has_interner"))?
26 .parse_args::<TokenStream>()
27 .expect("Expected has_interner argument"),
28 )
29}
30
31fn bounded_by_trait<'p>(param: &'p TypeParam, name: &str) -> Option<&'p Ident> {
32 let name = Some(String::from(name));
33 param.bounds.iter().find_map(|b| {
34 if let TypeParamBound::Trait(trait_bound) = b {
35 if trait_bound
36 .path
37 .segments
38 .last()
39 .map(|s| s.ident.to_string())
40 == name
41 {
42 return Some(¶m.ident);
43 }
44 }
45 None
46 })
47}
48
49fn get_intern_param(input: &DeriveInput) -> Option<(DeriveKind, &Ident)> {
50 let mut params = input.generics.type_params().filter_map(|param| {
51 has_interner(param)
52 .map(|ident| (DeriveKind::FromHasInterner, ident))
53 .or_else(|| is_interner(param).map(|ident| (DeriveKind::FromInterner, ident)))
54 });
55
56 let param = params.next();
57 assert!(params.next().is_none(), "deriving this trait only works with at most one type parameter that implements HasInterner or Interner");
58
59 param
60}
61
62fn get_intern_param_name(input: &DeriveInput) -> &Ident {
63 get_intern_param(input)
64 .expect("deriving this trait requires a parameter that implements HasInterner or Interner")
65 .1
66}
67
68fn try_find_interner(s: &mut synstructure::Structure) -> Option<(TokenStream, DeriveKind)> {
69 let input = s.ast();
70
71 if let Some(arg) = has_interner_attr(input) {
72 return Some((arg, DeriveKind::FromHasInternerAttr));
79 }
80
81 get_intern_param(input).map(|generic_param0| match generic_param0 {
82 (DeriveKind::FromHasInterner, param) => {
83 s.add_impl_generic(parse_quote! { _I });
89
90 s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
91 s.add_where_predicate(
92 parse_quote! { #param: ::chalk_ir::interner::HasInterner<Interner = _I> },
93 );
94
95 (quote! { _I }, DeriveKind::FromHasInterner)
96 }
97 (DeriveKind::FromInterner, i) => {
98 (quote! { #i }, DeriveKind::FromInterner)
104 }
105 _ => unreachable!(),
106 })
107}
108
109fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
110 try_find_interner(s)
111 .expect("deriving this trait requires a `#[has_interner]` attr or a parameter that implements HasInterner or Interner")
112}
113
114#[derive(Copy, Clone, PartialEq)]
115enum DeriveKind {
116 FromHasInternerAttr,
117 FromHasInterner,
118 FromInterner,
119}
120
121decl_derive!([FallibleTypeFolder, attributes(has_interner)] => derive_fallible_type_folder);
122decl_derive!([HasInterner, attributes(has_interner)] => derive_has_interner);
123decl_derive!([TypeVisitable, attributes(has_interner)] => derive_type_visitable);
124decl_derive!([TypeSuperVisitable, attributes(has_interner)] => derive_type_super_visitable);
125decl_derive!([TypeFoldable, attributes(has_interner)] => derive_type_foldable);
126decl_derive!([Zip, attributes(has_interner)] => derive_zip);
127
128fn derive_has_interner(mut s: synstructure::Structure) -> TokenStream {
129 s.underscore_const(true);
130 let (interner, _) = find_interner(&mut s);
131
132 s.add_bounds(synstructure::AddBounds::None);
133 s.bound_impl(
134 quote!(::chalk_ir::interner::HasInterner),
135 quote! {
136 type Interner = #interner;
137 },
138 )
139}
140
141fn derive_type_visitable(s: synstructure::Structure) -> TokenStream {
146 derive_any_type_visitable(
147 s,
148 parse_quote! { TypeVisitable },
149 parse_quote! { visit_with },
150 )
151}
152
153fn derive_type_super_visitable(s: synstructure::Structure) -> TokenStream {
155 derive_any_type_visitable(
156 s,
157 parse_quote! { TypeSuperVisitable },
158 parse_quote! { super_visit_with },
159 )
160}
161
162fn derive_any_type_visitable(
163 mut s: synstructure::Structure,
164 trait_name: Ident,
165 method_name: Ident,
166) -> TokenStream {
167 s.underscore_const(true);
168 let input = s.ast();
169 let (interner, kind) = find_interner(&mut s);
170
171 let body = s.each(|bi| {
172 quote! {
173 ::chalk_ir::try_break!(::chalk_ir::visit::TypeVisitable::visit_with(#bi, visitor, outer_binder));
174 }
175 });
176
177 if kind == DeriveKind::FromHasInterner {
178 let param = get_intern_param_name(input);
179 s.add_where_predicate(parse_quote! { #param: ::chalk_ir::visit::TypeVisitable<#interner> });
180 }
181
182 s.add_bounds(synstructure::AddBounds::None);
183 s.bound_impl(
184 quote!(::chalk_ir::visit:: #trait_name <#interner>),
185 quote! {
186 fn #method_name <B>(
187 &self,
188 visitor: &mut dyn ::chalk_ir::visit::TypeVisitor < #interner, BreakTy = B >,
189 outer_binder: ::chalk_ir::DebruijnIndex,
190 ) -> std::ops::ControlFlow<B> {
191 match *self {
192 #body
193 }
194 std::ops::ControlFlow::Continue(())
195 }
196 },
197 )
198}
199
200fn each_variant_pair<F, R>(
201 a: &mut synstructure::Structure,
202 b: &mut synstructure::Structure,
203 mut f: F,
204) -> TokenStream
205where
206 F: FnMut(&synstructure::VariantInfo<'_>, &synstructure::VariantInfo<'_>) -> R,
207 R: ToTokens,
208{
209 let mut t = TokenStream::new();
210 for (v_a, v_b) in a.variants_mut().iter_mut().zip(b.variants_mut().iter_mut()) {
211 v_a.binding_name(|_, i| Ident::new(&format!("a_{}", i), Span::call_site()));
212 v_b.binding_name(|_, i| Ident::new(&format!("b_{}", i), Span::call_site()));
213
214 let pat_a = v_a.pat();
215 let pat_b = v_b.pat();
216 let body = f(v_a, v_b);
217
218 quote!((#pat_a, #pat_b) => {#body}).to_tokens(&mut t);
219 }
220 t
221}
222
223fn derive_zip(mut s: synstructure::Structure) -> TokenStream {
224 s.underscore_const(true);
225 let (interner, _) = find_interner(&mut s);
226
227 let mut a = s.clone();
228 let mut b = s.clone();
229
230 let mut body = each_variant_pair(&mut a, &mut b, |v_a, v_b| {
231 let mut t = TokenStream::new();
232 for (b_a, b_b) in v_a.bindings().iter().zip(v_b.bindings().iter()) {
233 quote!(chalk_ir::zip::Zip::zip_with(zipper, variance, #b_a, #b_b)?;).to_tokens(&mut t);
234 }
235 quote!(Ok(())).to_tokens(&mut t);
236 t
237 });
238
239 quote!((_, _) => Err(::chalk_ir::NoSolution)).to_tokens(&mut body);
241
242 s.add_bounds(synstructure::AddBounds::None);
243 s.bound_impl(
244 quote!(::chalk_ir::zip::Zip<#interner>),
245 quote! {
246
247 fn zip_with<Z: ::chalk_ir::zip::Zipper<#interner>>(
248 zipper: &mut Z,
249 variance: ::chalk_ir::Variance,
250 a: &Self,
251 b: &Self,
252 ) -> ::chalk_ir::Fallible<()> {
253 match (a, b) { #body }
254 }
255 },
256 )
257}
258
259fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
264 s.underscore_const(true);
265 s.bind_with(|_| synstructure::BindStyle::Move);
266
267 let (interner, kind) = find_interner(&mut s);
268
269 let body = s.each_variant(|vi| {
270 let bindings = vi.bindings();
271 vi.construct(|_, index| {
272 let bind = &bindings[index];
273 quote! {
274 ::chalk_ir::fold::TypeFoldable::try_fold_with(#bind, folder, outer_binder)?
275 }
276 })
277 });
278
279 let input = s.ast();
280
281 if kind == DeriveKind::FromHasInterner {
282 let param = get_intern_param_name(input);
283 s.add_where_predicate(parse_quote! { #param: ::chalk_ir::fold::TypeFoldable<#interner> });
284 };
285
286 s.add_bounds(synstructure::AddBounds::None);
287 s.bound_impl(
288 quote!(::chalk_ir::fold::TypeFoldable<#interner>),
289 quote! {
290 fn try_fold_with<E>(
291 self,
292 folder: &mut dyn ::chalk_ir::fold::FallibleTypeFolder < #interner, Error = E >,
293 outer_binder: ::chalk_ir::DebruijnIndex,
294 ) -> ::std::result::Result<Self, E> {
295 Ok(match self { #body })
296 }
297 },
298 )
299}
300
301fn derive_fallible_type_folder(mut s: synstructure::Structure) -> TokenStream {
302 let interner = try_find_interner(&mut s).map_or_else(
303 || {
304 s.add_impl_generic(parse_quote! { _I });
305 s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
306 quote! { _I }
307 },
308 |(interner, _)| interner,
309 );
310 s.underscore_const(true);
311 s.unbound_impl(
312 quote!(::chalk_ir::fold::FallibleTypeFolder<#interner>),
313 quote! {
314 type Error = ::core::convert::Infallible;
315
316 fn as_dyn(&mut self) -> &mut dyn ::chalk_ir::fold::FallibleTypeFolder<#interner, Error = Self::Error> {
317 self
318 }
319
320 fn try_fold_ty(
321 &mut self,
322 ty: ::chalk_ir::Ty<#interner>,
323 outer_binder: ::chalk_ir::DebruijnIndex,
324 ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
325 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_ty(self, ty, outer_binder))
326 }
327
328 fn try_fold_lifetime(
329 &mut self,
330 lifetime: ::chalk_ir::Lifetime<#interner>,
331 outer_binder: ::chalk_ir::DebruijnIndex,
332 ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
333 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_lifetime(self, lifetime, outer_binder))
334 }
335
336 fn try_fold_const(
337 &mut self,
338 constant: ::chalk_ir::Const<#interner>,
339 outer_binder: ::chalk_ir::DebruijnIndex,
340 ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
341 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_const(self, constant, outer_binder))
342 }
343
344 fn try_fold_program_clause(
345 &mut self,
346 clause: ::chalk_ir::ProgramClause<#interner>,
347 outer_binder: ::chalk_ir::DebruijnIndex,
348 ) -> ::core::result::Result<::chalk_ir::ProgramClause<#interner>, Self::Error> {
349 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_program_clause(self, clause, outer_binder))
350 }
351
352 fn try_fold_goal(
353 &mut self,
354 goal: ::chalk_ir::Goal<#interner>,
355 outer_binder: ::chalk_ir::DebruijnIndex,
356 ) -> ::core::result::Result<::chalk_ir::Goal<#interner>, Self::Error> {
357 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_goal(self, goal, outer_binder))
358 }
359
360 fn forbid_free_vars(&self) -> bool {
361 ::chalk_ir::fold::TypeFolder::forbid_free_vars(self)
362 }
363
364 fn try_fold_free_var_ty(
365 &mut self,
366 bound_var: ::chalk_ir::BoundVar,
367 outer_binder: ::chalk_ir::DebruijnIndex,
368 ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
369 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_ty(self, bound_var, outer_binder))
370 }
371
372 fn try_fold_free_var_lifetime(
373 &mut self,
374 bound_var: ::chalk_ir::BoundVar,
375 outer_binder: ::chalk_ir::DebruijnIndex,
376 ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
377 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_lifetime(self, bound_var, outer_binder))
378 }
379
380 fn try_fold_free_var_const(
381 &mut self,
382 ty: ::chalk_ir::Ty<#interner>,
383 bound_var: ::chalk_ir::BoundVar,
384 outer_binder: ::chalk_ir::DebruijnIndex,
385 ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
386 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_const(self, ty, bound_var, outer_binder))
387 }
388
389 fn forbid_free_placeholders(&self) -> bool {
390 ::chalk_ir::fold::TypeFolder::forbid_free_placeholders(self)
391 }
392
393 fn try_fold_free_placeholder_ty(
394 &mut self,
395 universe: ::chalk_ir::PlaceholderIndex,
396 outer_binder: ::chalk_ir::DebruijnIndex,
397 ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
398 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_ty(self, universe, outer_binder))
399 }
400
401 fn try_fold_free_placeholder_lifetime(
402 &mut self,
403 universe: ::chalk_ir::PlaceholderIndex,
404 outer_binder: ::chalk_ir::DebruijnIndex,
405 ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
406 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_lifetime(self, universe, outer_binder))
407 }
408
409 fn try_fold_free_placeholder_const(
410 &mut self,
411 ty: ::chalk_ir::Ty<#interner>,
412 universe: ::chalk_ir::PlaceholderIndex,
413 outer_binder: ::chalk_ir::DebruijnIndex,
414 ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
415 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_const(self, ty, universe, outer_binder))
416 }
417
418 fn forbid_inference_vars(&self) -> bool {
419 ::chalk_ir::fold::TypeFolder::forbid_inference_vars(self)
420 }
421
422 fn try_fold_inference_ty(
423 &mut self,
424 var: ::chalk_ir::InferenceVar,
425 kind: ::chalk_ir::TyVariableKind,
426 outer_binder: ::chalk_ir::DebruijnIndex,
427 ) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
428 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_ty(self, var, kind, outer_binder))
429 }
430
431 fn try_fold_inference_lifetime(
432 &mut self,
433 var: ::chalk_ir::InferenceVar,
434 outer_binder: ::chalk_ir::DebruijnIndex,
435 ) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
436 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_lifetime(self, var, outer_binder))
437 }
438
439 fn try_fold_inference_const(
440 &mut self,
441 ty: ::chalk_ir::Ty<#interner>,
442 var: ::chalk_ir::InferenceVar,
443 outer_binder: ::chalk_ir::DebruijnIndex,
444 ) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
445 ::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_const(self, ty, var, outer_binder))
446 }
447
448 fn interner(&self) -> #interner {
449 ::chalk_ir::fold::TypeFolder::interner(self)
450 }
451 },
452 )
453}