1use itertools::Itertools;
6use std::collections::{hash_map::Entry, HashMap};
7
8use rustc_hir::def::DefKind;
9use rustc_hir::def_id::DefId;
10use rustc_middle::traits::CodegenObligationError;
11use rustc_middle::ty::*;
12use rustc_trait_selection::traits::ImplSource;
13
14use crate::{self_predicate, traits::utils::erase_and_norm};
15
16use super::utils::{implied_predicates, required_predicates, ToPolyTraitRef};
17
18#[derive(Debug, Clone)]
19pub enum PathChunk<'tcx> {
20 AssocItem {
21 item: AssocItem,
22 generic_args: &'tcx [GenericArg<'tcx>],
24 impl_exprs: Vec<ImplExpr<'tcx>>,
32 predicate: PolyTraitPredicate<'tcx>,
34 index: usize,
36 },
37 Parent {
38 predicate: PolyTraitPredicate<'tcx>,
40 index: usize,
42 },
43}
44pub type Path<'tcx> = Vec<PathChunk<'tcx>>;
45
46#[derive(Debug, Clone)]
47pub enum ImplExprAtom<'tcx> {
48 Concrete {
50 def_id: DefId,
51 generics: GenericArgsRef<'tcx>,
52 impl_exprs: Vec<ImplExpr<'tcx>>,
54 },
55 LocalBound {
57 predicate: Predicate<'tcx>,
58 index: usize,
61 r#trait: PolyTraitRef<'tcx>,
62 path: Path<'tcx>,
63 },
64 SelfImpl {
66 r#trait: PolyTraitRef<'tcx>,
67 path: Path<'tcx>,
68 },
69 Dyn,
76 Builtin {
80 r#trait: PolyTraitRef<'tcx>,
81 impl_exprs: Vec<ImplExpr<'tcx>>,
85 types: Vec<(DefId, Ty<'tcx>)>,
87 },
88 Error(String),
90}
91
92#[derive(Clone, Debug)]
93pub struct ImplExpr<'tcx> {
94 pub r#trait: PolyTraitRef<'tcx>,
96 pub r#impl: ImplExprAtom<'tcx>,
98}
99
100#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
103pub enum BoundPredicateOrigin {
104 SelfPred,
107 Item(usize),
110}
111
112#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
113pub struct AnnotatedTraitPred<'tcx> {
114 pub origin: BoundPredicateOrigin,
115 pub clause: PolyTraitPredicate<'tcx>,
116}
117
118fn initial_search_predicates<'tcx>(
122 tcx: TyCtxt<'tcx>,
123 def_id: rustc_span::def_id::DefId,
124) -> Vec<AnnotatedTraitPred<'tcx>> {
125 fn acc_predicates<'tcx>(
126 tcx: TyCtxt<'tcx>,
127 def_id: rustc_span::def_id::DefId,
128 predicates: &mut Vec<AnnotatedTraitPred<'tcx>>,
129 pred_id: &mut usize,
130 ) {
131 let next_item_origin = |pred_id: &mut usize| {
132 let origin = BoundPredicateOrigin::Item(*pred_id);
133 *pred_id += 1;
134 origin
135 };
136 use DefKind::*;
137 match tcx.def_kind(def_id) {
138 AssocTy | AssocFn | AssocConst | Closure | Ctor(..) | Variant => {
140 let parent = tcx.parent(def_id);
141 acc_predicates(tcx, parent, predicates, pred_id);
142 }
143 Trait => {
144 let self_pred = self_predicate(tcx, def_id).upcast(tcx);
145 predicates.push(AnnotatedTraitPred {
146 origin: BoundPredicateOrigin::SelfPred,
147 clause: self_pred,
148 })
149 }
150 _ => {}
151 }
152 predicates.extend(
153 required_predicates(tcx, def_id)
154 .predicates
155 .iter()
156 .map(|(clause, _span)| *clause)
157 .filter_map(|clause| {
158 clause.as_trait_clause().map(|clause| AnnotatedTraitPred {
159 origin: next_item_origin(pred_id),
160 clause,
161 })
162 }),
163 );
164 }
165
166 let mut predicates = vec![];
167 acc_predicates(tcx, def_id, &mut predicates, &mut 0);
168 predicates
169}
170
171#[tracing::instrument(level = "trace", skip(tcx))]
172fn parents_trait_predicates<'tcx>(
173 tcx: TyCtxt<'tcx>,
174 pred: PolyTraitPredicate<'tcx>,
175) -> Vec<PolyTraitPredicate<'tcx>> {
176 let self_trait_ref = pred.to_poly_trait_ref();
177 implied_predicates(tcx, pred.def_id())
178 .predicates
179 .iter()
180 .map(|(clause, _span)| *clause)
181 .map(|clause| clause.instantiate_supertrait(tcx, self_trait_ref))
184 .filter_map(|pred| pred.as_trait_clause())
185 .collect()
186}
187
188#[derive(Debug, Clone)]
191struct Candidate<'tcx> {
192 path: Path<'tcx>,
193 pred: PolyTraitPredicate<'tcx>,
194 origin: AnnotatedTraitPred<'tcx>,
195}
196
197pub struct PredicateSearcher<'tcx> {
199 tcx: TyCtxt<'tcx>,
200 typing_env: rustc_middle::ty::TypingEnv<'tcx>,
201 candidates: HashMap<PolyTraitPredicate<'tcx>, Candidate<'tcx>>,
203}
204
205impl<'tcx> PredicateSearcher<'tcx> {
206 pub fn new_for_owner(tcx: TyCtxt<'tcx>, owner_id: DefId) -> Self {
208 let mut out = Self {
209 tcx,
210 typing_env: TypingEnv {
211 param_env: tcx.param_env(owner_id),
212 typing_mode: TypingMode::PostAnalysis,
213 },
214 candidates: Default::default(),
215 };
216 out.extend(
217 initial_search_predicates(tcx, owner_id)
218 .into_iter()
219 .map(|clause| Candidate {
220 path: vec![],
221 pred: clause.clause,
222 origin: clause,
223 }),
224 );
225 out
226 }
227
228 fn extend(&mut self, candidates: impl IntoIterator<Item = Candidate<'tcx>>) {
231 let tcx = self.tcx;
232 let mut new_candidates = Vec::new();
234 for mut candidate in candidates {
235 candidate.pred = erase_and_norm(tcx, self.typing_env, candidate.pred);
237 if let Entry::Vacant(entry) = self.candidates.entry(candidate.pred) {
238 entry.insert(candidate.clone());
239 new_candidates.push(candidate);
240 }
241 }
242 if !new_candidates.is_empty() {
243 self.extend_parents(new_candidates);
244 }
245 }
246
247 fn extend_parents(&mut self, new_candidates: Vec<Candidate<'tcx>>) {
251 let tcx = self.tcx;
252 self.extend(new_candidates.into_iter().flat_map(|candidate| {
255 parents_trait_predicates(tcx, candidate.pred)
256 .into_iter()
257 .enumerate()
258 .map(move |(index, parent_pred)| {
259 let mut parent_candidate = Candidate {
260 pred: parent_pred,
261 path: candidate.path.clone(),
262 origin: candidate.origin,
263 };
264 parent_candidate.path.push(PathChunk::Parent {
265 predicate: parent_pred,
266 index,
267 });
268 parent_candidate
269 })
270 }));
271 }
272
273 fn add_associated_type_refs(
275 &mut self,
276 ty: Binder<'tcx, Ty<'tcx>>,
277 warn: &impl Fn(&str),
279 ) -> Result<(), String> {
280 let tcx = self.tcx;
281 let TyKind::Alias(AliasTyKind::Projection, alias_ty) = ty.skip_binder().kind() else {
283 return Ok(());
284 };
285 let (trait_ref, item_args) = alias_ty.trait_ref_and_own_args(tcx);
286 let trait_ref = ty.rebind(trait_ref).upcast(tcx);
287
288 let Some(trait_candidate) = self.resolve_local(trait_ref, warn)? else {
291 return Ok(());
292 };
293
294 let item_bounds = implied_predicates(tcx, alias_ty.def_id)
296 .predicates
297 .iter()
298 .map(|(clause, _span)| *clause)
299 .filter_map(|pred| pred.as_trait_clause())
300 .map(|pred| EarlyBinder::bind(pred).instantiate(tcx, alias_ty.args))
302 .enumerate();
303
304 let nested_impl_exprs =
306 self.resolve_item_required_predicates(alias_ty.def_id, alias_ty.args, warn)?;
307
308 self.extend(item_bounds.map(|(index, pred)| {
310 let mut candidate = Candidate {
311 path: trait_candidate.path.clone(),
312 pred,
313 origin: trait_candidate.origin,
314 };
315 candidate.path.push(PathChunk::AssocItem {
316 item: tcx.associated_item(alias_ty.def_id),
317 generic_args: item_args,
318 impl_exprs: nested_impl_exprs.clone(),
319 predicate: pred,
320 index,
321 });
322 candidate
323 }));
324
325 Ok(())
326 }
327
328 fn resolve_local(
331 &mut self,
332 target: PolyTraitPredicate<'tcx>,
333 warn: &impl Fn(&str),
335 ) -> Result<Option<Candidate<'tcx>>, String> {
336 tracing::trace!("Looking for {target:?}");
337
338 let ret = self.candidates.get(&target).cloned();
340 if ret.is_some() {
341 return Ok(ret);
342 }
343
344 self.add_associated_type_refs(target.self_ty(), warn)?;
346
347 let ret = self.candidates.get(&target).cloned();
348 if ret.is_none() {
349 tracing::trace!(
350 "Couldn't find {target:?} in: [\n{}]",
351 self.candidates
352 .iter()
353 .map(|(_, c)| format!(" - {:?}\n", c.pred))
354 .join("")
355 );
356 }
357 Ok(ret)
358 }
359
360 #[tracing::instrument(level = "trace", skip(self, warn))]
362 pub fn resolve(
363 &mut self,
364 tref: &PolyTraitRef<'tcx>,
365 warn: &impl Fn(&str),
367 ) -> Result<ImplExpr<'tcx>, String> {
368 use rustc_trait_selection::traits::{
369 BuiltinImplSource, ImplSource, ImplSourceUserDefinedData,
370 };
371
372 let erased_tref = erase_and_norm(self.tcx, self.typing_env, *tref);
373
374 let tcx = self.tcx;
375 let impl_source = shallow_resolve_trait_ref(tcx, self.typing_env.param_env, erased_tref);
376 let atom = match impl_source {
377 Ok(ImplSource::UserDefined(ImplSourceUserDefinedData {
378 impl_def_id,
379 args: generics,
380 ..
381 })) => {
382 let impl_exprs =
384 self.resolve_item_required_predicates(impl_def_id, generics, warn)?;
385 ImplExprAtom::Concrete {
386 def_id: impl_def_id,
387 generics,
388 impl_exprs,
389 }
390 }
391 Ok(ImplSource::Param(_)) => {
392 match self.resolve_local(erased_tref.upcast(self.tcx), warn)? {
393 Some(candidate) => {
394 let path = candidate.path;
395 let r#trait = candidate.origin.clause.to_poly_trait_ref();
396 match candidate.origin.origin {
397 BoundPredicateOrigin::SelfPred => {
398 ImplExprAtom::SelfImpl { r#trait, path }
399 }
400 BoundPredicateOrigin::Item(index) => ImplExprAtom::LocalBound {
401 predicate: candidate.origin.clause.upcast(tcx),
402 index,
403 r#trait,
404 path,
405 },
406 }
407 }
408 None => {
409 let msg = format!(
410 "Could not find a clause for `{tref:?}` in the item parameters"
411 );
412 warn(&msg);
413 ImplExprAtom::Error(msg)
414 }
415 }
416 }
417 Ok(ImplSource::Builtin(BuiltinImplSource::Object { .. }, _)) => ImplExprAtom::Dyn,
418 Ok(ImplSource::Builtin(_, _)) => {
419 let trait_def_id = erased_tref.skip_binder().def_id;
421 let impl_exprs = self.resolve_item_implied_predicates(
425 trait_def_id,
426 erased_tref.skip_binder().args,
427 warn,
428 )?;
429 let types = tcx
430 .associated_items(trait_def_id)
431 .in_definition_order()
432 .filter(|assoc| matches!(assoc.kind, AssocKind::Type))
433 .filter_map(|assoc| {
434 let ty =
435 Ty::new_projection(tcx, assoc.def_id, erased_tref.skip_binder().args);
436 let ty = erase_and_norm(tcx, self.typing_env, ty);
437 if let TyKind::Alias(_, alias_ty) = ty.kind() {
438 if alias_ty.def_id == assoc.def_id {
439 return None;
448 }
449 }
450 Some((assoc.def_id, ty))
451 })
452 .collect();
453 ImplExprAtom::Builtin {
454 r#trait: *tref,
455 impl_exprs,
456 types,
457 }
458 }
459 Err(e) => {
460 let msg = format!(
461 "Could not find a clause for `{tref:?}` in the current context: `{e:?}`"
462 );
463 warn(&msg);
464 ImplExprAtom::Error(msg)
465 }
466 };
467
468 Ok(ImplExpr {
469 r#impl: atom,
470 r#trait: *tref,
471 })
472 }
473
474 pub fn resolve_item_required_predicates(
476 &mut self,
477 def_id: DefId,
478 generics: GenericArgsRef<'tcx>,
479 warn: &impl Fn(&str),
481 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
482 let tcx = self.tcx;
483 self.resolve_predicates(generics, required_predicates(tcx, def_id), warn)
484 }
485
486 pub fn resolve_item_implied_predicates(
488 &mut self,
489 def_id: DefId,
490 generics: GenericArgsRef<'tcx>,
491 warn: &impl Fn(&str),
493 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
494 let tcx = self.tcx;
495 self.resolve_predicates(generics, implied_predicates(tcx, def_id), warn)
496 }
497
498 pub fn resolve_predicates(
501 &mut self,
502 generics: GenericArgsRef<'tcx>,
503 predicates: GenericPredicates<'tcx>,
504 warn: &impl Fn(&str),
506 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
507 let tcx = self.tcx;
508 predicates
509 .predicates
510 .iter()
511 .map(|(clause, _span)| *clause)
512 .filter_map(|clause| clause.as_trait_clause())
513 .map(|trait_pred| trait_pred.map_bound(|p| p.trait_ref))
514 .map(|trait_ref| EarlyBinder::bind(trait_ref).instantiate(tcx, generics))
516 .map(|trait_ref| self.resolve(&trait_ref, warn))
518 .collect()
519 }
520}
521
522pub fn shallow_resolve_trait_ref<'tcx>(
530 tcx: TyCtxt<'tcx>,
531 param_env: ParamEnv<'tcx>,
532 trait_ref: PolyTraitRef<'tcx>,
533) -> Result<ImplSource<'tcx, ()>, CodegenObligationError> {
534 use rustc_infer::infer::TyCtxtInferExt;
535 use rustc_middle::traits::CodegenObligationError;
536 use rustc_middle::ty::TypeVisitableExt;
537 use rustc_trait_selection::traits::{
538 Obligation, ObligationCause, ObligationCtxt, SelectionContext, Unimplemented,
539 };
540 let infcx = tcx
543 .infer_ctxt()
544 .ignoring_regions()
545 .build(TypingMode::PostAnalysis);
546 let mut selcx = SelectionContext::new(&infcx);
547
548 let obligation_cause = ObligationCause::dummy();
549 let obligation = Obligation::new(tcx, obligation_cause, param_env, trait_ref);
550
551 let selection = match selcx.poly_select(&obligation) {
552 Ok(Some(selection)) => selection,
553 Ok(None) => return Err(CodegenObligationError::Ambiguity),
554 Err(Unimplemented) => return Err(CodegenObligationError::Unimplemented),
555 Err(_) => return Err(CodegenObligationError::Ambiguity),
556 };
557
558 let ocx = ObligationCtxt::new(&infcx);
563 let impl_source = selection.map(|obligation| {
564 ocx.register_obligation(obligation.clone());
565 ()
566 });
567
568 let errors = ocx.select_all_or_error();
569 if !errors.is_empty() {
570 return Err(CodegenObligationError::Ambiguity);
571 }
572
573 let impl_source = infcx.resolve_vars_if_possible(impl_source);
574 let impl_source = tcx.erase_regions(impl_source);
575
576 if impl_source.has_infer() {
577 return Err(CodegenObligationError::Ambiguity);
579 }
580
581 Ok(impl_source)
582}