hax_frontend_exporter/traits/
resolution.rs

1//! Trait resolution: given a trait reference, we track which local clause caused it to be true.
2//! This module is independent from the rest of hax, in particular it doesn't use its
3//! state-tracking machinery.
4
5use itertools::Itertools;
6use std::collections::{hash_map::Entry, HashMap};
7
8use rustc_hir::def::DefKind;
9use rustc_hir::def_id::DefId;
10use rustc_middle::traits::CodegenObligationError;
11use rustc_middle::ty::*;
12use rustc_trait_selection::traits::ImplSource;
13
14use crate::{self_predicate, traits::utils::erase_and_norm};
15
16use super::utils::{implied_predicates, required_predicates, ToPolyTraitRef};
17
18#[derive(Debug, Clone)]
19pub enum PathChunk<'tcx> {
20    AssocItem {
21        item: AssocItem,
22        /// The arguments provided to the item (for GATs).
23        generic_args: &'tcx [GenericArg<'tcx>],
24        /// The impl exprs that must be satisfied to apply the given arguments to the item. E.g.
25        /// `T: Clone` in the following example:
26        /// ```ignore
27        /// trait Foo {
28        ///     type Type<T: Clone>: Debug;
29        /// }
30        /// ```
31        impl_exprs: Vec<ImplExpr<'tcx>>,
32        /// The implemented predicate.
33        predicate: PolyTraitPredicate<'tcx>,
34        /// The index of this predicate in the list returned by `implied_predicates`.
35        index: usize,
36    },
37    Parent {
38        /// The implemented predicate.
39        predicate: PolyTraitPredicate<'tcx>,
40        /// The index of this predicate in the list returned by `implied_predicates`.
41        index: usize,
42    },
43}
44pub type Path<'tcx> = Vec<PathChunk<'tcx>>;
45
46#[derive(Debug, Clone)]
47pub enum ImplExprAtom<'tcx> {
48    /// A concrete `impl Trait for Type {}` item.
49    Concrete {
50        def_id: DefId,
51        generics: GenericArgsRef<'tcx>,
52        /// The impl exprs that prove the clauses on the impl.
53        impl_exprs: Vec<ImplExpr<'tcx>>,
54    },
55    /// A context-bound clause like `where T: Trait`.
56    LocalBound {
57        predicate: Predicate<'tcx>,
58        /// The nth (non-self) predicate found for this item. We use predicates from
59        /// `required_predicates` starting from the parentmost item.
60        index: usize,
61        r#trait: PolyTraitRef<'tcx>,
62        path: Path<'tcx>,
63    },
64    /// The automatic clause `Self: Trait` present inside a `impl Trait for Type {}` item.
65    SelfImpl {
66        r#trait: PolyTraitRef<'tcx>,
67        path: Path<'tcx>,
68    },
69    /// `dyn Trait` is a wrapped value with a virtual table for trait
70    /// `Trait`.  In other words, a value `dyn Trait` is a dependent
71    /// triple that gathers a type τ, a value of type τ and an
72    /// instance of type `Trait`.
73    /// `dyn Trait` implements `Trait` using a built-in implementation; this refers to that
74    /// built-in implementation.
75    Dyn,
76    /// A built-in trait whose implementation is computed by the compiler, such as `FnMut`. This
77    /// morally points to an invisible `impl` block; as such it contains the information we may
78    /// need from one.
79    Builtin {
80        r#trait: PolyTraitRef<'tcx>,
81        /// The `ImplExpr`s required to satisfy the implied predicates on the trait declaration.
82        /// E.g. since `FnMut: FnOnce`, a built-in `T: FnMut` impl would have an `ImplExpr` for `T:
83        /// FnOnce`.
84        impl_exprs: Vec<ImplExpr<'tcx>>,
85        /// The values of the associated types for this trait.
86        types: Vec<(DefId, Ty<'tcx>)>,
87    },
88    /// An error happened while resolving traits.
89    Error(String),
90}
91
92#[derive(Clone, Debug)]
93pub struct ImplExpr<'tcx> {
94    /// The trait this is an impl for.
95    pub r#trait: PolyTraitRef<'tcx>,
96    /// The kind of implemention of the root of the tree.
97    pub r#impl: ImplExprAtom<'tcx>,
98}
99
100/// Items have various predicates in scope. `path_to` uses them as a starting point for trait
101/// resolution. This tracks where each of them comes from.
102#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
103pub enum BoundPredicateOrigin {
104    /// The `Self: Trait` predicate implicitly present within trait declarations (note: we
105    /// don't add it for trait implementations, should we?).
106    SelfPred,
107    /// The nth (non-self) predicate found for this item. We use predicates from
108    /// `required_predicates` starting from the parentmost item.
109    Item(usize),
110}
111
112#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
113pub struct AnnotatedTraitPred<'tcx> {
114    pub origin: BoundPredicateOrigin,
115    pub clause: PolyTraitPredicate<'tcx>,
116}
117
118/// The predicates to use as a starting point for resolving trait references within this item. This
119/// includes the "self" predicate if applicable and the `required_predicates` of this item and all
120/// its parents, numbered starting from the parents.
121fn initial_search_predicates<'tcx>(
122    tcx: TyCtxt<'tcx>,
123    def_id: rustc_span::def_id::DefId,
124) -> Vec<AnnotatedTraitPred<'tcx>> {
125    fn acc_predicates<'tcx>(
126        tcx: TyCtxt<'tcx>,
127        def_id: rustc_span::def_id::DefId,
128        predicates: &mut Vec<AnnotatedTraitPred<'tcx>>,
129        pred_id: &mut usize,
130    ) {
131        let next_item_origin = |pred_id: &mut usize| {
132            let origin = BoundPredicateOrigin::Item(*pred_id);
133            *pred_id += 1;
134            origin
135        };
136        use DefKind::*;
137        match tcx.def_kind(def_id) {
138            // These inherit some predicates from their parent.
139            AssocTy | AssocFn | AssocConst | Closure | Ctor(..) | Variant => {
140                let parent = tcx.parent(def_id);
141                acc_predicates(tcx, parent, predicates, pred_id);
142            }
143            Trait => {
144                let self_pred = self_predicate(tcx, def_id).upcast(tcx);
145                predicates.push(AnnotatedTraitPred {
146                    origin: BoundPredicateOrigin::SelfPred,
147                    clause: self_pred,
148                })
149            }
150            _ => {}
151        }
152        predicates.extend(
153            required_predicates(tcx, def_id)
154                .predicates
155                .iter()
156                .map(|(clause, _span)| *clause)
157                .filter_map(|clause| {
158                    clause.as_trait_clause().map(|clause| AnnotatedTraitPred {
159                        origin: next_item_origin(pred_id),
160                        clause,
161                    })
162                }),
163        );
164    }
165
166    let mut predicates = vec![];
167    acc_predicates(tcx, def_id, &mut predicates, &mut 0);
168    predicates
169}
170
171#[tracing::instrument(level = "trace", skip(tcx))]
172fn parents_trait_predicates<'tcx>(
173    tcx: TyCtxt<'tcx>,
174    pred: PolyTraitPredicate<'tcx>,
175) -> Vec<PolyTraitPredicate<'tcx>> {
176    let self_trait_ref = pred.to_poly_trait_ref();
177    implied_predicates(tcx, pred.def_id())
178        .predicates
179        .iter()
180        .map(|(clause, _span)| *clause)
181        // Substitute with the `self` args so that the clause makes sense in the
182        // outside context.
183        .map(|clause| clause.instantiate_supertrait(tcx, self_trait_ref))
184        .filter_map(|pred| pred.as_trait_clause())
185        .collect()
186}
187
188/// A candidate projects `self` along a path reaching some predicate. A candidate is
189/// selected when its predicate is the one expected, aka `target`.
190#[derive(Debug, Clone)]
191struct Candidate<'tcx> {
192    path: Path<'tcx>,
193    pred: PolyTraitPredicate<'tcx>,
194    origin: AnnotatedTraitPred<'tcx>,
195}
196
197/// Stores a set of predicates along with where they came from.
198pub struct PredicateSearcher<'tcx> {
199    tcx: TyCtxt<'tcx>,
200    typing_env: rustc_middle::ty::TypingEnv<'tcx>,
201    /// Local clauses available in the current context.
202    candidates: HashMap<PolyTraitPredicate<'tcx>, Candidate<'tcx>>,
203}
204
205impl<'tcx> PredicateSearcher<'tcx> {
206    /// Initialize the elaborator with the predicates accessible within this item.
207    pub fn new_for_owner(tcx: TyCtxt<'tcx>, owner_id: DefId) -> Self {
208        let mut out = Self {
209            tcx,
210            typing_env: TypingEnv {
211                param_env: tcx.param_env(owner_id),
212                typing_mode: TypingMode::PostAnalysis,
213            },
214            candidates: Default::default(),
215        };
216        out.extend(
217            initial_search_predicates(tcx, owner_id)
218                .into_iter()
219                .map(|clause| Candidate {
220                    path: vec![],
221                    pred: clause.clause,
222                    origin: clause,
223                }),
224        );
225        out
226    }
227
228    /// Insert new candidates and all their parent predicates. This deduplicates predicates
229    /// to avoid divergence.
230    fn extend(&mut self, candidates: impl IntoIterator<Item = Candidate<'tcx>>) {
231        let tcx = self.tcx;
232        // Filter out duplicated candidates.
233        let mut new_candidates = Vec::new();
234        for mut candidate in candidates {
235            // Normalize and erase all lifetimes.
236            candidate.pred = erase_and_norm(tcx, self.typing_env, candidate.pred);
237            if let Entry::Vacant(entry) = self.candidates.entry(candidate.pred) {
238                entry.insert(candidate.clone());
239                new_candidates.push(candidate);
240            }
241        }
242        if !new_candidates.is_empty() {
243            self.extend_parents(new_candidates);
244        }
245    }
246
247    /// Add the parents of these candidates. This is a separate function to avoid
248    /// polymorphic recursion due to the closures capturing the type parameters of this
249    /// function.
250    fn extend_parents(&mut self, new_candidates: Vec<Candidate<'tcx>>) {
251        let tcx = self.tcx;
252        // Then recursively add their parents. This way ensures a breadth-first order,
253        // which means we select the shortest path when looking up predicates.
254        self.extend(new_candidates.into_iter().flat_map(|candidate| {
255            parents_trait_predicates(tcx, candidate.pred)
256                .into_iter()
257                .enumerate()
258                .map(move |(index, parent_pred)| {
259                    let mut parent_candidate = Candidate {
260                        pred: parent_pred,
261                        path: candidate.path.clone(),
262                        origin: candidate.origin,
263                    };
264                    parent_candidate.path.push(PathChunk::Parent {
265                        predicate: parent_pred,
266                        index,
267                    });
268                    parent_candidate
269                })
270        }));
271    }
272
273    /// If the type is a trait associated type, we add any relevant bounds to our context.
274    fn add_associated_type_refs(
275        &mut self,
276        ty: Binder<'tcx, Ty<'tcx>>,
277        // Call back into hax-related code to display a nice warning.
278        warn: &impl Fn(&str),
279    ) -> Result<(), String> {
280        let tcx = self.tcx;
281        // Note: We skip a binder but rebind it just after.
282        let TyKind::Alias(AliasTyKind::Projection, alias_ty) = ty.skip_binder().kind() else {
283            return Ok(());
284        };
285        let (trait_ref, item_args) = alias_ty.trait_ref_and_own_args(tcx);
286        let trait_ref = ty.rebind(trait_ref).upcast(tcx);
287
288        // The predicate we're looking for is is `<T as Trait>::Type: OtherTrait`. We look up `T as
289        // Trait` in the current context and add all the bounds on `Trait::Type` to our context.
290        let Some(trait_candidate) = self.resolve_local(trait_ref, warn)? else {
291            return Ok(());
292        };
293
294        // The bounds that hold on the associated type.
295        let item_bounds = implied_predicates(tcx, alias_ty.def_id)
296            .predicates
297            .iter()
298            .map(|(clause, _span)| *clause)
299            .filter_map(|pred| pred.as_trait_clause())
300            // Substitute the item generics
301            .map(|pred| EarlyBinder::bind(pred).instantiate(tcx, alias_ty.args))
302            .enumerate();
303
304        // Resolve predicates required to mention the item.
305        let nested_impl_exprs =
306            self.resolve_item_required_predicates(alias_ty.def_id, alias_ty.args, warn)?;
307
308        // Add all the bounds on the corresponding associated item.
309        self.extend(item_bounds.map(|(index, pred)| {
310            let mut candidate = Candidate {
311                path: trait_candidate.path.clone(),
312                pred,
313                origin: trait_candidate.origin,
314            };
315            candidate.path.push(PathChunk::AssocItem {
316                item: tcx.associated_item(alias_ty.def_id),
317                generic_args: item_args,
318                impl_exprs: nested_impl_exprs.clone(),
319                predicate: pred,
320                index,
321            });
322            candidate
323        }));
324
325        Ok(())
326    }
327
328    /// Resolve a local clause by looking it up in this set. If the predicate applies to an
329    /// associated type, we add the relevant implied associated type bounds to the set as well.
330    fn resolve_local(
331        &mut self,
332        target: PolyTraitPredicate<'tcx>,
333        // Call back into hax-related code to display a nice warning.
334        warn: &impl Fn(&str),
335    ) -> Result<Option<Candidate<'tcx>>, String> {
336        tracing::trace!("Looking for {target:?}");
337
338        // Look up the predicate
339        let ret = self.candidates.get(&target).cloned();
340        if ret.is_some() {
341            return Ok(ret);
342        }
343
344        // Add clauses related to associated type in the `Self` type of the predicate.
345        self.add_associated_type_refs(target.self_ty(), warn)?;
346
347        let ret = self.candidates.get(&target).cloned();
348        if ret.is_none() {
349            tracing::trace!(
350                "Couldn't find {target:?} in: [\n{}]",
351                self.candidates
352                    .iter()
353                    .map(|(_, c)| format!("  - {:?}\n", c.pred))
354                    .join("")
355            );
356        }
357        Ok(ret)
358    }
359
360    /// Resolve the given trait reference in the local context.
361    #[tracing::instrument(level = "trace", skip(self, warn))]
362    pub fn resolve(
363        &mut self,
364        tref: &PolyTraitRef<'tcx>,
365        // Call back into hax-related code to display a nice warning.
366        warn: &impl Fn(&str),
367    ) -> Result<ImplExpr<'tcx>, String> {
368        use rustc_trait_selection::traits::{
369            BuiltinImplSource, ImplSource, ImplSourceUserDefinedData,
370        };
371
372        let erased_tref = erase_and_norm(self.tcx, self.typing_env, *tref);
373
374        let tcx = self.tcx;
375        let impl_source = shallow_resolve_trait_ref(tcx, self.typing_env.param_env, erased_tref);
376        let atom = match impl_source {
377            Ok(ImplSource::UserDefined(ImplSourceUserDefinedData {
378                impl_def_id,
379                args: generics,
380                ..
381            })) => {
382                // Resolve the predicates required by the impl.
383                let impl_exprs =
384                    self.resolve_item_required_predicates(impl_def_id, generics, warn)?;
385                ImplExprAtom::Concrete {
386                    def_id: impl_def_id,
387                    generics,
388                    impl_exprs,
389                }
390            }
391            Ok(ImplSource::Param(_)) => {
392                match self.resolve_local(erased_tref.upcast(self.tcx), warn)? {
393                    Some(candidate) => {
394                        let path = candidate.path;
395                        let r#trait = candidate.origin.clause.to_poly_trait_ref();
396                        match candidate.origin.origin {
397                            BoundPredicateOrigin::SelfPred => {
398                                ImplExprAtom::SelfImpl { r#trait, path }
399                            }
400                            BoundPredicateOrigin::Item(index) => ImplExprAtom::LocalBound {
401                                predicate: candidate.origin.clause.upcast(tcx),
402                                index,
403                                r#trait,
404                                path,
405                            },
406                        }
407                    }
408                    None => {
409                        let msg = format!(
410                            "Could not find a clause for `{tref:?}` in the item parameters"
411                        );
412                        warn(&msg);
413                        ImplExprAtom::Error(msg)
414                    }
415                }
416            }
417            Ok(ImplSource::Builtin(BuiltinImplSource::Object { .. }, _)) => ImplExprAtom::Dyn,
418            Ok(ImplSource::Builtin(_, _)) => {
419                // Resolve the predicates implied by the trait.
420                let trait_def_id = erased_tref.skip_binder().def_id;
421                // If we wanted to not skip this binder, we'd have to instantiate the bound
422                // regions, solve, then wrap the result in a binder. And track  higher-kinded
423                // clauses better all over.
424                let impl_exprs = self.resolve_item_implied_predicates(
425                    trait_def_id,
426                    erased_tref.skip_binder().args,
427                    warn,
428                )?;
429                let types = tcx
430                    .associated_items(trait_def_id)
431                    .in_definition_order()
432                    .filter(|assoc| matches!(assoc.kind, AssocKind::Type))
433                    .filter_map(|assoc| {
434                        let ty =
435                            Ty::new_projection(tcx, assoc.def_id, erased_tref.skip_binder().args);
436                        let ty = erase_and_norm(tcx, self.typing_env, ty);
437                        if let TyKind::Alias(_, alias_ty) = ty.kind() {
438                            if alias_ty.def_id == assoc.def_id {
439                                // Couldn't normalize the type to anything different than itself;
440                                // this must be a built-in associated type such as
441                                // `DiscriminantKind::Discriminant`.
442                                // We can't return the unnormalized associated type as that would
443                                // make the trait ref contain itself, which would make hax's
444                                // `sinto` infrastructure loop. That's ok because we can't provide
445                                // a value for this type other than the associate type alias
446                                // itself.
447                                return None;
448                            }
449                        }
450                        Some((assoc.def_id, ty))
451                    })
452                    .collect();
453                ImplExprAtom::Builtin {
454                    r#trait: *tref,
455                    impl_exprs,
456                    types,
457                }
458            }
459            Err(e) => {
460                let msg = format!(
461                    "Could not find a clause for `{tref:?}` in the current context: `{e:?}`"
462                );
463                warn(&msg);
464                ImplExprAtom::Error(msg)
465            }
466        };
467
468        Ok(ImplExpr {
469            r#impl: atom,
470            r#trait: *tref,
471        })
472    }
473
474    /// Resolve the predicates required by the given item.
475    pub fn resolve_item_required_predicates(
476        &mut self,
477        def_id: DefId,
478        generics: GenericArgsRef<'tcx>,
479        // Call back into hax-related code to display a nice warning.
480        warn: &impl Fn(&str),
481    ) -> Result<Vec<ImplExpr<'tcx>>, String> {
482        let tcx = self.tcx;
483        self.resolve_predicates(generics, required_predicates(tcx, def_id), warn)
484    }
485
486    /// Resolve the predicates implied by the given item.
487    pub fn resolve_item_implied_predicates(
488        &mut self,
489        def_id: DefId,
490        generics: GenericArgsRef<'tcx>,
491        // Call back into hax-related code to display a nice warning.
492        warn: &impl Fn(&str),
493    ) -> Result<Vec<ImplExpr<'tcx>>, String> {
494        let tcx = self.tcx;
495        self.resolve_predicates(generics, implied_predicates(tcx, def_id), warn)
496    }
497
498    /// Apply the given generics to the provided clauses and resolve the trait references in the
499    /// current context.
500    pub fn resolve_predicates(
501        &mut self,
502        generics: GenericArgsRef<'tcx>,
503        predicates: GenericPredicates<'tcx>,
504        // Call back into hax-related code to display a nice warning.
505        warn: &impl Fn(&str),
506    ) -> Result<Vec<ImplExpr<'tcx>>, String> {
507        let tcx = self.tcx;
508        predicates
509            .predicates
510            .iter()
511            .map(|(clause, _span)| *clause)
512            .filter_map(|clause| clause.as_trait_clause())
513            .map(|trait_pred| trait_pred.map_bound(|p| p.trait_ref))
514            // Substitute the item generics
515            .map(|trait_ref| EarlyBinder::bind(trait_ref).instantiate(tcx, generics))
516            // Resolve
517            .map(|trait_ref| self.resolve(&trait_ref, warn))
518            .collect()
519    }
520}
521
522/// Attempts to resolve an obligation to an `ImplSource`. The result is a shallow `ImplSource`
523/// resolution, meaning that we do not resolve all nested obligations on the impl. Note that type
524/// check should guarantee to us that all nested obligations *could be* resolved if we wanted to.
525///
526/// This expects that `trait_ref` is fully normalized.
527///
528/// This is based on `rustc_traits::codegen::codegen_select_candidate` in rustc.
529pub fn shallow_resolve_trait_ref<'tcx>(
530    tcx: TyCtxt<'tcx>,
531    param_env: ParamEnv<'tcx>,
532    trait_ref: PolyTraitRef<'tcx>,
533) -> Result<ImplSource<'tcx, ()>, CodegenObligationError> {
534    use rustc_infer::infer::TyCtxtInferExt;
535    use rustc_middle::traits::CodegenObligationError;
536    use rustc_middle::ty::TypeVisitableExt;
537    use rustc_trait_selection::traits::{
538        Obligation, ObligationCause, ObligationCtxt, SelectionContext, Unimplemented,
539    };
540    // Do the initial selection for the obligation. This yields the
541    // shallow result we are looking for -- that is, what specific impl.
542    let infcx = tcx
543        .infer_ctxt()
544        .ignoring_regions()
545        .build(TypingMode::PostAnalysis);
546    let mut selcx = SelectionContext::new(&infcx);
547
548    let obligation_cause = ObligationCause::dummy();
549    let obligation = Obligation::new(tcx, obligation_cause, param_env, trait_ref);
550
551    let selection = match selcx.poly_select(&obligation) {
552        Ok(Some(selection)) => selection,
553        Ok(None) => return Err(CodegenObligationError::Ambiguity),
554        Err(Unimplemented) => return Err(CodegenObligationError::Unimplemented),
555        Err(_) => return Err(CodegenObligationError::Ambiguity),
556    };
557
558    // Currently, we use a fulfillment context to completely resolve
559    // all nested obligations. This is because they can inform the
560    // inference of the impl's type parameters.
561    // FIXME(-Znext-solver): Doesn't need diagnostics if new solver.
562    let ocx = ObligationCtxt::new(&infcx);
563    let impl_source = selection.map(|obligation| {
564        ocx.register_obligation(obligation.clone());
565        ()
566    });
567
568    let errors = ocx.select_all_or_error();
569    if !errors.is_empty() {
570        return Err(CodegenObligationError::Ambiguity);
571    }
572
573    let impl_source = infcx.resolve_vars_if_possible(impl_source);
574    let impl_source = tcx.erase_regions(impl_source);
575
576    if impl_source.has_infer() {
577        // Unused lifetimes on an impl get replaced with inference vars, but never resolved.
578        return Err(CodegenObligationError::Ambiguity);
579    }
580
581    Ok(impl_source)
582}