use super::var::*;
use super::*;
use crate::debug_span;
use crate::infer::instantiate::IntoBindersAndValue;
use chalk_ir::cast::Cast;
use chalk_ir::fold::{Fold, Folder};
use chalk_ir::interner::{HasInterner, Interner};
use chalk_ir::zip::{Zip, Zipper};
use std::fmt::Debug;
impl<I: Interner> InferenceTable<I> {
#[instrument(level = "debug", skip(self, interner, environment))]
pub fn unify<T>(
&mut self,
interner: &I,
environment: &Environment<I>,
a: &T,
b: &T,
) -> Fallible<UnificationResult<I>>
where
T: ?Sized + Zip<I>,
{
let snapshot = self.snapshot();
match Unifier::new(interner, self, environment).unify(a, b) {
Ok(r) => {
self.commit(snapshot);
Ok(r)
}
Err(e) => {
self.rollback_to(snapshot);
Err(e)
}
}
}
}
struct Unifier<'t, I: Interner> {
table: &'t mut InferenceTable<I>,
environment: &'t Environment<I>,
goals: Vec<InEnvironment<Goal<I>>>,
interner: &'t I,
}
#[derive(Debug)]
pub struct UnificationResult<I: Interner> {
pub goals: Vec<InEnvironment<Goal<I>>>,
}
impl<'t, I: Interner> Unifier<'t, I> {
fn new(
interner: &'t I,
table: &'t mut InferenceTable<I>,
environment: &'t Environment<I>,
) -> Self {
Unifier {
environment,
table,
goals: vec![],
interner,
}
}
fn unify<T>(mut self, a: &T, b: &T) -> Fallible<UnificationResult<I>>
where
T: ?Sized + Zip<I>,
{
Zip::zip_with(&mut self, a, b)?;
Ok(UnificationResult { goals: self.goals })
}
fn unify_ty_ty(&mut self, a: &Ty<I>, b: &Ty<I>) -> Fallible<()> {
let interner = self.interner;
let n_a = self.table.normalize_ty_shallow(interner, a);
let n_b = self.table.normalize_ty_shallow(interner, b);
let a = n_a.as_ref().unwrap_or(a);
let b = n_b.as_ref().unwrap_or(b);
debug_span!("unify_ty_ty", ?a, ?b);
match (a.kind(interner), b.kind(interner)) {
(&TyKind::InferenceVar(var1, kind1), &TyKind::InferenceVar(var2, kind2)) => {
if kind1 == kind2 {
self.unify_var_var(var1, var2)
} else if kind1 == TyVariableKind::General {
self.unify_general_var_specific_ty(var1, b.clone())
} else if kind2 == TyVariableKind::General {
self.unify_general_var_specific_ty(var2, a.clone())
} else {
debug!(
"Tried to unify mis-matching inference variables: {:?} and {:?}",
kind1, kind2
);
Err(NoSolution)
}
}
(&TyKind::Function(ref fn1), &TyKind::Function(ref fn2)) => {
if fn1.sig == fn2.sig {
self.unify_binders(fn1, fn2)
} else {
Err(NoSolution)
}
}
(&TyKind::Placeholder(ref p1), &TyKind::Placeholder(ref p2)) => {
Zip::zip_with(self, p1, p2)
}
(&TyKind::Dyn(ref qwc1), &TyKind::Dyn(ref qwc2)) => Zip::zip_with(self, qwc1, qwc2),
(TyKind::BoundVar(_), _) | (_, TyKind::BoundVar(_)) => panic!(
"unification encountered bound variable: a={:?} b={:?}",
a, b
),
(_, &TyKind::Alias(ref alias)) => self.unify_alias_ty(alias, a),
(&TyKind::Alias(ref alias), _) => self.unify_alias_ty(alias, b),
(&TyKind::InferenceVar(var, kind), ty_data @ _)
| (ty_data @ _, &TyKind::InferenceVar(var, kind)) => {
let ty = ty_data.clone().intern(interner);
match (kind, ty.is_integer(interner), ty.is_float(interner)) {
(TyVariableKind::General, _, _)
| (TyVariableKind::Integer, true, _)
| (TyVariableKind::Float, _, true) => self.unify_var_ty(var, &ty),
_ => Err(NoSolution),
}
}
(&TyKind::Function(_), _) | (_, &TyKind::Function(_)) => Err(NoSolution),
(_, &TyKind::Placeholder(_)) | (&TyKind::Placeholder(_), _) => Err(NoSolution),
(_, &TyKind::Dyn(_)) | (&TyKind::Dyn(_), _) => Err(NoSolution),
(TyKind::Adt(id_a, substitution_a), TyKind::Adt(id_b, substitution_b)) => {
Zip::zip_with(self, id_a, id_b)?;
Zip::zip_with(self, substitution_a, substitution_b)
}
(
TyKind::AssociatedType(assoc_ty_a, substitution_a),
TyKind::AssociatedType(assoc_ty_b, substitution_b),
) => {
Zip::zip_with(self, assoc_ty_a, assoc_ty_b)?;
Zip::zip_with(self, substitution_a, substitution_b)
}
(TyKind::Scalar(scalar_a), TyKind::Scalar(scalar_b)) => {
Zip::zip_with(self, scalar_a, scalar_b)
}
(TyKind::Str, TyKind::Str) => Ok(()),
(TyKind::Tuple(_arity_a, substitution_a), TyKind::Tuple(_arity_b, substitution_b)) => {
Zip::zip_with(self, substitution_a, substitution_b)
}
(
TyKind::OpaqueType(opaque_ty_a, substitution_a),
TyKind::OpaqueType(opaque_ty_b, substitution_b),
) => {
Zip::zip_with(self, opaque_ty_a, opaque_ty_b)?;
Zip::zip_with(self, substitution_a, substitution_b)
}
(TyKind::Slice(substitution_a), TyKind::Slice(substitution_b)) => {
Zip::zip_with(self, substitution_a, substitution_b)
}
(TyKind::FnDef(fn_def_a, substitution_a), TyKind::FnDef(fn_def_b, substitution_b)) => {
Zip::zip_with(self, fn_def_a, fn_def_b)?;
Zip::zip_with(self, substitution_a, substitution_b)
}
(
TyKind::Ref(mutability_a, lifetime_a, ty_a),
TyKind::Ref(mutability_b, lifetime_b, ty_b),
) => {
Zip::zip_with(self, mutability_a, mutability_b)?;
Zip::zip_with(self, lifetime_a, lifetime_b)?;
Zip::zip_with(self, ty_a, ty_b)
}
(TyKind::Raw(mutability_a, ty_a), TyKind::Raw(mutability_b, ty_b)) => {
Zip::zip_with(self, mutability_a, mutability_b)?;
Zip::zip_with(self, ty_a, ty_b)
}
(TyKind::Never, TyKind::Never) => Ok(()),
(TyKind::Array(ty_a, const_a), TyKind::Array(ty_b, const_b)) => {
Zip::zip_with(self, ty_a, ty_b)?;
Zip::zip_with(self, const_a, const_b)
}
(TyKind::Closure(id_a, substitution_a), TyKind::Closure(id_b, substitution_b)) => {
Zip::zip_with(self, id_a, id_b)?;
Zip::zip_with(self, substitution_a, substitution_b)
}
(TyKind::Generator(id_a, substitution_a), TyKind::Generator(id_b, substitution_b)) => {
Zip::zip_with(self, id_a, id_b)?;
Zip::zip_with(self, substitution_a, substitution_b)
}
(
TyKind::GeneratorWitness(id_a, substitution_a),
TyKind::GeneratorWitness(id_b, substitution_b),
) => {
Zip::zip_with(self, id_a, id_b)?;
Zip::zip_with(self, substitution_a, substitution_b)
}
(TyKind::Foreign(id_a), TyKind::Foreign(id_b)) => Zip::zip_with(self, id_a, id_b),
(TyKind::Error, TyKind::Error) => Ok(()),
(_, _) => Err(NoSolution),
}
}
#[instrument(level = "debug", skip(self))]
fn unify_var_var(&mut self, a: InferenceVar, b: InferenceVar) -> Fallible<()> {
let var1 = EnaVariable::from(a);
let var2 = EnaVariable::from(b);
Ok(self
.table
.unify
.unify_var_var(var1, var2)
.expect("unification of two unbound variables cannot fail"))
}
#[instrument(level = "debug", skip(self))]
fn unify_general_var_specific_ty(
&mut self,
general_var: InferenceVar,
specific_ty: Ty<I>,
) -> Fallible<()> {
self.table
.unify
.unify_var_value(
general_var,
InferenceValue::from_ty(self.interner, specific_ty),
)
.unwrap();
Ok(())
}
#[instrument(level = "debug", skip(self))]
fn unify_binders<'a, T, R>(
&mut self,
a: impl IntoBindersAndValue<'a, I, Value = T> + Copy + Debug,
b: impl IntoBindersAndValue<'a, I, Value = T> + Copy + Debug,
) -> Fallible<()>
where
T: Fold<I, Result = R>,
R: Zip<I> + Fold<I, Result = R>,
't: 'a,
{
let interner = self.interner;
{
let a_universal = self.table.instantiate_binders_universally(interner, a);
let b_existential = self.table.instantiate_binders_existentially(interner, b);
Zip::zip_with(self, &a_universal, &b_existential)?;
}
{
let b_universal = self.table.instantiate_binders_universally(interner, b);
let a_existential = self.table.instantiate_binders_existentially(interner, a);
Zip::zip_with(self, &a_existential, &b_universal)
}
}
fn unify_alias_ty(&mut self, alias: &AliasTy<I>, ty: &Ty<I>) -> Fallible<()> {
let interner = self.interner;
Ok(self.goals.push(InEnvironment::new(
self.environment,
AliasEq {
alias: alias.clone(),
ty: ty.clone(),
}
.cast(interner),
)))
}
fn unify_var_ty(&mut self, var: InferenceVar, ty: &Ty<I>) -> Fallible<()> {
debug_span!("unify_var_ty", ?var, ?ty);
let interner = self.interner;
let var = EnaVariable::from(var);
let universe_index = self.table.universe_of_unbound_var(var);
let ty1 = ty.fold_with(
&mut OccursCheck::new(self, var, universe_index),
DebruijnIndex::INNERMOST,
)?;
self.table
.unify
.unify_var_value(var, InferenceValue::from_ty(interner, ty1.clone()))
.unwrap();
debug!("var {:?} set to {:?}", var, ty1);
Ok(())
}
fn unify_lifetime_lifetime(&mut self, a: &Lifetime<I>, b: &Lifetime<I>) -> Fallible<()> {
let interner = self.interner;
let n_a = self.table.normalize_lifetime_shallow(interner, a);
let n_b = self.table.normalize_lifetime_shallow(interner, b);
let a = n_a.as_ref().unwrap_or(a);
let b = n_b.as_ref().unwrap_or(b);
debug_span!("unify_lifetime_lifetime", ?a, ?b);
match (a.data(interner), b.data(interner)) {
(&LifetimeData::InferenceVar(var_a), &LifetimeData::InferenceVar(var_b)) => {
let var_a = EnaVariable::from(var_a);
let var_b = EnaVariable::from(var_b);
debug!(?var_a, ?var_b);
self.table.unify.unify_var_var(var_a, var_b).unwrap();
Ok(())
}
(&LifetimeData::InferenceVar(a_var), &LifetimeData::Placeholder(b_idx)) => {
self.unify_lifetime_var(a, b, a_var, b, b_idx.ui)
}
(&LifetimeData::Placeholder(a_idx), &LifetimeData::InferenceVar(b_var)) => {
self.unify_lifetime_var(a, b, b_var, a, a_idx.ui)
}
(&LifetimeData::InferenceVar(a_var), &LifetimeData::Static) => {
self.unify_lifetime_var(a, b, a_var, b, UniverseIndex::root())
}
(&LifetimeData::Static, &LifetimeData::InferenceVar(b_var)) => {
self.unify_lifetime_var(a, b, b_var, a, UniverseIndex::root())
}
(&LifetimeData::Static, &LifetimeData::Static) => Ok(()),
(&LifetimeData::Static, &LifetimeData::Placeholder(_))
| (&LifetimeData::Placeholder(_), &LifetimeData::Static)
| (&LifetimeData::Placeholder(_), &LifetimeData::Placeholder(_)) => {
if a != b {
Ok(self.push_lifetime_eq_goals(a.clone(), b.clone()))
} else {
Ok(())
}
}
(LifetimeData::BoundVar(_), _) | (_, LifetimeData::BoundVar(_)) => panic!(
"unification encountered bound variable: a={:?} b={:?}",
a, b
),
(LifetimeData::Phantom(..), _) | (_, LifetimeData::Phantom(..)) => unreachable!(),
}
}
#[instrument(level = "debug", skip(self, a, b))]
fn unify_lifetime_var(
&mut self,
a: &Lifetime<I>,
b: &Lifetime<I>,
var: InferenceVar,
value: &Lifetime<I>,
value_ui: UniverseIndex,
) -> Fallible<()> {
let var = EnaVariable::from(var);
let var_ui = self.table.universe_of_unbound_var(var);
if var_ui.can_see(value_ui) {
debug!("{:?} in {:?} can see {:?}; unifying", var, var_ui, value_ui);
self.table
.unify
.unify_var_value(
var,
InferenceValue::from_lifetime(&self.interner, value.clone()),
)
.unwrap();
Ok(())
} else {
debug!(
"{:?} in {:?} cannot see {:?}; pushing constraint",
var, var_ui, value_ui
);
Ok(self.push_lifetime_eq_goals(a.clone(), b.clone()))
}
}
fn unify_const_const<'a>(&mut self, a: &'a Const<I>, b: &'a Const<I>) -> Fallible<()> {
let interner = self.interner;
let n_a = self.table.normalize_const_shallow(interner, a);
let n_b = self.table.normalize_const_shallow(interner, b);
let a = n_a.as_ref().unwrap_or(a);
let b = n_b.as_ref().unwrap_or(b);
debug_span!("unify_const_const", ?a, ?b);
let ConstData {
ty: a_ty,
value: a_val,
} = a.data(interner);
let ConstData {
ty: b_ty,
value: b_val,
} = b.data(interner);
self.unify_ty_ty(a_ty, b_ty)?;
match (a_val, b_val) {
(&ConstValue::InferenceVar(var1), &ConstValue::InferenceVar(var2)) => {
debug!(?var1, ?var2, "unify_ty_ty");
let var1 = EnaVariable::from(var1);
let var2 = EnaVariable::from(var2);
Ok(self
.table
.unify
.unify_var_var(var1, var2)
.expect("unification of two unbound variables cannot fail"))
}
(&ConstValue::InferenceVar(var), &ConstValue::Concrete(_))
| (&ConstValue::InferenceVar(var), &ConstValue::Placeholder(_)) => {
debug!(?var, ty=?b, "unify_var_ty");
self.unify_var_const(var, b)
}
(&ConstValue::Concrete(_), &ConstValue::InferenceVar(var))
| (&ConstValue::Placeholder(_), &ConstValue::InferenceVar(var)) => {
debug!(?var, ty=?a, "unify_var_ty");
self.unify_var_const(var, a)
}
(&ConstValue::Placeholder(p1), &ConstValue::Placeholder(p2)) => {
Zip::zip_with(self, &p1, &p2)
}
(&ConstValue::Concrete(ref ev1), &ConstValue::Concrete(ref ev2)) => {
if ev1.const_eq(a_ty, ev2, interner) {
Ok(())
} else {
Err(NoSolution)
}
}
(&ConstValue::Concrete(_), &ConstValue::Placeholder(_))
| (&ConstValue::Placeholder(_), &ConstValue::Concrete(_)) => Err(NoSolution),
(ConstValue::BoundVar(_), _) | (_, ConstValue::BoundVar(_)) => panic!(
"unification encountered bound variable: a={:?} b={:?}",
a, b
),
}
}
#[instrument(level = "debug", skip(self))]
fn unify_var_const(&mut self, var: InferenceVar, c: &Const<I>) -> Fallible<()> {
let interner = self.interner;
let var = EnaVariable::from(var);
let universe_index = self.table.universe_of_unbound_var(var);
let c1 = c.fold_with(
&mut OccursCheck::new(self, var, universe_index),
DebruijnIndex::INNERMOST,
)?;
debug!("unify_var_const: var {:?} set to {:?}", var, c1);
self.table
.unify
.unify_var_value(var, InferenceValue::from_const(interner, c1))
.unwrap();
Ok(())
}
fn push_lifetime_eq_goals(&mut self, a: Lifetime<I>, b: Lifetime<I>) {
self.goals.push(InEnvironment::new(
self.environment,
WhereClause::LifetimeOutlives(LifetimeOutlives {
a: a.clone(),
b: b.clone(),
})
.cast(self.interner),
));
self.goals.push(InEnvironment::new(
self.environment,
WhereClause::LifetimeOutlives(LifetimeOutlives { a: b, b: a }).cast(self.interner),
));
}
}
impl<'i, I: Interner> Zipper<'i, I> for Unifier<'i, I> {
fn zip_tys(&mut self, a: &Ty<I>, b: &Ty<I>) -> Fallible<()> {
self.unify_ty_ty(a, b)
}
fn zip_lifetimes(&mut self, a: &Lifetime<I>, b: &Lifetime<I>) -> Fallible<()> {
self.unify_lifetime_lifetime(a, b)
}
fn zip_consts(&mut self, a: &Const<I>, b: &Const<I>) -> Fallible<()> {
self.unify_const_const(a, b)
}
fn zip_binders<T>(&mut self, a: &Binders<T>, b: &Binders<T>) -> Fallible<()>
where
T: HasInterner<Interner = I> + Zip<I> + Fold<I, Result = T>,
{
self.unify_binders(a, b)
}
fn interner(&self) -> &'i I {
self.interner
}
}
struct OccursCheck<'u, 't, I: Interner> {
unifier: &'u mut Unifier<'t, I>,
var: EnaVariable<I>,
universe_index: UniverseIndex,
}
impl<'u, 't, I: Interner> OccursCheck<'u, 't, I> {
fn new(
unifier: &'u mut Unifier<'t, I>,
var: EnaVariable<I>,
universe_index: UniverseIndex,
) -> Self {
OccursCheck {
unifier,
var,
universe_index,
}
}
}
impl<'i, I: Interner> Folder<'i, I> for OccursCheck<'_, 'i, I>
where
I: 'i,
{
fn as_dyn(&mut self) -> &mut dyn Folder<'i, I> {
self
}
fn fold_free_placeholder_ty(
&mut self,
universe: PlaceholderIndex,
_outer_binder: DebruijnIndex,
) -> Fallible<Ty<I>> {
let interner = self.interner();
if self.universe_index < universe.ui {
Err(NoSolution)
} else {
Ok(universe.to_ty(interner))
}
}
fn fold_free_placeholder_const(
&mut self,
ty: &Ty<I>,
universe: PlaceholderIndex,
_outer_binder: DebruijnIndex,
) -> Fallible<Const<I>> {
let interner = self.interner();
if self.universe_index < universe.ui {
Err(NoSolution)
} else {
Ok(universe.to_const(interner, ty.clone()))
}
}
#[instrument(level = "debug", skip(self))]
fn fold_free_placeholder_lifetime(
&mut self,
ui: PlaceholderIndex,
_outer_binder: DebruijnIndex,
) -> Fallible<Lifetime<I>> {
let interner = self.interner();
if self.universe_index < ui.ui {
let tick_x = self.unifier.table.new_variable(self.universe_index);
self.unifier
.push_lifetime_eq_goals(tick_x.to_lifetime(interner), ui.to_lifetime(interner));
Ok(tick_x.to_lifetime(interner))
} else {
Ok(ui.to_lifetime(interner))
}
}
fn fold_inference_ty(
&mut self,
var: InferenceVar,
kind: TyVariableKind,
_outer_binder: DebruijnIndex,
) -> Fallible<Ty<I>> {
let interner = self.interner();
let var = EnaVariable::from(var);
match self.unifier.table.unify.probe_value(var) {
InferenceValue::Bound(normalized_ty) => {
let normalized_ty = normalized_ty.assert_ty_ref(interner);
let normalized_ty = normalized_ty.fold_with(self, DebruijnIndex::INNERMOST)?;
assert!(!normalized_ty.needs_shift(interner));
Ok(normalized_ty)
}
InferenceValue::Unbound(ui) => {
if self.unifier.table.unify.unioned(var, self.var) {
return Err(NoSolution);
}
if self.universe_index < ui {
self.unifier
.table
.unify
.unify_var_value(var, InferenceValue::Unbound(self.universe_index))
.unwrap();
}
Ok(var.to_ty_with_kind(interner, kind))
}
}
}
fn fold_inference_const(
&mut self,
ty: &Ty<I>,
var: InferenceVar,
_outer_binder: DebruijnIndex,
) -> Fallible<Const<I>> {
let interner = self.interner();
let var = EnaVariable::from(var);
match self.unifier.table.unify.probe_value(var) {
InferenceValue::Bound(normalized_const) => {
let normalized_const = normalized_const.assert_const_ref(interner);
let normalized_const =
normalized_const.fold_with(self, DebruijnIndex::INNERMOST)?;
assert!(!normalized_const.needs_shift(interner));
Ok(normalized_const)
}
InferenceValue::Unbound(ui) => {
if self.unifier.table.unify.unioned(var, self.var) {
return Err(NoSolution);
}
if self.universe_index < ui {
self.unifier
.table
.unify
.unify_var_value(var, InferenceValue::Unbound(self.universe_index))
.unwrap();
}
Ok(var.to_const(interner, ty.clone()))
}
}
}
fn fold_inference_lifetime(
&mut self,
var: InferenceVar,
outer_binder: DebruijnIndex,
) -> Fallible<Lifetime<I>> {
let interner = self.interner();
let var = EnaVariable::from(var);
match self.unifier.table.unify.probe_value(var) {
InferenceValue::Unbound(ui) => {
if self.universe_index < ui {
self.unifier
.table
.unify
.unify_var_value(var, InferenceValue::Unbound(self.universe_index))
.unwrap();
}
Ok(var.to_lifetime(interner))
}
InferenceValue::Bound(l) => {
let l = l.assert_lifetime_ref(interner);
let l = l.fold_with(self, outer_binder)?;
assert!(!l.needs_shift(interner));
Ok(l)
}
}
}
fn forbid_free_vars(&self) -> bool {
true
}
fn interner(&self) -> &'i I {
self.unifier.interner
}
fn target_interner(&self) -> &'i I {
self.interner()
}
}