use std::collections::HashSet;
use ena::unify::{EqUnifyValue, UnifyKey};
use super::{Evidence, Label};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ClosedRow {
pub fields: Vec<Label>,
pub values: Vec<Type>,
}
impl ClosedRow {
pub fn merge(left: ClosedRow, right: ClosedRow) -> ClosedRow {
let mut left_fields = left.fields.into_iter().peekable();
let mut left_values = left.values.into_iter();
let mut right_fields = right.fields.into_iter().peekable();
let mut right_values = right.values.into_iter();
let mut fields = vec![];
let mut values = vec![];
loop {
match (left_fields.peek(), right_fields.peek()) {
(Some(left), Some(right)) => {
if left <= right {
fields.push(left_fields.next().unwrap());
values.push(left_values.next().unwrap());
} else {
fields.push(right_fields.next().unwrap());
values.push(right_values.next().unwrap());
}
}
(Some(_), None) => {
fields.extend(left_fields);
values.extend(left_values);
break;
}
(None, Some(_)) => {
fields.extend(right_fields);
values.extend(right_values);
break;
}
(None, None) => {
break;
}
}
}
ClosedRow { fields, values }
}
pub fn mentions(
&self,
unbound_tys: &HashSet<TypeUniVar>,
unbound_rows: &HashSet<RowUniVar>,
) -> bool {
for ty in self.values.iter() {
if ty.mentions(unbound_tys, unbound_rows) {
return true;
}
}
false
}
pub fn fields_and_values(&self) -> impl Iterator<Item = (&Label, &Type)> {
self.fields.iter().zip(self.values.iter())
}
}
impl EqUnifyValue for ClosedRow {}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Row {
Unifier(RowUniVar),
Open(RowVar),
Closed(ClosedRow),
}
impl EqUnifyValue for Row {}
impl Row {
pub fn single<S: ToString>(lbl: S, ty: Type) -> Self {
Row::Closed(ClosedRow {
fields: vec![lbl.to_string()],
values: vec![ty],
})
}
pub fn equatable(&self, other: &Self) -> bool {
match (self, other) {
(Row::Unifier(a), Row::Unifier(b)) => a == b,
(Row::Open(a), Row::Open(b)) => a == b,
(Row::Closed(a), Row::Closed(b)) => a.fields == b.fields,
_ => false,
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, Hash)]
pub enum Type {
Unit,
Int,
Float,
String,
Unifier(TypeUniVar),
Var(TypeVar),
Abs(Box<Self>, Box<Self>),
Prod(Row),
Sum(Row),
Label(Label, Box<Self>),
DataFrame,
}
impl EqUnifyValue for Type {}
impl Type {
pub fn abstraction(arg: Self, ret: Self) -> Self {
Self::Abs(Box::new(arg), Box::new(ret))
}
pub fn abstractions<T>(arg: T, ret: Self) -> Self
where
T: IntoIterator<Item = Type>,
<T as IntoIterator>::IntoIter: DoubleEndedIterator<Item = Type>,
{
arg.into_iter()
.rfold(ret, |ret, param| Self::Abs(Box::new(param), Box::new(ret)))
}
pub fn label(label: Label, value: Self) -> Self {
Self::Label(label, Box::new(value))
}
pub fn occurs_check(&self, var: TypeUniVar) -> Result<(), Type> {
match self {
Type::Unit
| Type::Int
| Type::Float
| Type::String
| Type::Var(_)
| Type::DataFrame => Ok(()),
Type::Unifier(v) => {
if *v == var {
Err(Type::Unifier(*v))
} else {
Ok(())
}
}
Type::Abs(arg, ret) => {
arg.occurs_check(var).map_err(|_| self.clone())?;
ret.occurs_check(var).map_err(|_| self.clone())
}
Type::Label(_, ty) => ty.occurs_check(var).map_err(|_| self.clone()),
Type::Prod(row) | Type::Sum(row) => match row {
Row::Unifier(_) => Ok(()),
Row::Open(_) => Ok(()),
Row::Closed(closed_row) => {
for ty in closed_row.values.iter() {
ty.occurs_check(var).map_err(|_| self.clone())?
}
Ok(())
}
},
}
}
pub fn mentions(
&self,
unbound_tys: &HashSet<TypeUniVar>,
unbound_rows: &HashSet<RowUniVar>,
) -> bool {
match self {
Type::Unit
| Type::Int
| Type::Float
| Type::String
| Type::Var(_)
| Type::DataFrame => false,
Type::Unifier(v) => unbound_tys.contains(v),
Type::Abs(arg, ret) => {
arg.mentions(unbound_tys, unbound_rows) || ret.mentions(unbound_tys, unbound_rows)
}
Type::Label(_, ty) => ty.mentions(unbound_tys, unbound_rows),
Type::Prod(row) | Type::Sum(row) => match row {
Row::Unifier(var) => unbound_rows.contains(var),
Row::Open(_) => false,
Row::Closed(row) => row.mentions(unbound_tys, unbound_rows),
},
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct RowVar(pub u32);
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct RowUniVar {
pub id: u32,
}
impl RowUniVar {
pub fn new(id: u32) -> Self {
Self { id }
}
}
impl UnifyKey for RowUniVar {
type Value = Option<Row>;
fn index(&self) -> u32 {
self.id
}
fn from_index(id: u32) -> Self {
Self::new(id)
}
fn tag() -> &'static str {
"RowUniVar"
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct TypeVar(pub u32);
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct TypeUniVar {
pub id: u32,
}
impl TypeUniVar {
fn new(id: u32) -> Self {
Self { id }
}
}
impl UnifyKey for TypeUniVar {
type Value = Option<Type>;
fn index(&self) -> u32 {
self.id
}
fn from_index(id: u32) -> Self {
Self::new(id)
}
fn tag() -> &'static str {
"TypeUniVar"
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct RowCombination {
pub left: Row,
pub right: Row,
pub goal: Row,
}
impl RowCombination {
pub fn is_unifiable(&self, other: &Self) -> bool {
let left_equatable = self.left.equatable(&other.left);
let right_equatable = self.right.equatable(&other.right);
let goal_equatable = self.goal.equatable(&other.goal);
(goal_equatable && (left_equatable || right_equatable))
|| (left_equatable && right_equatable)
}
pub fn is_comm_unifiable(&self, other: &Self) -> bool {
let left_equatable = self.left.equatable(&other.right);
let right_equatable = self.right.equatable(&other.left);
let goal_equatable = self.goal.equatable(&other.goal);
(goal_equatable && (left_equatable || right_equatable))
|| (left_equatable && right_equatable)
}
pub fn into_evidence(self) -> Evidence {
Evidence::RowEquation {
left: self.left,
right: self.right,
goal: self.goal,
}
}
}