use std::{collections::HashMap, fmt::Display, ops::Deref as _};
use tracing::instrument;
use super::{
Constraint, Evidence, Expr, ItemWrapper, NativeItem, NodeId, TypeInference, TypeScheme,
TypedVar, Var,
inst::Instantiate,
ty::{Row, RowCombination, RowUniVar, Type, TypeUniVar},
};
#[derive(Debug)]
pub struct InferOut {
pub constraints: Vec<Constraint>,
pub typed_expr: Expr<TypedVar>,
}
impl Display for InferOut {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{{\n\tconstraints: {:#?}\n\ttyped_expr: {}\n}}",
self.constraints, self.typed_expr
)
}
}
impl InferOut {
fn with_typed_expr(self, f: impl FnOnce(Expr<TypedVar>) -> Expr<TypedVar>) -> Self {
InferOut {
constraints: self.constraints,
typed_expr: f(self.typed_expr),
}
}
}
impl InferOut {
fn new(constraints: Vec<Constraint>, typed_expr: Expr<TypedVar>) -> Self {
Self {
constraints,
typed_expr,
}
}
}
impl TypeInference {
fn fresh_ty_var(&mut self) -> TypeUniVar {
self.unification_table.new_key(None)
}
fn fresh_row_var(&mut self) -> RowUniVar {
self.row_unification_table.new_key(None)
}
fn fresh_row_combination(&mut self) -> RowCombination {
RowCombination {
left: Row::Unifier(self.fresh_row_var()),
right: Row::Unifier(self.fresh_row_var()),
goal: Row::Unifier(self.fresh_row_var()),
}
}
pub fn infer_item(
&mut self,
env: im::HashMap<Var, Type>,
item: NativeItem<Var>,
) -> (InferOut, Type) {
let id = item.abstraction.id();
let (_, _, _, typ) = self.instantiate(id, item.typ);
let out = self.check(env, item.abstraction, typ.clone());
(out, typ)
}
pub fn infer(&mut self, env: im::HashMap<Var, Type>, expr: Expr<Var>) -> (InferOut, Type) {
match expr {
Expr::Unit(id) => (InferOut::new(vec![], Expr::Unit(id)), Type::Unit),
Expr::Integer(id, i) => (InferOut::new(vec![], Expr::Integer(id, i)), Type::Int),
Expr::Float(id, f) => (InferOut::new(vec![], Expr::Float(id, f)), Type::Float),
Expr::String(id, s) => (InferOut::new(vec![], Expr::String(id, s)), Type::String),
Expr::Variable(id, v) => {
let ty = &env[&v];
(
InferOut::new(vec![], Expr::Variable(id, TypedVar(v, ty.clone()))),
ty.clone(),
)
}
Expr::Abstraction {
id,
parameter,
body,
} => {
let arg_ty_var = self.fresh_ty_var();
let env = env.update(parameter, Type::Unifier(arg_ty_var));
let (body_out, body_ty) = self.infer(env, *body);
(
InferOut {
typed_expr: Expr::abstraction(
id,
TypedVar(parameter, Type::Unifier(arg_ty_var)),
body_out.typed_expr,
),
..body_out
},
Type::abstraction(Type::Unifier(arg_ty_var), body_ty),
)
}
Expr::Application {
id,
abstraction: function,
parameter,
} => {
let (paramater_out, parameter_ty) = self.infer(env.clone(), *parameter);
let ret_ty = Type::Unifier(self.fresh_ty_var());
let fun_ty = Type::abstraction(parameter_ty, ret_ty.clone());
let fun_out = self.check(env, *function, fun_ty);
(
InferOut::new(
paramater_out
.constraints
.into_iter()
.chain(fun_out.constraints)
.collect(),
Expr::application(id, fun_out.typed_expr, paramater_out.typed_expr),
),
ret_ty,
)
}
Expr::Label { id, label, expr } => {
let (out, expr_ty) = self.infer(env, *expr);
(
out.with_typed_expr(|expr| Expr::label(id, label.clone(), expr)),
Type::label(label, expr_ty),
)
}
Expr::Unlabel { id, expr, label } => {
let expr_var = self.fresh_ty_var();
let expected_ty = Type::label(label.clone(), Type::Unifier(expr_var));
let out = self.check(env, *expr, expected_ty);
(
out.with_typed_expr(|expr| Expr::unlabel(id, expr, label)),
Type::Unifier(expr_var),
)
}
Expr::Project(id, goal) => {
let row_comb = self.fresh_row_combination();
let sub_row = row_comb.left.clone();
let mut out = self.check(env, *goal, Type::Prod(row_comb.goal.clone()));
out.constraints
.push(Constraint::RowCombine(id, row_comb.clone()));
self.row_to_ev.insert(id, row_comb);
(
out.with_typed_expr(|expr| Expr::project(id, expr)),
Type::Prod(sub_row),
)
}
Expr::Concatenate { id, left, right } => {
let row_comb = self.fresh_row_combination();
let left_out = self.check(env.clone(), *left, Type::Prod(row_comb.left.clone()));
let right_out = self.check(env, *right, Type::Prod(row_comb.right.clone()));
let out_ty = Type::Prod(row_comb.goal.clone());
let mut constraints = left_out.constraints;
constraints.extend(right_out.constraints);
constraints.push(Constraint::RowCombine(id, row_comb.clone()));
self.row_to_ev.insert(id, row_comb);
let typed_expr = Expr::concatenate(id, left_out.typed_expr, right_out.typed_expr);
(
InferOut {
constraints,
typed_expr,
},
out_ty,
)
}
Expr::Inject(id, left) => {
let row_comb = self.fresh_row_combination();
let left_row = row_comb.left.clone();
let out_typ = Type::Sum(row_comb.goal.clone());
let mut out = self.check(env, *left, Type::Sum(left_row));
out.constraints
.push(Constraint::RowCombine(id, row_comb.clone()));
self.row_to_ev.insert(id, row_comb);
(out.with_typed_expr(|expr| Expr::inject(id, expr)), out_typ)
}
Expr::Branch { id, left, right } => {
let row_comb = self.fresh_row_combination();
let ret_ty_var = self.fresh_ty_var();
let left_out = self.check(
env.clone(),
*left,
Type::abstraction(Type::Sum(row_comb.left.clone()), Type::Unifier(ret_ty_var)),
);
let right_out = self.check(
env.clone(),
*right,
Type::abstraction(Type::Sum(row_comb.right.clone()), Type::Unifier(ret_ty_var)),
);
let ret_typ =
Type::abstraction(Type::Sum(row_comb.goal.clone()), Type::Unifier(ret_ty_var));
let mut constraints = left_out.constraints;
constraints.extend(right_out.constraints);
constraints.push(Constraint::RowCombine(id, row_comb.clone()));
self.row_to_ev.insert(id, row_comb);
self.branch_to_ret_typ.insert(id, Type::Unifier(ret_ty_var));
let typed_expr = Expr::branch(id, left_out.typed_expr, right_out.typed_expr);
(
InferOut {
constraints,
typed_expr,
},
ret_typ,
)
}
Expr::Item(id, item_id, symbol) => {
let ty_scheme = self.item_source.type_of_item(item_id);
let (wrapper_tyvars, wrapper_rowvars, constraints, ty) =
self.instantiate(id, ty_scheme);
let wrapper = ItemWrapper {
types: wrapper_tyvars,
rows: wrapper_rowvars,
evidence: constraints
.clone()
.into_iter()
.filter_map(|c| match c {
Constraint::RowCombine(_, row_combo) => Some(Evidence::RowEquation {
left: row_combo.left,
right: row_combo.right,
goal: row_combo.goal,
}),
_ => None,
})
.collect(),
};
self.item_wrappers.insert(id, wrapper);
(
InferOut::new(constraints, Expr::Item(id, item_id, symbol)),
ty,
)
}
}
}
#[instrument(skip(self),ret(level=tracing::Level::TRACE))]
fn instantiate(
&mut self,
id: NodeId,
ty_scheme: TypeScheme,
) -> (Vec<Type>, Vec<Row>, Vec<Constraint>, Type) {
let mut wrapper_tyvars = vec![];
let tyvar_to_unifiers = ty_scheme
.unbound_tys
.iter()
.map(|ty_var| {
let unifier = self.fresh_ty_var();
wrapper_tyvars.push(Type::Unifier(unifier));
(*ty_var, unifier)
})
.collect::<HashMap<_, _>>();
let mut wrapper_rowvars = vec![];
let rowvar_to_unifiers = ty_scheme
.unbound_rows
.iter()
.map(|row_var| {
let unifier = self.fresh_row_var();
wrapper_rowvars.push(Row::Unifier(unifier));
(*row_var, unifier)
})
.collect::<HashMap<_, _>>();
let (constraints, ty) =
Instantiate::new(id, &tyvar_to_unifiers, &rowvar_to_unifiers).type_scheme(ty_scheme);
(wrapper_tyvars, wrapper_rowvars, constraints, ty)
}
pub(crate) fn check(
&mut self,
env: im::HashMap<Var, Type>,
expr: Expr<Var>,
ty: Type,
) -> InferOut {
match (expr, ty) {
(Expr::Integer(id, i), Type::Int) => InferOut::new(vec![], Expr::Integer(id, i)),
(
Expr::Abstraction {
id,
parameter,
body,
},
Type::Abs(parameter_ty, ret_ty),
) => {
let env = env.update(parameter, *parameter_ty.clone());
self.check(env, *body, *ret_ty).with_typed_expr(|body| {
Expr::abstraction(id, TypedVar(parameter, *parameter_ty), body)
})
}
(Expr::Label { id, label, expr }, Type::Label(ty_lbl, ty)) if label == ty_lbl => self
.check(env, *expr, *ty)
.with_typed_expr(|expr| Expr::label(id, label, expr)),
(Expr::Unlabel { id, expr, label }, ty) => self
.check(env, *expr, Type::label(label.clone(), ty))
.with_typed_expr(|expr| Expr::unlabel(id, expr, label)),
(expr @ Expr::Concatenate { .. }, Type::Label(lbl, ty))
| (expr @ Expr::Project(_, _), Type::Label(lbl, ty)) => {
self.check(env, expr, Type::Prod(Row::single(lbl, *ty)))
}
(expr @ Expr::Branch { .. }, Type::Label(lbl, ty))
| (expr @ Expr::Inject(_, _), Type::Label(lbl, ty)) => {
self.check(env, expr, Type::Sum(Row::single(lbl, *ty)))
}
(Expr::Project(id, goal), Type::Prod(sub_row)) => {
let goal_row = Row::Unifier(self.fresh_row_var());
let left = sub_row;
let right = Row::Unifier(self.fresh_row_var());
let mut out = self.check(env, *goal, Type::Prod(goal_row.clone()));
let row_comb = RowCombination {
left,
right,
goal: goal_row,
};
out.constraints
.push(Constraint::RowCombine(id, row_comb.clone()));
self.row_to_ev.insert(id, row_comb);
out.with_typed_expr(|expr| Expr::project(id, expr))
}
(Expr::Concatenate { id, left, right }, Type::Prod(goal_row)) => {
let left_row = Row::Unifier(self.fresh_row_var());
let right_row = Row::Unifier(self.fresh_row_var());
let left_out = self.check(env.clone(), *left, Type::Prod(left_row.clone()));
let right_out = self.check(env, *right, Type::Prod(right_row.clone()));
let mut constraints = left_out.constraints;
constraints.extend(right_out.constraints);
let row_comb = RowCombination {
left: left_row,
right: right_row,
goal: goal_row,
};
constraints.push(Constraint::RowCombine(id, row_comb.clone()));
self.row_to_ev.insert(id, row_comb);
InferOut {
constraints,
typed_expr: Expr::concatenate(id, left_out.typed_expr, right_out.typed_expr),
}
}
(Expr::Branch { id, left, right }, Type::Abs(param_typ, ret_typ)) => {
let mut constraints = vec![];
let goal = match param_typ.deref() {
Type::Sum(goal) => goal.clone(),
_ => {
let goal = self.fresh_row_var();
constraints.push(Constraint::TypeEqual(
id,
*param_typ,
Type::Sum(Row::Unifier(goal)),
));
Row::Unifier(goal)
}
};
let left_row = Row::Unifier(self.fresh_row_var());
let right_row = Row::Unifier(self.fresh_row_var());
let left_out = self.check(
env.clone(),
*left,
Type::abstraction(Type::Sum(left_row.clone()), ret_typ.deref().clone()),
);
let right_out = self.check(
env,
*right,
Type::abstraction(Type::Sum(right_row.clone()), ret_typ.deref().clone()),
);
constraints.extend(left_out.constraints);
constraints.extend(right_out.constraints);
let row_comb = RowCombination {
left: left_row,
right: right_row,
goal,
};
constraints.push(Constraint::RowCombine(id, row_comb.clone()));
self.row_to_ev.insert(id, row_comb);
self.branch_to_ret_typ.insert(id, *ret_typ);
InferOut {
constraints,
typed_expr: Expr::branch(id, left_out.typed_expr, right_out.typed_expr),
}
}
(expr, expected_ty) => {
let id = expr.id();
let (mut out, actual_ty) = self.infer(env, expr);
out.constraints
.push(Constraint::TypeEqual(id, expected_ty, actual_ty));
out
}
}
}
}