use crate::ext::*;
use crate::infer::InferenceTable;
use crate::solve::slg::SlgContext;
use crate::solve::slg::SlgContextOps;
use crate::solve::slg::SubstitutionExt;
use crate::solve::{Guidance, Solution};
use chalk_ir::cast::Cast;
use chalk_ir::interner::Interner;
use chalk_ir::*;
use chalk_engine::context::{self, AnswerResult, ContextOps};
use chalk_engine::CompleteAnswer;
use std::fmt::Debug;
impl<I: Interner> context::AggregateOps<I, SlgContext<I>> for SlgContextOps<'_, I> {
fn make_solution(
&self,
root_goal: &UCanonical<InEnvironment<Goal<I>>>,
mut answers: impl context::AnswerStream<I>,
should_continue: impl std::ops::Fn() -> bool,
) -> Option<Solution<I>> {
let interner = self.program.interner();
let CompleteAnswer { subst, ambiguous } = match answers.next_answer(|| should_continue()) {
AnswerResult::NoMoreSolutions => {
return None;
}
AnswerResult::Answer(answer) => answer,
AnswerResult::Floundered => CompleteAnswer {
subst: self.identity_constrained_subst(root_goal),
ambiguous: true,
},
AnswerResult::QuantumExceeded => {
return Some(Solution::Ambig(Guidance::Unknown));
}
};
let next_answer = answers.peek_answer(|| should_continue());
if next_answer.is_quantum_exceeded() {
if subst.value.subst.is_identity_subst(interner) {
return Some(Solution::Ambig(Guidance::Unknown));
} else {
return Some(Solution::Ambig(Guidance::Suggested(
subst.map(interner, |cs| cs.subst),
)));
}
}
if next_answer.is_no_more_solutions() && !ambiguous {
return Some(Solution::Unique(subst));
}
let mut subst = subst.map(interner, |cs| cs.subst);
let mut num_answers = 1;
let guidance = loop {
if subst.value.is_empty(interner) || is_trivial(interner, &subst) {
break Guidance::Unknown;
}
if !answers
.any_future_answer(|ref mut new_subst| new_subst.may_invalidate(interner, &subst))
{
break Guidance::Definite(subst);
}
if let Some(expected_answers) = self.expected_answers {
if num_answers >= expected_answers {
panic!("Too many answers for solution.");
}
}
let new_subst = match answers.next_answer(|| should_continue()) {
AnswerResult::Answer(answer1) => answer1.subst,
AnswerResult::Floundered => {
self.identity_constrained_subst(root_goal)
}
AnswerResult::NoMoreSolutions => {
break Guidance::Definite(subst);
}
AnswerResult::QuantumExceeded => {
break Guidance::Suggested(subst);
}
};
subst = merge_into_guidance(interner, &root_goal.canonical, subst, &new_subst);
num_answers += 1;
};
if let Some(expected_answers) = self.expected_answers {
assert_eq!(
expected_answers, num_answers,
"Not enough answers for solution."
);
}
Some(Solution::Ambig(guidance))
}
}
fn merge_into_guidance<I: Interner>(
interner: &I,
root_goal: &Canonical<InEnvironment<Goal<I>>>,
guidance: Canonical<Substitution<I>>,
answer: &Canonical<ConstrainedSubst<I>>,
) -> Canonical<Substitution<I>> {
let mut infer = InferenceTable::new();
let Canonical {
value: ConstrainedSubst {
subst: subst1,
constraints: _,
},
binders: _,
} = answer;
let aggr_generic_args: Vec<_> = guidance
.value
.iter(interner)
.zip(subst1.iter(interner))
.enumerate()
.map(|(index, (p1, p2))| {
let universe = *root_goal.binders.as_slice(interner)[index].skip_kind();
match p1.data(interner) {
GenericArgData::Ty(_) => (),
GenericArgData::Lifetime(_) => {
return infer
.new_variable(universe)
.to_lifetime(interner)
.cast(interner);
}
GenericArgData::Const(_) => (),
};
let mut aggr = AntiUnifier {
infer: &mut infer,
universe,
interner,
};
aggr.aggregate_generic_args(p1, p2)
})
.collect();
let aggr_subst = Substitution::from(interner, aggr_generic_args);
infer.canonicalize(interner, &aggr_subst).quantified
}
fn is_trivial<I: Interner>(interner: &I, subst: &Canonical<Substitution<I>>) -> bool {
subst
.value
.iter(interner)
.enumerate()
.all(|(index, parameter)| {
let is_trivial = |b: Option<BoundVar>| match b {
None => false,
Some(bound_var) => {
if let Some(index1) = bound_var.index_if_innermost() {
index == index1
} else {
false
}
}
};
match parameter.data(interner) {
GenericArgData::Ty(t) => is_trivial(t.bound_var(interner)),
GenericArgData::Const(t) => is_trivial(t.bound_var(interner)),
GenericArgData::Lifetime(_) => false,
}
})
}
struct AntiUnifier<'infer, 'intern, I: Interner> {
infer: &'infer mut InferenceTable<I>,
universe: UniverseIndex,
interner: &'intern I,
}
impl<I: Interner> AntiUnifier<'_, '_, I> {
fn aggregate_tys(&mut self, ty0: &Ty<I>, ty1: &Ty<I>) -> Ty<I> {
let interner = self.interner;
match (ty0.data(interner), ty1.data(interner)) {
(TyData::InferenceVar(_, _), TyData::InferenceVar(_, _)) => self.new_ty_variable(),
(TyData::BoundVar(_), TyData::BoundVar(_))
| (TyData::Function(_), TyData::Function(_))
| (TyData::Dyn(_), TyData::Dyn(_)) => self.new_ty_variable(),
(TyData::Apply(apply1), TyData::Apply(apply2)) => {
self.aggregate_application_tys(apply1, apply2)
}
(
TyData::Alias(AliasTy::Projection(proj1)),
TyData::Alias(AliasTy::Projection(proj2)),
) => self.aggregate_projection_tys(proj1, proj2),
(
TyData::Alias(AliasTy::Opaque(opaque_ty1)),
TyData::Alias(AliasTy::Opaque(opaque_ty2)),
) => self.aggregate_opaque_ty_tys(opaque_ty1, opaque_ty2),
(TyData::Placeholder(placeholder1), TyData::Placeholder(placeholder2)) => {
self.aggregate_placeholder_tys(placeholder1, placeholder2)
}
(TyData::InferenceVar(_, _), _)
| (TyData::BoundVar(_), _)
| (TyData::Dyn(_), _)
| (TyData::Function(_), _)
| (TyData::Apply(_), _)
| (TyData::Alias(_), _)
| (TyData::Placeholder(_), _) => self.new_ty_variable(),
}
}
fn aggregate_application_tys(
&mut self,
apply1: &ApplicationTy<I>,
apply2: &ApplicationTy<I>,
) -> Ty<I> {
let interner = self.interner;
let ApplicationTy {
name: name1,
substitution: substitution1,
} = apply1;
let ApplicationTy {
name: name2,
substitution: substitution2,
} = apply2;
self.aggregate_name_and_substs(name1, substitution1, name2, substitution2)
.map(|(&name, substitution)| {
TyData::Apply(ApplicationTy { name, substitution }).intern(interner)
})
.unwrap_or_else(|| self.new_ty_variable())
}
fn aggregate_placeholder_tys(
&mut self,
index1: &PlaceholderIndex,
index2: &PlaceholderIndex,
) -> Ty<I> {
let interner = self.interner;
if index1 != index2 {
self.new_ty_variable()
} else {
TyData::Placeholder(index1.clone()).intern(interner)
}
}
fn aggregate_projection_tys(
&mut self,
proj1: &ProjectionTy<I>,
proj2: &ProjectionTy<I>,
) -> Ty<I> {
let interner = self.interner;
let ProjectionTy {
associated_ty_id: name1,
substitution: substitution1,
} = proj1;
let ProjectionTy {
associated_ty_id: name2,
substitution: substitution2,
} = proj2;
self.aggregate_name_and_substs(name1, substitution1, name2, substitution2)
.map(|(&associated_ty_id, substitution)| {
TyData::Alias(AliasTy::Projection(ProjectionTy {
associated_ty_id,
substitution,
}))
.intern(interner)
})
.unwrap_or_else(|| self.new_ty_variable())
}
fn aggregate_opaque_ty_tys(
&mut self,
opaque_ty1: &OpaqueTy<I>,
opaque_ty2: &OpaqueTy<I>,
) -> Ty<I> {
let OpaqueTy {
opaque_ty_id: name1,
substitution: substitution1,
} = opaque_ty1;
let OpaqueTy {
opaque_ty_id: name2,
substitution: substitution2,
} = opaque_ty2;
self.aggregate_name_and_substs(name1, substitution1, name2, substitution2)
.map(|(&opaque_ty_id, substitution)| {
TyData::Alias(AliasTy::Opaque(OpaqueTy {
opaque_ty_id,
substitution,
}))
.intern(self.interner)
})
.unwrap_or_else(|| self.new_ty_variable())
}
fn aggregate_name_and_substs<N>(
&mut self,
name1: N,
substitution1: &Substitution<I>,
name2: N,
substitution2: &Substitution<I>,
) -> Option<(N, Substitution<I>)>
where
N: Copy + Eq + Debug,
{
let interner = self.interner;
if name1 != name2 {
return None;
}
let name = name1;
assert_eq!(
substitution1.len(interner),
substitution2.len(interner),
"does {:?} take {} substitution or {}? can't both be right",
name,
substitution1.len(interner),
substitution2.len(interner)
);
let substitution = Substitution::from(
interner,
substitution1
.iter(interner)
.zip(substitution2.iter(interner))
.map(|(p1, p2)| self.aggregate_generic_args(p1, p2)),
);
Some((name, substitution))
}
fn aggregate_generic_args(&mut self, p1: &GenericArg<I>, p2: &GenericArg<I>) -> GenericArg<I> {
let interner = self.interner;
match (p1.data(interner), p2.data(interner)) {
(GenericArgData::Ty(ty1), GenericArgData::Ty(ty2)) => {
self.aggregate_tys(ty1, ty2).cast(interner)
}
(GenericArgData::Lifetime(l1), GenericArgData::Lifetime(l2)) => {
self.aggregate_lifetimes(l1, l2).cast(interner)
}
(GenericArgData::Const(c1), GenericArgData::Const(c2)) => {
self.aggregate_consts(c1, c2).cast(interner)
}
(GenericArgData::Ty(_), _)
| (GenericArgData::Lifetime(_), _)
| (GenericArgData::Const(_), _) => {
panic!("mismatched parameter kinds: p1={:?} p2={:?}", p1, p2)
}
}
}
fn aggregate_lifetimes(&mut self, l1: &Lifetime<I>, l2: &Lifetime<I>) -> Lifetime<I> {
let interner = self.interner;
match (l1.data(interner), l2.data(interner)) {
(LifetimeData::InferenceVar(_), _) | (_, LifetimeData::InferenceVar(_)) => {
self.new_lifetime_variable()
}
(LifetimeData::BoundVar(_), _) | (_, LifetimeData::BoundVar(_)) => {
self.new_lifetime_variable()
}
(LifetimeData::Placeholder(_), LifetimeData::Placeholder(_)) => {
if l1 == l2 {
l1.clone()
} else {
self.new_lifetime_variable()
}
}
(LifetimeData::Phantom(..), _) | (_, LifetimeData::Phantom(..)) => unreachable!(),
}
}
fn aggregate_consts(&mut self, c1: &Const<I>, c2: &Const<I>) -> Const<I> {
let interner = self.interner;
let ConstData {
ty: c1_ty,
value: c1_value,
} = c1.data(interner);
let ConstData {
ty: _c2_ty,
value: c2_value,
} = c2.data(interner);
let ty = c1_ty.clone();
match (c1_value, c2_value) {
(ConstValue::InferenceVar(_), _) | (_, ConstValue::InferenceVar(_)) => {
self.new_const_variable(ty)
}
(ConstValue::BoundVar(_), _) | (_, ConstValue::BoundVar(_)) => {
self.new_const_variable(ty.clone())
}
(ConstValue::Placeholder(_), ConstValue::Placeholder(_)) => {
if c1 == c2 {
c1.clone()
} else {
self.new_const_variable(ty)
}
}
(ConstValue::Concrete(e1), ConstValue::Concrete(e2)) => {
if e1.const_eq(&ty, e2, interner) {
c1.clone()
} else {
self.new_const_variable(ty)
}
}
(ConstValue::Placeholder(_), _) | (_, ConstValue::Placeholder(_)) => {
self.new_const_variable(ty)
}
}
}
fn new_ty_variable(&mut self) -> Ty<I> {
let interner = self.interner;
self.infer.new_variable(self.universe).to_ty(interner)
}
fn new_lifetime_variable(&mut self) -> Lifetime<I> {
let interner = self.interner;
self.infer.new_variable(self.universe).to_lifetime(interner)
}
fn new_const_variable(&mut self, ty: Ty<I>) -> Const<I> {
let interner = self.interner;
self.infer
.new_variable(self.universe)
.to_const(interner, ty)
}
}
#[test]
fn vec_i32_vs_vec_u32() {
use chalk_integration::interner::ChalkIr;
let mut infer: InferenceTable<ChalkIr> = InferenceTable::new();
let mut anti_unifier = AntiUnifier {
infer: &mut infer,
universe: UniverseIndex::root(),
interner: &ChalkIr,
};
let ty = anti_unifier.aggregate_tys(
&ty!(apply (item 0) (apply (item 1))),
&ty!(apply (item 0) (apply (item 2))),
);
assert_eq!(ty!(apply (item 0) (infer 0)), ty);
}
#[test]
fn vec_i32_vs_vec_i32() {
use chalk_integration::interner::ChalkIr;
let interner = &ChalkIr;
let mut infer: InferenceTable<ChalkIr> = InferenceTable::new();
let mut anti_unifier = AntiUnifier {
interner,
infer: &mut infer,
universe: UniverseIndex::root(),
};
let ty = anti_unifier.aggregate_tys(
&ty!(apply (item 0) (apply (item 1))),
&ty!(apply (item 0) (apply (item 1))),
);
assert_eq!(ty!(apply (item 0) (apply (item 1))), ty);
}
#[test]
fn vec_x_vs_vec_y() {
use chalk_integration::interner::ChalkIr;
let interner = &ChalkIr;
let mut infer: InferenceTable<ChalkIr> = InferenceTable::new();
let mut anti_unifier = AntiUnifier {
interner,
infer: &mut infer,
universe: UniverseIndex::root(),
};
let ty = anti_unifier.aggregate_tys(
&ty!(apply (item 0) (infer 0)),
&ty!(apply (item 0) (infer 1)),
);
assert_eq!(ty!(apply (item 0) (infer 0)), ty);
}