use std::{cell::RefCell, fmt::Display};
use ena::unify::{InPlaceUnificationTable, UnifyKey, UnifyValue};
use snafu::Snafu;
use tracing::trace;
use super::{
Constraint, Evidence, NodeId, TypeInference,
ty::{ClosedRow, Row, RowCombination, RowUniVar, RowVar, Type, TypeUniVar, TypeVar},
};
#[derive(Debug, PartialEq, Eq)]
pub enum TypeErrorKind {
TypeNotEqual((Type, Type)),
InfiniteType(TypeUniVar, Type),
RowsNotEqual((Row, Row)),
CheckIntroducedExtraVariablesOrConstraints {
extra_types: Vec<TypeVar>,
extra_row: Vec<RowVar>,
extra_evidence: Vec<Evidence>,
},
}
#[derive(Debug, Snafu, PartialEq, Eq)]
#[snafu(display("{node_id}@{kind}"))]
pub struct TypeError {
pub kind: TypeErrorKind,
pub node_id: NodeId,
}
impl Display for TypeErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self {
TypeErrorKind::TypeNotEqual((l, r)) => write!(f, "types not equal: {l:?} != {r:?}"),
TypeErrorKind::InfiniteType(_type_uni_var, _) => todo!(),
TypeErrorKind::RowsNotEqual((l, r)) => write!(f, "rows not equal: {l:?} != {r:?}"),
TypeErrorKind::CheckIntroducedExtraVariablesOrConstraints { .. } => todo!(),
}
}
}
impl From<(ClosedRow, ClosedRow)> for TypeErrorKind {
fn from((left, right): (ClosedRow, ClosedRow)) -> Self {
TypeErrorKind::RowsNotEqual((Row::Closed(left), Row::Closed(right)))
}
}
impl TypeInference {
pub(crate) fn unification(&mut self, constraints: Vec<Constraint>) -> Result<(), TypeError> {
for constr in constraints {
match constr {
Constraint::TypeEqual(node_id, left, right) => self
.unify_ty_ty(left, right)
.map_err(|kind| TypeError { kind, node_id })?,
Constraint::RowCombine(node_id, row_comb) => self
.unify_row_comb(row_comb)
.map_err(|kind| TypeError { kind, node_id })?,
}
}
Ok(())
}
fn normalize_closed_row(&mut self, closed: ClosedRow) -> ClosedRow {
ClosedRow {
fields: closed.fields,
values: closed
.values
.into_iter()
.map(|ty| self.normalize_ty(ty))
.collect(),
}
}
fn normalize_row(&mut self, row: Row) -> Row {
match row {
Row::Unifier(var) => match self.row_unification_table.probe_value(var) {
Some(Row::Closed(closed)) => Row::Closed(self.normalize_closed_row(closed)),
Some(row) => row,
None => row,
},
Row::Open(var) => Row::Open(var),
Row::Closed(closed) => Row::Closed(self.normalize_closed_row(closed)),
}
}
fn dispatch_any_solved(&mut self, var: RowUniVar, row: ClosedRow) -> Result<(), TypeErrorKind> {
let var = self.row_unification_table.find(var);
let mut changed_combs = vec![];
trace!(?var,?self.partial_row_combs,"dispatch_any_solved");
self.partial_row_combs = std::mem::take(&mut self.partial_row_combs)
.into_iter()
.filter_map(|comb| match comb {
RowCombination {
left: Row::Unifier(left),
right,
goal,
} if self.row_unification_table.find(left) == var => {
changed_combs.push(RowCombination {
left: Row::Closed(row.clone()),
right,
goal,
});
None
}
RowCombination {
left,
right: Row::Unifier(right),
goal,
} if self.row_unification_table.find(right) == var => {
changed_combs.push(RowCombination {
left,
right: Row::Closed(row.clone()),
goal,
});
None
}
RowCombination {
left,
right,
goal: Row::Unifier(goal),
} if self.row_unification_table.find(goal) == var => {
changed_combs.push(RowCombination {
left,
right,
goal: Row::Closed(row.clone()),
});
None
}
comb => Some(comb),
})
.collect();
for row_comb in changed_combs {
self.unify_row_comb(row_comb)?;
}
Ok(())
}
fn normalize_ty(&mut self, ty: Type) -> Type {
match ty {
Type::Unit => Type::Unit,
Type::Int => Type::Int,
Type::Float => Type::Float,
Type::String => Type::String,
Type::Var(var) => Type::Var(var),
Type::Abs(arg, ret) => {
let arg = self.normalize_ty(*arg);
let ret = self.normalize_ty(*ret);
Type::abstraction(arg, ret)
}
Type::Unifier(v) => match self.unification_table.probe_value(v) {
Some(ty) => self.normalize_ty(ty),
None => Type::Unifier(v),
},
Type::Label(label, ty) => {
let ty = self.normalize_ty(*ty);
Type::label(label, ty)
}
Type::Prod(row) => Type::Prod(self.normalize_row(row)),
Type::Sum(row) => Type::Sum(self.normalize_row(row)),
Type::DataFrame => Type::DataFrame,
}
}
fn unify_ty_ty(&mut self, unnorm_left: Type, unnorm_right: Type) -> Result<(), TypeErrorKind> {
trace!(?unnorm_left, ?unnorm_right, "unify_ty_ty");
let left = self.normalize_ty(unnorm_left);
let right = self.normalize_ty(unnorm_right);
match (left, right) {
(Type::Unit, Type::Unit) => Ok(()),
(Type::Int, Type::Int) => Ok(()),
(Type::Float, Type::Float) => Ok(()),
(Type::String, Type::String) => Ok(()),
(Type::DataFrame, Type::DataFrame) => Ok(()),
(Type::Var(a), Type::Var(b)) => (a == b)
.then_some(())
.ok_or(TypeErrorKind::TypeNotEqual((Type::Var(a), Type::Var(b)))),
(Type::Abs(a_arg, a_ret), Type::Abs(b_arg, b_ret)) => {
self.unify_ty_ty(*a_arg, *b_arg)?;
self.unify_ty_ty(*a_ret, *b_ret)
}
(Type::Unifier(a), Type::Unifier(b)) => self
.unification_table
.unify_var_var(a, b)
.map_err(TypeErrorKind::TypeNotEqual),
(Type::Unifier(v), ty) | (ty, Type::Unifier(v)) => {
ty.occurs_check(v)
.map_err(|ty| TypeErrorKind::InfiniteType(v, ty))?;
self.unification_table
.unify_var_value(v, Some(ty))
.map_err(TypeErrorKind::TypeNotEqual)
}
(Type::Prod(left), Type::Prod(right)) | (Type::Sum(left), Type::Sum(right)) => {
self.unify_row_row(left, right)
}
(Type::Label(field, ty), Type::Prod(row))
| (Type::Prod(row), Type::Label(field, ty))
| (Type::Label(field, ty), Type::Sum(row))
| (Type::Sum(row), Type::Label(field, ty)) => self.unify_row_row(
Row::Closed(ClosedRow {
fields: vec![field],
values: vec![*ty],
}),
row,
),
(left, right) => Err(TypeErrorKind::TypeNotEqual((left, right))),
}
}
fn diff_and_unify(
&mut self,
goal: ClosedRow,
sub: ClosedRow,
) -> Result<ClosedRow, TypeErrorKind> {
let mut diff_fields = vec![];
let mut diff_values = vec![];
for (field, value) in goal.fields_and_values() {
match sub.fields.binary_search(field) {
Ok(indx) => {
self.unify_ty_ty(value.clone(), sub.values[indx].clone())?;
}
Err(_) => {
diff_fields.push(field.clone());
diff_values.push(value.clone());
}
}
}
let mut extra_fields = vec![];
let mut extra_values = vec![];
for (field, value) in sub.fields_and_values() {
if goal.fields.binary_search(field).is_err() {
extra_fields.push(field.clone());
extra_values.push(value.clone());
}
}
if !extra_fields.is_empty() {
let expected = Row::Closed(ClosedRow::merge(
ClosedRow {
fields: extra_fields,
values: extra_values,
},
goal.clone(),
));
let goal = Row::Closed(goal);
return Err(TypeErrorKind::RowsNotEqual((goal, expected)));
}
Ok(ClosedRow {
fields: diff_fields,
values: diff_values,
})
}
fn unify_row_row(&mut self, left: Row, right: Row) -> Result<(), TypeErrorKind> {
trace!(?left, ?right, "unify_row_row");
let left = self.normalize_row(left);
let right = self.normalize_row(right);
match (left, right) {
(Row::Open(left), Row::Open(right)) => {
(left == right)
.then_some(())
.ok_or(TypeErrorKind::RowsNotEqual((
Row::Open(left),
Row::Open(right),
)))
}
(Row::Unifier(left), Row::Unifier(right)) => self
.row_unification_table
.unify_var_var(left, right)
.map_err(TypeErrorKind::RowsNotEqual),
(Row::Unifier(var), Row::Open(row)) | (Row::Open(row), Row::Unifier(var)) => self
.row_unification_table
.unify_var_value(var, Some(Row::Open(row)))
.map_err(TypeErrorKind::RowsNotEqual),
(Row::Unifier(var), Row::Closed(row)) | (Row::Closed(row), Row::Unifier(var)) => {
self.row_unification_table
.unify_var_value(var, Some(Row::Closed(row.clone())))
.map_err(TypeErrorKind::RowsNotEqual)?;
self.dispatch_any_solved(var, row)
}
(Row::Closed(left), Row::Closed(right)) => {
if left.fields != right.fields {
return Err(TypeErrorKind::from((left, right)));
}
for (left_ty, right_ty) in left.values.into_iter().zip(right.values) {
self.unify_ty_ty(left_ty, right_ty)?;
}
Ok(())
}
(Row::Open(var), Row::Closed(row)) | (Row::Closed(row), Row::Open(var)) => Err(
TypeErrorKind::RowsNotEqual((Row::Open(var), Row::Closed(row))),
),
}
}
fn unify_row_comb(&mut self, row_comb: RowCombination) -> Result<(), TypeErrorKind> {
let left = self.normalize_row(row_comb.left);
let right = self.normalize_row(row_comb.right);
let goal = self.normalize_row(row_comb.goal);
trace!(?left, ?right, ?goal, "unify_row_comb");
match (left, right, goal) {
(Row::Closed(left), Row::Closed(right), goal) => {
let calc_goal = ClosedRow::merge(left, right);
self.unify_row_row(Row::Closed(calc_goal), goal)
}
(Row::Unifier(var), Row::Closed(sub), Row::Closed(goal))
| (Row::Closed(sub), Row::Unifier(var), Row::Closed(goal)) => {
let diff_row = self.diff_and_unify(goal, sub)?;
self.unify_row_row(Row::Unifier(var), Row::Closed(diff_row))
}
(left, right, goal) => {
let new_comb = RowCombination { left, right, goal };
let mut poss_uni = None;
self.partial_row_combs = std::mem::take(&mut self.partial_row_combs)
.into_iter()
.map(|comb| {
let comb = RowCombination {
left: self.normalize_row(comb.left),
right: self.normalize_row(comb.right),
goal: self.normalize_row(comb.goal),
};
if comb.is_unifiable(&new_comb) {
poss_uni = Some(comb.clone());
} else if comb.is_comm_unifiable(&new_comb) {
poss_uni = Some(RowCombination {
left: comb.right.clone(),
right: comb.left.clone(),
goal: comb.goal.clone(),
});
}
comb
})
.collect();
match poss_uni {
Some(match_comb) => {
self.unify_row_row(new_comb.left, match_comb.left)?;
self.unify_row_row(new_comb.right, match_comb.right)?;
self.unify_row_row(new_comb.goal, match_comb.goal)?;
}
None => {
self.partial_row_combs.insert(new_comb);
}
}
Ok(())
}
}
}
}
pub struct UnificationTable<K: ena::unify::UnifyKey> {
table: RefCell<InPlaceUnificationTable<K>>,
keys: Vec<K>,
}
impl<K: ena::unify::UnifyKey> std::fmt::Debug for UnificationTable<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut table = self.table.borrow_mut();
let vars: Vec<_> = self
.keys
.iter()
.map(|key| {
let root = table.find(*key);
(key, root, table.probe_value(root))
})
.collect();
f.debug_struct("UnificationTable")
.field("vars", &vars)
.finish()
}
}
impl<K: UnifyKey> Default for UnificationTable<K> {
fn default() -> Self {
Self {
table: Default::default(),
keys: Default::default(),
}
}
}
impl<K: UnifyKey> UnificationTable<K> {
pub fn new_key(&mut self, value: <K as UnifyKey>::Value) -> K {
let k = self.table.borrow_mut().new_key(value);
self.keys.push(k);
k
}
pub fn unify_var_var<K1, K2>(
&mut self,
a: K1,
b: K2,
) -> Result<(), <<K as UnifyKey>::Value as UnifyValue>::Error>
where
K1: Into<K>,
K2: Into<K>,
{
self.table.borrow_mut().unify_var_var(a, b)
}
pub fn unify_var_value<K1>(
&mut self,
a_id: K1,
b: <K as UnifyKey>::Value,
) -> Result<(), <<K as UnifyKey>::Value as UnifyValue>::Error>
where
K1: Into<K>,
{
self.table.borrow_mut().unify_var_value(a_id, b)
}
pub fn find<K1>(&mut self, id: K1) -> K
where
K1: Into<K>,
{
self.table.borrow_mut().find(id)
}
pub fn probe_value<K1>(&mut self, id: K1) -> <K as UnifyKey>::Value
where
K1: Into<K>,
{
self.table.borrow_mut().probe_value(id)
}
}