chalk_engine/slg/
aggregate.rs

1use crate::context::{self, AnswerResult};
2use crate::slg::SlgContextOps;
3use crate::slg::SubstitutionExt;
4use crate::CompleteAnswer;
5use chalk_ir::cast::Cast;
6use chalk_ir::interner::Interner;
7use chalk_ir::*;
8use chalk_solve::ext::*;
9use chalk_solve::infer::InferenceTable;
10use chalk_solve::solve::{Guidance, Solution};
11
12use std::fmt::Debug;
13
14/// Methods for combining solutions to yield an aggregate solution.
15pub trait AggregateOps<I: Interner> {
16    fn make_solution(
17        &self,
18        root_goal: &UCanonical<InEnvironment<Goal<I>>>,
19        answers: impl context::AnswerStream<I>,
20        should_continue: impl std::ops::Fn() -> bool + Clone,
21    ) -> Option<Solution<I>>;
22}
23
24/// Draws as many answers as it needs from `answers` (but
25/// no more!) in order to come up with a solution.
26impl<I: Interner> AggregateOps<I> for SlgContextOps<'_, I> {
27    fn make_solution(
28        &self,
29        root_goal: &UCanonical<InEnvironment<Goal<I>>>,
30        mut answers: impl context::AnswerStream<I>,
31        should_continue: impl std::ops::Fn() -> bool + Clone,
32    ) -> Option<Solution<I>> {
33        let interner = self.program.interner();
34        let CompleteAnswer { subst, ambiguous } = match answers.next_answer(&should_continue) {
35            AnswerResult::NoMoreSolutions => {
36                // No answers at all
37                return None;
38            }
39            AnswerResult::Answer(answer) => answer,
40            AnswerResult::Floundered => CompleteAnswer {
41                subst: self.identity_constrained_subst(root_goal),
42                ambiguous: true,
43            },
44            AnswerResult::QuantumExceeded => {
45                return Some(Solution::Ambig(Guidance::Unknown));
46            }
47        };
48
49        // Exactly 1 unconditional answer?
50        let next_answer = answers.peek_answer(&should_continue);
51        if next_answer.is_quantum_exceeded() {
52            if subst.value.subst.is_identity_subst(interner) {
53                return Some(Solution::Ambig(Guidance::Unknown));
54            } else {
55                return Some(Solution::Ambig(Guidance::Suggested(
56                    subst.map(interner, |cs| cs.subst),
57                )));
58            }
59        }
60        if next_answer.is_no_more_solutions() && !ambiguous {
61            return Some(Solution::Unique(subst));
62        }
63
64        // Otherwise, we either have >1 answer, or else we have
65        // ambiguity.  Either way, we are only going to be giving back
66        // **guidance**, and with guidance, the caller doesn't get
67        // back any region constraints. So drop them from our `subst`
68        // variable.
69        //
70        // FIXME-- there is actually a 3rd possibility. We could have
71        // >1 answer where all the answers have the same substitution,
72        // but different region constraints. We should collapse those
73        // cases into an `OR` region constraint at some point, but I
74        // leave that for future work. This is basically
75        // rust-lang/rust#21974.
76        let mut subst = subst.map(interner, |cs| cs.subst);
77
78        // Extract answers and merge them into `subst`. Stop once we have
79        // a trivial subst (or run out of answers).
80        let mut num_answers = 1;
81        let guidance = loop {
82            if subst.value.is_empty(interner) || is_trivial(interner, &subst) {
83                break Guidance::Unknown;
84            }
85
86            if !answers
87                .any_future_answer(|ref mut new_subst| new_subst.may_invalidate(interner, &subst))
88            {
89                break Guidance::Definite(subst);
90            }
91
92            if let Some(expected_answers) = self.expected_answers {
93                if num_answers >= expected_answers {
94                    panic!("Too many answers for solution.");
95                }
96            }
97
98            let new_subst = match answers.next_answer(&should_continue) {
99                AnswerResult::Answer(answer1) => answer1.subst,
100                AnswerResult::Floundered => {
101                    // FIXME: this doesn't trigger for any current tests
102                    self.identity_constrained_subst(root_goal)
103                }
104                AnswerResult::NoMoreSolutions => {
105                    break Guidance::Definite(subst);
106                }
107                AnswerResult::QuantumExceeded => {
108                    break Guidance::Suggested(subst);
109                }
110            };
111            subst = merge_into_guidance(interner, &root_goal.canonical, subst, &new_subst);
112            num_answers += 1;
113        };
114
115        if let Some(expected_answers) = self.expected_answers {
116            assert_eq!(
117                expected_answers, num_answers,
118                "Not enough answers for solution."
119            );
120        }
121        Some(Solution::Ambig(guidance))
122    }
123}
124
125/// Given a current substitution used as guidance for `root_goal`, and
126/// a new possible answer to `root_goal`, returns a new set of
127/// guidance that encompasses both of them. This is often more general
128/// than the old guidance. For example, if we had a guidance of `?0 =
129/// u32` and the new answer is `?0 = i32`, then the guidance would
130/// become `?0 = ?X` (where `?X` is some fresh variable).
131fn merge_into_guidance<I: Interner>(
132    interner: I,
133    root_goal: &Canonical<InEnvironment<Goal<I>>>,
134    guidance: Canonical<Substitution<I>>,
135    answer: &Canonical<ConstrainedSubst<I>>,
136) -> Canonical<Substitution<I>> {
137    let mut infer = InferenceTable::new();
138    let Canonical {
139        value: ConstrainedSubst {
140            subst: subst1,
141            constraints: _,
142        },
143        binders: _,
144    } = answer;
145
146    // Collect the types that the two substitutions have in
147    // common.
148    let aggr_generic_args: Vec<_> = guidance
149        .value
150        .iter(interner)
151        .zip(subst1.iter(interner))
152        .enumerate()
153        .map(|(index, (p1, p2))| {
154            // We have two values for some variable X that
155            // appears in the root goal. Find out the universe
156            // of X.
157            let universe = *root_goal.binders.as_slice(interner)[index].skip_kind();
158
159            match p1.data(interner) {
160                GenericArgData::Ty(_) => (),
161                GenericArgData::Lifetime(_) => {
162                    // Ignore the lifetimes from the substitution: we're just
163                    // creating guidance here anyway.
164                    return infer
165                        .new_variable(universe)
166                        .to_lifetime(interner)
167                        .cast(interner);
168                }
169                GenericArgData::Const(_) => (),
170            };
171
172            // Combine the two types into a new type.
173            let mut aggr = AntiUnifier {
174                infer: &mut infer,
175                universe,
176                interner,
177            };
178            aggr.aggregate_generic_args(p1, p2)
179        })
180        .collect();
181
182    let aggr_subst = Substitution::from_iter(interner, aggr_generic_args);
183
184    infer.canonicalize(interner, aggr_subst).quantified
185}
186
187fn is_trivial<I: Interner>(interner: I, subst: &Canonical<Substitution<I>>) -> bool {
188    // A subst is trivial if..
189    subst
190        .value
191        .iter(interner)
192        .enumerate()
193        .all(|(index, parameter)| {
194            let is_trivial = |b: Option<BoundVar>| match b {
195                None => false,
196                Some(bound_var) => {
197                    if let Some(index1) = bound_var.index_if_innermost() {
198                        index == index1
199                    } else {
200                        false
201                    }
202                }
203            };
204
205            match parameter.data(interner) {
206                // All types and consts are mapped to distinct variables. Since this
207                // has been canonicalized, those will also be the first N
208                // variables.
209                GenericArgData::Ty(t) => is_trivial(t.bound_var(interner)),
210                GenericArgData::Const(t) => is_trivial(t.bound_var(interner)),
211
212                // And no lifetime mappings. (This is too strict, but we never
213                // product substs with lifetimes.)
214                GenericArgData::Lifetime(_) => false,
215            }
216        })
217}
218
219/// [Anti-unification] is the act of taking two things that do not
220/// unify and finding a minimal generalization of them. So for
221/// example `Vec<u32>` anti-unified with `Vec<i32>` might be
222/// `Vec<?X>`. This is a **very simplistic** anti-unifier.
223///
224/// NOTE: The values here are canonicalized, but output is not, this means
225/// that any escaping bound variables that we see have to be replaced with
226/// inference variables.
227///
228/// [Anti-unification]: https://en.wikipedia.org/wiki/Anti-unification_(computer_science)
229struct AntiUnifier<'infer, I: Interner> {
230    infer: &'infer mut InferenceTable<I>,
231    universe: UniverseIndex,
232    interner: I,
233}
234
235impl<I: Interner> AntiUnifier<'_, I> {
236    fn aggregate_tys(&mut self, ty0: &Ty<I>, ty1: &Ty<I>) -> Ty<I> {
237        let interner = self.interner;
238        match (ty0.kind(interner), ty1.kind(interner)) {
239            // If we see bound things on either side, just drop in a
240            // fresh variable. This means we will sometimes
241            // overgeneralize.  So for example if we have two
242            // solutions that are both `(X, X)`, we just produce `(Y,
243            // Z)` in all cases.
244            (TyKind::InferenceVar(_, _), TyKind::InferenceVar(_, _)) => self.new_ty_variable(),
245
246            // Ugh. Aggregating two types like `for<'a> fn(&'a u32,
247            // &'a u32)` and `for<'a, 'b> fn(&'a u32, &'b u32)` seems
248            // kinda hard. Don't try to be smart for now, just plop a
249            // variable in there and be done with it.
250            // This also ensures that any bound variables we do see
251            // were bound by `Canonical`.
252            (TyKind::BoundVar(_), TyKind::BoundVar(_))
253            | (TyKind::Function(_), TyKind::Function(_))
254            | (TyKind::Dyn(_), TyKind::Dyn(_)) => self.new_ty_variable(),
255
256            (
257                TyKind::Alias(AliasTy::Projection(proj1)),
258                TyKind::Alias(AliasTy::Projection(proj2)),
259            ) => self.aggregate_projection_tys(proj1, proj2),
260
261            (
262                TyKind::Alias(AliasTy::Opaque(opaque_ty1)),
263                TyKind::Alias(AliasTy::Opaque(opaque_ty2)),
264            ) => self.aggregate_opaque_ty_tys(opaque_ty1, opaque_ty2),
265
266            (TyKind::Placeholder(placeholder1), TyKind::Placeholder(placeholder2)) => {
267                self.aggregate_placeholder_tys(placeholder1, placeholder2)
268            }
269
270            (TyKind::Adt(id_a, substitution_a), TyKind::Adt(id_b, substitution_b)) => self
271                .aggregate_name_and_substs(id_a, substitution_a, id_b, substitution_b)
272                .map(|(&name, substitution)| TyKind::Adt(name, substitution).intern(interner))
273                .unwrap_or_else(|| self.new_ty_variable()),
274            (
275                TyKind::AssociatedType(id_a, substitution_a),
276                TyKind::AssociatedType(id_b, substitution_b),
277            ) => self
278                .aggregate_name_and_substs(id_a, substitution_a, id_b, substitution_b)
279                .map(|(&name, substitution)| {
280                    TyKind::AssociatedType(name, substitution).intern(interner)
281                })
282                .unwrap_or_else(|| self.new_ty_variable()),
283            (TyKind::Scalar(scalar_a), TyKind::Scalar(scalar_b)) => {
284                if scalar_a == scalar_b {
285                    TyKind::Scalar(*scalar_a).intern(interner)
286                } else {
287                    self.new_ty_variable()
288                }
289            }
290            (TyKind::Str, TyKind::Str) => TyKind::Str.intern(interner),
291            (TyKind::Tuple(arity_a, substitution_a), TyKind::Tuple(arity_b, substitution_b)) => {
292                self.aggregate_name_and_substs(arity_a, substitution_a, arity_b, substitution_b)
293                    .map(|(&name, substitution)| TyKind::Tuple(name, substitution).intern(interner))
294                    .unwrap_or_else(|| self.new_ty_variable())
295            }
296            (
297                TyKind::OpaqueType(id_a, substitution_a),
298                TyKind::OpaqueType(id_b, substitution_b),
299            ) => self
300                .aggregate_name_and_substs(id_a, substitution_a, id_b, substitution_b)
301                .map(|(&name, substitution)| {
302                    TyKind::OpaqueType(name, substitution).intern(interner)
303                })
304                .unwrap_or_else(|| self.new_ty_variable()),
305            (TyKind::Slice(ty_a), TyKind::Slice(ty_b)) => {
306                TyKind::Slice(self.aggregate_tys(ty_a, ty_b)).intern(interner)
307            }
308            (TyKind::FnDef(id_a, substitution_a), TyKind::FnDef(id_b, substitution_b)) => self
309                .aggregate_name_and_substs(id_a, substitution_a, id_b, substitution_b)
310                .map(|(&name, substitution)| TyKind::FnDef(name, substitution).intern(interner))
311                .unwrap_or_else(|| self.new_ty_variable()),
312            (TyKind::Ref(id_a, lifetime_a, ty_a), TyKind::Ref(id_b, lifetime_b, ty_b)) => {
313                if id_a == id_b {
314                    TyKind::Ref(
315                        *id_a,
316                        self.aggregate_lifetimes(lifetime_a, lifetime_b),
317                        self.aggregate_tys(ty_a, ty_b),
318                    )
319                    .intern(interner)
320                } else {
321                    self.new_ty_variable()
322                }
323            }
324            (TyKind::Raw(id_a, ty_a), TyKind::Raw(id_b, ty_b)) => {
325                if id_a == id_b {
326                    TyKind::Raw(*id_a, self.aggregate_tys(ty_a, ty_b)).intern(interner)
327                } else {
328                    self.new_ty_variable()
329                }
330            }
331            (TyKind::Never, TyKind::Never) => TyKind::Never.intern(interner),
332            (TyKind::Array(ty_a, const_a), TyKind::Array(ty_b, const_b)) => TyKind::Array(
333                self.aggregate_tys(ty_a, ty_b),
334                self.aggregate_consts(const_a, const_b),
335            )
336            .intern(interner),
337            (TyKind::Closure(id_a, substitution_a), TyKind::Closure(id_b, substitution_b)) => self
338                .aggregate_name_and_substs(id_a, substitution_a, id_b, substitution_b)
339                .map(|(&name, substitution)| TyKind::Closure(name, substitution).intern(interner))
340                .unwrap_or_else(|| self.new_ty_variable()),
341            (TyKind::Coroutine(id_a, substitution_a), TyKind::Coroutine(id_b, substitution_b)) => {
342                self.aggregate_name_and_substs(id_a, substitution_a, id_b, substitution_b)
343                    .map(|(&name, substitution)| {
344                        TyKind::Coroutine(name, substitution).intern(interner)
345                    })
346                    .unwrap_or_else(|| self.new_ty_variable())
347            }
348            (
349                TyKind::CoroutineWitness(id_a, substitution_a),
350                TyKind::CoroutineWitness(id_b, substitution_b),
351            ) => self
352                .aggregate_name_and_substs(id_a, substitution_a, id_b, substitution_b)
353                .map(|(&name, substitution)| {
354                    TyKind::CoroutineWitness(name, substitution).intern(interner)
355                })
356                .unwrap_or_else(|| self.new_ty_variable()),
357            (TyKind::Foreign(id_a), TyKind::Foreign(id_b)) => {
358                if id_a == id_b {
359                    TyKind::Foreign(*id_a).intern(interner)
360                } else {
361                    self.new_ty_variable()
362                }
363            }
364            (TyKind::Error, TyKind::Error) => TyKind::Error.intern(interner),
365
366            (_, _) => self.new_ty_variable(),
367        }
368    }
369
370    fn aggregate_placeholder_tys(
371        &mut self,
372        index1: &PlaceholderIndex,
373        index2: &PlaceholderIndex,
374    ) -> Ty<I> {
375        let interner = self.interner;
376        if index1 != index2 {
377            self.new_ty_variable()
378        } else {
379            TyKind::Placeholder(*index1).intern(interner)
380        }
381    }
382
383    fn aggregate_projection_tys(
384        &mut self,
385        proj1: &ProjectionTy<I>,
386        proj2: &ProjectionTy<I>,
387    ) -> Ty<I> {
388        let interner = self.interner;
389        let ProjectionTy {
390            associated_ty_id: name1,
391            substitution: substitution1,
392        } = proj1;
393        let ProjectionTy {
394            associated_ty_id: name2,
395            substitution: substitution2,
396        } = proj2;
397
398        self.aggregate_name_and_substs(name1, substitution1, name2, substitution2)
399            .map(|(&associated_ty_id, substitution)| {
400                TyKind::Alias(AliasTy::Projection(ProjectionTy {
401                    associated_ty_id,
402                    substitution,
403                }))
404                .intern(interner)
405            })
406            .unwrap_or_else(|| self.new_ty_variable())
407    }
408
409    fn aggregate_opaque_ty_tys(
410        &mut self,
411        opaque_ty1: &OpaqueTy<I>,
412        opaque_ty2: &OpaqueTy<I>,
413    ) -> Ty<I> {
414        let OpaqueTy {
415            opaque_ty_id: name1,
416            substitution: substitution1,
417        } = opaque_ty1;
418        let OpaqueTy {
419            opaque_ty_id: name2,
420            substitution: substitution2,
421        } = opaque_ty2;
422
423        self.aggregate_name_and_substs(name1, substitution1, name2, substitution2)
424            .map(|(&opaque_ty_id, substitution)| {
425                TyKind::Alias(AliasTy::Opaque(OpaqueTy {
426                    opaque_ty_id,
427                    substitution,
428                }))
429                .intern(self.interner)
430            })
431            .unwrap_or_else(|| self.new_ty_variable())
432    }
433
434    fn aggregate_name_and_substs<N>(
435        &mut self,
436        name1: N,
437        substitution1: &Substitution<I>,
438        name2: N,
439        substitution2: &Substitution<I>,
440    ) -> Option<(N, Substitution<I>)>
441    where
442        N: Copy + Eq + Debug,
443    {
444        let interner = self.interner;
445        if name1 != name2 {
446            return None;
447        }
448
449        let name = name1;
450
451        assert_eq!(
452            substitution1.len(interner),
453            substitution2.len(interner),
454            "does {:?} take {} substitution or {}? can't both be right",
455            name,
456            substitution1.len(interner),
457            substitution2.len(interner)
458        );
459
460        let substitution = Substitution::from_iter(
461            interner,
462            substitution1
463                .iter(interner)
464                .zip(substitution2.iter(interner))
465                .map(|(p1, p2)| self.aggregate_generic_args(p1, p2)),
466        );
467
468        Some((name, substitution))
469    }
470
471    fn aggregate_generic_args(&mut self, p1: &GenericArg<I>, p2: &GenericArg<I>) -> GenericArg<I> {
472        let interner = self.interner;
473        match (p1.data(interner), p2.data(interner)) {
474            (GenericArgData::Ty(ty1), GenericArgData::Ty(ty2)) => {
475                self.aggregate_tys(ty1, ty2).cast(interner)
476            }
477            (GenericArgData::Lifetime(l1), GenericArgData::Lifetime(l2)) => {
478                self.aggregate_lifetimes(l1, l2).cast(interner)
479            }
480            (GenericArgData::Const(c1), GenericArgData::Const(c2)) => {
481                self.aggregate_consts(c1, c2).cast(interner)
482            }
483            (GenericArgData::Ty(_), _)
484            | (GenericArgData::Lifetime(_), _)
485            | (GenericArgData::Const(_), _) => {
486                panic!("mismatched parameter kinds: p1={:?} p2={:?}", p1, p2)
487            }
488        }
489    }
490
491    fn aggregate_lifetimes(&mut self, l1: &Lifetime<I>, l2: &Lifetime<I>) -> Lifetime<I> {
492        let interner = self.interner;
493        match (l1.data(interner), l2.data(interner)) {
494            (LifetimeData::Phantom(void, ..), _) | (_, LifetimeData::Phantom(void, ..)) => {
495                match *void {}
496            }
497            (LifetimeData::BoundVar(..), _) | (_, LifetimeData::BoundVar(..)) => {
498                self.new_lifetime_variable()
499            }
500            _ => {
501                if l1 == l2 {
502                    l1.clone()
503                } else {
504                    self.new_lifetime_variable()
505                }
506            }
507        }
508    }
509
510    fn aggregate_consts(&mut self, c1: &Const<I>, c2: &Const<I>) -> Const<I> {
511        let interner = self.interner;
512
513        // It would be nice to check that c1 and c2 have the same type, even though
514        // on this stage of solving they should already have the same type.
515
516        let ConstData {
517            ty: c1_ty,
518            value: c1_value,
519        } = c1.data(interner);
520        let ConstData {
521            ty: _c2_ty,
522            value: c2_value,
523        } = c2.data(interner);
524
525        let ty = c1_ty.clone();
526
527        match (c1_value, c2_value) {
528            (ConstValue::InferenceVar(_), _) | (_, ConstValue::InferenceVar(_)) => {
529                self.new_const_variable(ty)
530            }
531
532            (ConstValue::BoundVar(_), _) | (_, ConstValue::BoundVar(_)) => {
533                self.new_const_variable(ty)
534            }
535
536            (ConstValue::Placeholder(_), ConstValue::Placeholder(_)) => {
537                if c1 == c2 {
538                    c1.clone()
539                } else {
540                    self.new_const_variable(ty)
541                }
542            }
543            (ConstValue::Concrete(e1), ConstValue::Concrete(e2)) => {
544                if e1.const_eq(&ty, e2, interner) {
545                    c1.clone()
546                } else {
547                    self.new_const_variable(ty)
548                }
549            }
550
551            (ConstValue::Placeholder(_), _) | (_, ConstValue::Placeholder(_)) => {
552                self.new_const_variable(ty)
553            }
554        }
555    }
556
557    fn new_ty_variable(&mut self) -> Ty<I> {
558        let interner = self.interner;
559        self.infer.new_variable(self.universe).to_ty(interner)
560    }
561
562    fn new_lifetime_variable(&mut self) -> Lifetime<I> {
563        let interner = self.interner;
564        self.infer.new_variable(self.universe).to_lifetime(interner)
565    }
566
567    fn new_const_variable(&mut self, ty: Ty<I>) -> Const<I> {
568        let interner = self.interner;
569        self.infer
570            .new_variable(self.universe)
571            .to_const(interner, ty)
572    }
573}
574
575#[cfg(test)]
576mod test {
577    use crate::slg::aggregate::AntiUnifier;
578    use chalk_integration::{arg, ty};
579    use chalk_ir::UniverseIndex;
580    use chalk_solve::infer::InferenceTable;
581
582    /// Test the equivalent of `Vec<i32>` vs `Vec<u32>`
583    #[test]
584    fn vec_i32_vs_vec_u32() {
585        use chalk_integration::interner::ChalkIr;
586        let mut infer: InferenceTable<ChalkIr> = InferenceTable::new();
587        let mut anti_unifier = AntiUnifier {
588            infer: &mut infer,
589            universe: UniverseIndex::root(),
590            interner: ChalkIr,
591        };
592
593        let ty = anti_unifier.aggregate_tys(
594            &ty!(apply (item 0) (apply (item 1))),
595            &ty!(apply (item 0) (apply (item 2))),
596        );
597        assert_eq!(ty!(apply (item 0) (infer 0)), ty);
598    }
599
600    /// Test the equivalent of `Vec<i32>` vs `Vec<i32>`
601    #[test]
602    fn vec_i32_vs_vec_i32() {
603        use chalk_integration::interner::ChalkIr;
604        let interner = ChalkIr;
605        let mut infer: InferenceTable<ChalkIr> = InferenceTable::new();
606        let mut anti_unifier = AntiUnifier {
607            interner,
608            infer: &mut infer,
609            universe: UniverseIndex::root(),
610        };
611
612        let ty = anti_unifier.aggregate_tys(
613            &ty!(apply (item 0) (apply (item 1))),
614            &ty!(apply (item 0) (apply (item 1))),
615        );
616        assert_eq!(ty!(apply (item 0) (apply (item 1))), ty);
617    }
618
619    /// Test the equivalent of `Vec<X>` vs `Vec<Y>`
620    #[test]
621    fn vec_x_vs_vec_y() {
622        use chalk_integration::interner::ChalkIr;
623        let interner = ChalkIr;
624        let mut infer: InferenceTable<ChalkIr> = InferenceTable::new();
625        let mut anti_unifier = AntiUnifier {
626            interner,
627            infer: &mut infer,
628            universe: UniverseIndex::root(),
629        };
630
631        // Note that the `var 0` and `var 1` in these types would be
632        // referring to canonicalized free variables, not variables in
633        // `infer`.
634        let ty = anti_unifier.aggregate_tys(
635            &ty!(apply (item 0) (infer 0)),
636            &ty!(apply (item 0) (infer 1)),
637        );
638
639        // But this `var 0` is from `infer.
640        assert_eq!(ty!(apply (item 0) (infer 0)), ty);
641    }
642}