mod monomorph;
pub mod sexpr;
mod simplify;
pub use monomorph::monomorph_module;
pub use simplify::simplify_module;
use tracing::instrument;
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap},
};
use super::crust::{self, Evidence, NodeId, TypedVar};
use crate::{external_type::FunctionType, trace_alt};
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct ItemId(pub u32);
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct VarId(u32);
#[derive(PartialEq, Eq, Hash, Clone)]
pub struct Var {
pub id: VarId,
pub typ: Type,
}
impl Var {
fn new(id: VarId, ty: Type) -> Self {
Self { id, typ: ty }
}
pub fn map_typ(self, f: impl FnOnce(Type) -> Type) -> Self {
Self {
typ: f(self.typ),
..self
}
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub struct TypeVar(usize);
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum Kind {
Type,
Row,
}
#[derive(PartialEq, Eq, Hash, Clone)]
pub enum Row {
Open(TypeVar),
Closed(Vec<(String, Type)>),
}
#[derive(PartialEq, Eq, Clone, Hash)]
pub enum Type {
Unit,
Int,
Float,
String,
Var(TypeVar),
Abs(Box<Self>, Box<Self>),
TypAbs(Kind, Box<Self>),
Prod(Row),
Sum(Row),
DataFrame,
}
impl Type {
pub fn fun(param: Self, ret: Self) -> Self {
Self::Abs(Box::new(param), Box::new(ret))
}
pub fn funs<T>(params: T, ret: Self) -> Self
where
T: IntoIterator<Item = Self>,
<T as IntoIterator>::IntoIter: DoubleEndedIterator<Item = Self>,
{
params
.into_iter()
.rfold(ret, |ret, param| Self::Abs(Box::new(param), Box::new(ret)))
}
pub fn ty_fun(kind: Kind, body: Self) -> Self {
Self::TypAbs(kind, Box::new(body))
}
pub fn prod(row: Row) -> Self {
match row {
Row::Closed(elems) if elems.len() == 1 => elems.into_iter().next().unwrap().1,
row => Self::Prod(row),
}
}
pub fn sum(row: Row) -> Self {
match row {
Row::Closed(elems) if elems.len() == 1 => elems.into_iter().next().unwrap().1,
row => Self::Sum(row),
}
}
fn subst_row(self, row: Row) -> Self {
Subst::RowPayload(row).subst_ty(self, 0)
}
fn subst_typ(self, ty: Self) -> Self {
Subst::TyPayload(ty).subst_ty(self, 0)
}
}
impl std::fmt::Debug for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
#[derive(Clone)]
enum Subst {
RowPayload(Row),
TyPayload(Type),
}
impl Subst {
fn shift(&mut self) {
match self {
Subst::RowPayload(row) => row.shift(),
Subst::TyPayload(ty) => ty.shift(),
}
}
fn shifted(mut self) -> Self {
self.shift();
self
}
fn subst_row_var(self) -> Row {
match self {
Subst::RowPayload(row) => row,
Subst::TyPayload(_) => panic!("ICE: Kind mismatch. A row was substituted for a type"),
}
}
fn subst_ty_var(self) -> Type {
match self {
Subst::TyPayload(ty) => ty,
Subst::RowPayload(_) => panic!("ICE: Kind mismatch. A type was substituted for a row"),
}
}
fn subst_row(self, haystack: Row, needle: usize) -> Row {
match haystack {
Row::Open(row_var) => match row_var.0.cmp(&needle) {
Ordering::Equal => self.subst_row_var(),
Ordering::Less => Row::Open(row_var),
Ordering::Greater => Row::Open(TypeVar(row_var.0 - 1)),
},
Row::Closed(elems) => Row::Closed(
elems
.into_iter()
.map(|(name, elem)| (name, self.clone().subst_ty(elem, needle)))
.collect(),
),
}
}
fn subst_ty(self, haystack: Type, needle: usize) -> Type {
match haystack {
Type::Unit => Type::Unit,
Type::Int => Type::Int,
Type::Float => Type::Float,
Type::String => Type::String,
Type::DataFrame => Type::DataFrame,
Type::Var(type_var) => match type_var.0.cmp(&needle) {
Ordering::Equal => self.subst_ty_var(),
Ordering::Less => Type::Var(type_var),
Ordering::Greater => Type::Var(TypeVar(type_var.0 - 1)),
},
Type::Abs(param, ret) => Type::fun(
self.clone().subst_ty(*param, needle),
self.subst_ty(*ret, needle),
),
Type::TypAbs(kind, body) => {
Type::ty_fun(kind, self.shifted().subst_ty(*body, needle + 1))
}
Type::Prod(row) => Type::prod(self.subst_row(row, needle)),
Type::Sum(row) => Type::sum(self.subst_row(row, needle)),
}
}
}
#[derive(PartialEq, Clone)]
pub enum TypApp {
Ty(Type),
Row(Row),
}
impl TypApp {
fn subst_typ(self, payload: Type) -> TypApp {
match self {
TypApp::Ty(ty) => Self::Ty(ty.subst_typ(payload)),
TypApp::Row(_row) => todo!(),
}
}
}
#[derive()]
pub struct Module {
pub items: BTreeMap<ItemId, Item>,
pub instances: HashMap<Symbol, Vec<Type>>,
item_supply: ItemSupply,
}
impl Module {
fn map(self, f: impl Fn(Expr) -> Expr) -> Self {
Self {
items: self
.items
.into_iter()
.map(|(id, item)| match item {
Item::Native(native_item) => (
id,
Item::Native(NativeItem {
expr: f(native_item.expr),
..native_item
}),
),
item @ Item::External(_) => (id, item),
})
.collect(),
..self
}
}
}
impl std::fmt::Debug for Module {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Module {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
#[derive(Clone)]
pub enum Item {
Native(NativeItem),
External(ExternalItem),
}
#[derive(Clone)]
pub struct NativeItem {
pub symbol: Symbol,
pub expr: Expr,
pub typ: Type,
}
#[derive(Clone)]
pub struct ExternalItem {
pub symbol: Symbol,
pub external_type: FunctionType,
}
#[derive(PartialEq, Clone)]
pub enum Expr {
Unit,
Variable(Var),
Integer(i64),
Float(f64),
String(String),
Abstraction(Var, Box<Self>),
Application(Box<Self>, Box<Self>),
TypAbs(Kind, Box<Self>),
TypApp(Box<Self>, TypApp),
Local(Var, Box<Self>, Box<Self>),
Tuple(Vec<(String, Self)>),
Field(Box<Self>, usize),
Tag(Type, usize, Box<Self>),
Case(Type, Box<Self>, Vec<Branch>),
Item(Type, ItemId, Symbol),
}
impl std::fmt::Debug for Expr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Expr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for Item {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Item {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for NativeItem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for NativeItem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for ExternalItem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for ExternalItem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for Row {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Row {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for Var {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Var {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for VarId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for VarId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for ItemId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for ItemId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for TypeVar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for TypeVar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for Kind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Kind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for TypApp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for TypApp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for Branch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Branch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
impl std::fmt::Debug for Symbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_sexpr_for(self, &(), f)
}
}
impl std::fmt::Display for Symbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
crate::compiler::sexpr::format_box_for(self, &(), f)
}
}
#[derive(PartialEq, Clone)]
pub struct Branch {
pub param: Var,
pub body: Expr,
}
impl Branch {
fn as_fun(&self) -> Expr {
Expr::abs(self.param.clone(), self.body.clone())
}
}
#[derive(Default, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
pub struct Symbol {
pub module: String,
pub field: String,
}
impl From<crust::Symbol> for Symbol {
fn from(value: crust::Symbol) -> Self {
Self {
module: value.module,
field: value.field,
}
}
}
impl Expr {
pub fn abs(var: Var, body: Self) -> Self {
Self::Abstraction(var, Box::new(body))
}
fn abss<I>(vars: I, body: Expr) -> Expr
where
I: IntoIterator<Item = Var>,
I::IntoIter: DoubleEndedIterator,
{
vars.into_iter()
.rfold(body, |body, var| Expr::abs(var, body))
}
pub fn app(fun: Self, param: Self) -> Self {
Self::Application(Box::new(fun), Box::new(param))
}
fn ty_abs(kind: Kind, expr: Self) -> Self {
Self::TypAbs(kind, Box::new(expr))
}
fn ty_app(body: Expr, ty: TypApp) -> Expr {
Self::TypApp(Box::new(body), ty)
}
fn tuple(elems: impl IntoIterator<Item = (String, Self)>) -> Expr {
Self::Tuple(elems.into_iter().collect())
}
fn field(body: Self, index: usize) -> Self {
Self::Field(Box::new(body), index)
}
fn tag(ty: Type, tag: usize, body: Self) -> Self {
Self::Tag(ty, tag, Box::new(body))
}
fn case(ty: Type, scrutinee: Self, branch: impl IntoIterator<Item = Branch>) -> Self {
Self::Case(ty, Box::new(scrutinee), branch.into_iter().collect())
}
fn branch(param: Var, body: Expr) -> Branch {
Branch { param, body }
}
fn local(var: Var, defn: Self, body: Self) -> Self {
Self::Local(var, Box::new(defn), Box::new(body))
}
pub fn type_of(&self) -> Type {
match self {
Expr::Variable(var) => var.typ.clone(),
Expr::Unit => Type::Unit,
Expr::Integer(_) => Type::Int,
Expr::Float(_) => Type::Float,
Expr::String(_) => Type::String,
Expr::Abstraction(param, body) => Type::fun(param.typ.clone(), body.type_of()),
Expr::TypAbs(kind, body) => Type::ty_fun(*kind, body.type_of()),
Expr::Application(fun, param) => {
let Type::Abs(fun_param_ty, ret_ty) = fun.type_of() else {
panic!(
"ICE: IR used non-function type as a function: {}\n{}",
fun.type_of(),
self.clone(),
)
};
if param.type_of() != *fun_param_ty {
panic!(
"ICE: Function applied to wrong parameter type {} != {}\n{}",
param.type_of(),
fun_param_ty,
Expr::Application(fun.clone(), param.clone()),
);
}
*ret_ty
}
Expr::TypApp(body, ty_app) => {
let Type::TypAbs(kind, ret_ty) = body.type_of() else {
panic!("ICE: Type applied to a non-forall IR term");
};
match (kind, ty_app) {
(Kind::Type, TypApp::Ty(ty)) => ret_ty.subst_typ(ty.clone()),
(Kind::Row, TypApp::Row(row)) => ret_ty.subst_row(row.clone()),
(Kind::Type, TypApp::Row(_)) => {
panic!("ICE: Kind mismatch. Type applied a Row to variable of kind Type")
}
(Kind::Row, TypApp::Ty(_)) => {
panic!("ICE: Kind mismatch. Type applied a Type to variable of kind Row")
}
}
}
Expr::Local(v, defn, body) => {
if v.typ != defn.type_of() {
}
body.type_of()
}
Expr::Tuple(fields) => Type::Prod(Row::Closed(
fields
.iter()
.map(|(name, field)| (name.clone(), field.type_of()))
.collect(),
)),
Expr::Field(body, field) => {
let Type::Prod(Row::Closed(elems)) = body.type_of() else {
panic!("ICE: IR accessed field of a non product type");
};
elems[*field].1.clone()
}
Expr::Tag(typ, tag, body) => {
let Type::Sum(Row::Closed(elems)) = typ else {
panic!("ICE: Tagged value with non sum type");
};
if !body.type_of().eq(&elems[*tag].1) {
panic!("ICE: Tagged value has element with the wrong type")
};
typ.clone()
}
Expr::Case(typ, elem, branches) => {
let Type::Sum(Row::Closed(elems)) = elem.type_of() else {
panic!("ICE: Case scrutinee does not have sum type")
};
for (branch, elem) in branches.iter().zip(elems.iter()) {
if elem.1 != branch.param.typ {
panic!(
"ICE: Branch has unexpected parameter type {} != {}",
elem.1, branch.param.typ,
)
}
if typ != &branch.body.type_of() {
panic!("ICE: Branch body has unexpected type")
}
}
typ.clone()
}
Expr::Item(ty, _, _) => ty.clone(),
}
}
}
#[derive(Default)]
struct VarSupply {
next: u32,
cache: HashMap<crust::Var, VarId>,
}
impl VarSupply {
fn supply_for(&mut self, var: crust::Var) -> VarId {
self.cache
.entry(var)
.or_insert_with(|| {
let expr_var = self.next;
self.next += 1;
VarId(expr_var)
})
.to_owned()
}
fn supply(&mut self) -> VarId {
let expr_var = self.next;
self.next += 1;
VarId(expr_var)
}
}
#[derive(Debug, Default, Clone)]
pub struct ItemSupply {
next: u32,
cache: HashMap<crust::ItemId, ItemId>,
}
impl ItemSupply {
fn supply_for(&mut self, item: crust::ItemId) -> ItemId {
self.cache
.entry(item)
.or_insert_with(|| {
let expr_item = self.next;
self.next += 1;
ItemId(expr_item)
})
.to_owned()
}
fn new_supply(&mut self) -> ItemId {
let id = self.next;
self.next += 1;
ItemId(id)
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Hash)]
enum SemanticTypeVar {
Ty(crust::TypeVar),
Row(crust::RowVar),
}
impl SemanticTypeVar {
fn kind(&self) -> Kind {
match self {
SemanticTypeVar::Ty(_) => Kind::Type,
SemanticTypeVar::Row(_) => Kind::Row,
}
}
}
impl TypeVar {
fn adjust(&mut self, cutoff: usize) {
if self.0 >= cutoff {
self.0 += 1;
}
}
}
impl Type {
fn adjust(&mut self, cutoff: usize) {
match self {
Type::Unit | Type::Int | Type::Float | Type::String | Type::DataFrame => {}
Type::Var(type_var) => type_var.adjust(cutoff),
Type::Abs(param, ret) => {
param.adjust(cutoff);
ret.adjust(cutoff);
}
Type::TypAbs(_, body) => {
body.adjust(cutoff + 1);
}
Type::Prod(row) | Type::Sum(row) => row.adjust(cutoff),
}
}
fn shift(&mut self) {
self.adjust(0);
}
fn shifted(mut self) -> Self {
self.shift();
self
}
}
impl Row {
fn adjust(&mut self, cutoff: usize) {
match self {
Row::Open(type_var) => type_var.adjust(cutoff),
Row::Closed(tys) => {
for (_name, ty) in tys {
ty.adjust(cutoff);
}
}
}
}
fn shift(&mut self) {
self.adjust(0);
}
}
struct LowerTypes {
env: HashMap<SemanticTypeVar, TypeVar>,
}
impl LowerTypes {
fn lower_closed_row_ty(&self, closed_row: crust::ClosedRow) -> Vec<(String, Type)> {
closed_row
.fields
.into_iter()
.zip(closed_row.values)
.map(|(name, ty)| (name, self.lower_ty(ty)))
.collect()
}
fn lower_row_ty(&self, row: crust::Row) -> Row {
match row {
crust::Row::Open(var) => {
let ty_var = self.env[&SemanticTypeVar::Row(var)];
Row::Open(ty_var)
}
crust::Row::Closed(closed_row) => {
let values = self.lower_closed_row_ty(closed_row);
Row::Closed(values)
}
crust::Row::Unifier(_) => panic!("ICE: Unexpected row unifier in lowering"),
}
}
fn lower_ty(&self, ty: crust::Type) -> Type {
match ty {
crust::Type::Unit => Type::Unit,
crust::Type::Int => Type::Int,
crust::Type::Float => Type::Float,
crust::Type::String => Type::String,
crust::Type::DataFrame => Type::DataFrame,
crust::Type::Var(v) => Type::Var(self.env[&SemanticTypeVar::Ty(v)]),
crust::Type::Abs(param, ret) => {
let param = self.lower_ty(*param);
let ret = self.lower_ty(*ret);
Type::fun(param, ret)
}
crust::Type::Prod(row) => Type::prod(self.lower_row_ty(row)),
crust::Type::Sum(row) => Type::sum(self.lower_row_ty(row)),
crust::Type::Label(_, ty) => self.lower_ty(*ty),
crust::Type::Unifier(_) => panic!("ICE: Unexpected type unifier in lowering"),
}
}
fn lower_ev_ty(&self, evidence: crust::Evidence) -> Type {
match evidence {
crust::Evidence::RowEquation { left, right, goal } => {
let left = self.lower_row_ty(left);
let (left_prod, left_sum) = (Type::prod(left.clone()), Type::sum(left));
let right = self.lower_row_ty(right);
let (right_prod, right_sum) = (Type::prod(right.clone()), Type::sum(right));
let goal = self.lower_row_ty(goal);
let (goal_prod, goal_sum) = (Type::prod(goal.clone()), Type::sum(goal));
let concat = Type::funs([left_prod.clone(), right_prod.clone()], goal_prod.clone());
let branch = {
let a = TypeVar(0);
Type::ty_fun(
Kind::Type,
Type::funs(
[
Type::fun(left_sum.clone().shifted(), Type::Var(a)),
Type::fun(right_sum.clone().shifted(), Type::Var(a)),
goal_sum.clone().shifted(),
],
Type::Var(a),
),
)
};
let prj_left = Type::fun(goal_prod.clone(), left_prod);
let inj_left = Type::fun(left_sum, goal_sum.clone());
Type::prod(Row::Closed(vec![
("project".to_string(), prj_left),
("concat".to_string(), concat),
("inject".to_string(), inj_left),
("branch".to_string(), branch),
]))
}
}
}
}
struct LoweredTyScheme {
scheme: Type,
lower_types: LowerTypes,
kinds: Vec<Kind>,
ev_to_ty: BTreeMap<crust::Evidence, Type>,
}
fn lower_ty_scheme(scheme: crust::TypeScheme) -> LoweredTyScheme {
let mut kinds = vec![Kind::Type; scheme.unbound_tys.len() + scheme.unbound_rows.len()];
let ty_env = scheme
.unbound_tys
.into_iter()
.map(SemanticTypeVar::Ty)
.chain(scheme.unbound_rows.into_iter().map(SemanticTypeVar::Row))
.rev()
.enumerate()
.map(|(i, tyvar)| {
kinds[i] = tyvar.kind();
(tyvar, TypeVar(i))
})
.collect();
let lower = LowerTypes { env: ty_env };
let lower_ty = lower.lower_ty(scheme.typ);
let mut ev_to_ty = BTreeMap::new();
let ev_tys = scheme
.evidence
.into_iter()
.map(|ev| {
let ty = lower.lower_ev_ty(ev.clone());
ev_to_ty.insert(ev, ty.clone());
ty
})
.collect::<Vec<_>>();
let evident_lower_ty = Type::funs(ev_tys, lower_ty);
let bound_lower_ty = kinds
.iter()
.fold(evident_lower_ty, |ty, kind| Type::ty_fun(*kind, ty));
LoweredTyScheme {
scheme: bound_lower_ty,
lower_types: lower,
kinds,
ev_to_ty,
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum RowIndex {
Left(usize),
Right(usize),
}
struct LowerSolvedEv<'a> {
supply: &'a mut VarSupply,
left: Vec<(String, Type)>,
right: Vec<(String, Type)>,
goal: Vec<(String, Type)>,
goal_indices: Vec<(String, RowIndex)>,
left_indices: Vec<usize>,
}
fn unwrap_single(len: usize, var: Var, else_fn: impl FnOnce(Expr) -> Expr) -> Expr {
if len == 1 {
Expr::Variable(var)
} else {
else_fn(Expr::Variable(var))
}
}
fn unwrap_prj(index: usize, len: usize, prod: Var) -> Expr {
unwrap_single(len, prod, |expr| Expr::field(expr, index))
}
impl LowerSolvedEv<'_> {
fn left_prod(&self) -> Type {
Type::prod(Row::Closed(self.left.clone()))
}
fn right_prod(&self) -> Type {
Type::prod(Row::Closed(self.right.clone()))
}
fn goal_prod(&self) -> Type {
Type::prod(Row::Closed(self.goal.clone()))
}
fn left_sum(&self) -> Type {
Type::sum(Row::Closed(self.left.clone()))
}
fn right_sum(&self) -> Type {
Type::sum(Row::Closed(self.right.clone()))
}
fn goal_sum(&self) -> Type {
Type::sum(Row::Closed(self.goal.clone()))
}
fn make_vars<const N: usize>(&mut self, tys: [Type; N]) -> [Var; N] {
tys.map(|ty| {
let id = self.supply.supply();
Var::new(id, ty)
})
}
fn concat(&mut self) -> Expr {
let vars = self.make_vars([self.left_prod(), self.right_prod()]);
Expr::abss(vars.clone(), {
let [left, right] = vars;
let mut elems = self
.goal_indices
.iter()
.map(|(name, row_index)| match row_index {
RowIndex::Left(i) => {
(name.clone(), unwrap_prj(*i, self.left.len(), left.clone()))
}
RowIndex::Right(i) => (
name.clone(),
unwrap_prj(*i, self.right.len(), right.clone()),
),
});
if self.goal_indices.len() == 1 {
elems.next().unwrap().1
} else {
Expr::tuple(elems)
}
})
}
fn prj_left(&mut self) -> Expr {
let [goal] = self.make_vars([self.goal_prod()]);
Expr::abs(goal.clone(), {
if self.left.len() == 1 {
unwrap_prj(self.left_indices[0], self.goal.len(), goal)
} else {
Expr::tuple(self.left_indices.iter().map(|i| {
(
self.goal_indices[*i].0.clone(),
unwrap_prj(*i, self.goal.len(), goal.clone()),
)
}))
}
})
}
fn inj_left(&mut self) -> Expr {
let [left_var] = self.make_vars([self.left_sum()]);
Expr::abs(left_var.clone(), {
let branches = self
.left_indices
.clone()
.into_iter()
.zip(self.left.clone())
.map(|(i, ty)| {
let [branch_var] = self.make_vars([ty.1]);
Expr::branch(branch_var.clone(), {
unwrap_single(self.goal.len(), branch_var, |ir| {
Expr::tag(self.goal_sum(), i, ir)
})
})
})
.collect::<Vec<_>>();
if self.left.len() == 1 {
Expr::app(branches[0].as_fun(), Expr::Variable(left_var))
} else {
Expr::case(self.goal_sum(), Expr::Variable(left_var), branches)
}
})
}
fn branch(&mut self) -> Expr {
let left_sum = self.left_sum().shifted();
let right_sum = self.right_sum().shifted();
let goal_sum = self.goal_sum().shifted();
let ret_ty = Type::Var(TypeVar(0));
let vars = self.make_vars([
Type::fun(left_sum.clone(), ret_ty.clone()),
Type::fun(right_sum.clone(), ret_ty.clone()),
goal_sum,
]);
Expr::ty_abs(
Kind::Type,
Expr::abss(vars.clone(), {
let [left_var, right_var, goal_var] = vars;
let goal_len = self.goal.len();
let mut branches = self.goal_indices.clone().into_iter().map(|(_, row_index)| {
let (i, ty, len, var, sum) = match row_index {
RowIndex::Left(i) => (
i,
self.left[i].1.clone().shifted(),
self.left.len(),
left_var.clone(),
left_sum.clone(),
),
RowIndex::Right(i) => (
i,
self.right[i].1.clone().shifted(),
self.right.len(),
right_var.clone(),
right_sum.clone(),
),
};
let [case_var] = self.make_vars([ty]);
Expr::branch(case_var.clone(), {
Expr::app(
Expr::Variable(var),
unwrap_single(len, case_var, |ir| Expr::tag(sum, i, ir)),
)
})
});
if goal_len == 1 {
Expr::app(branches.next().unwrap().as_fun(), Expr::Variable(goal_var))
} else {
Expr::case(ret_ty, Expr::Variable(goal_var), branches)
}
}),
)
}
fn lower_ev_term(mut self) -> Expr {
Expr::tuple([
("project".to_string(), self.prj_left()),
("concat".to_string(), self.concat()),
("inject".to_string(), self.inj_left()),
("branch".to_string(), self.branch()),
])
}
}
const PROJECT_FIELD_INDEX: usize = 0;
const CONCAT_FIELD_INDEX: usize = 1;
const INJECT_FIELD_INDEX: usize = 2;
const BRANCH_FIELD_INDEX: usize = 3;
struct LowerExpr<'a> {
supply: VarSupply,
types: LowerTypes,
ev_to_var: HashMap<Evidence, Var>,
solved: Vec<(Var, Expr)>,
row_to_ev: BTreeMap<NodeId, Evidence>,
branch_to_ret_ty: BTreeMap<NodeId, crust::Type>,
item_wrappers: BTreeMap<NodeId, crust::ItemWrapper>,
item_source: ItemSource,
item_supply: &'a mut ItemSupply,
}
impl<'a> LowerExpr<'a> {
#[instrument(skip(self), ret ( level = tracing::Level::TRACE))]
fn lookup_ev(&mut self, ev: Evidence) -> Var {
self.ev_to_var
.entry(ev)
.or_insert_with_key(|ev| {
let Evidence::RowEquation {
left: crust::Row::Closed(left),
right: crust::Row::Closed(right),
goal: crust::Row::Closed(goal),
} = ev
else {
panic!("ICE: Unsolved evidence appeared in AST that wasn't in type scheme");
};
let param = self.supply.supply();
let mut left_indices = vec![0; left.fields.len()];
let mut right_indices = vec![0; right.fields.len()];
let goal_indices = goal
.fields
.iter()
.enumerate()
.map(|(goal_indx, field)| {
(
field.clone(),
left.fields
.binary_search(field)
.map(|left_indx| {
left_indices[left_indx] = goal_indx;
RowIndex::Left(left_indx)
})
.or_else(|_| {
right.fields.binary_search(field).map(|right_indx| {
right_indices[right_indx] = goal_indx;
RowIndex::Right(right_indx)
})
})
.expect("ICE: Invalid solved row combination."),
)
})
.collect::<Vec<_>>();
let left_values = self.types.lower_closed_row_ty(left.clone());
let right_values = self.types.lower_closed_row_ty(right.clone());
let goal_values = self.types.lower_closed_row_ty(goal.clone());
let lower_solved_ev = LowerSolvedEv {
supply: &mut self.supply,
left: left_values,
right: right_values,
goal: goal_values,
goal_indices,
left_indices,
};
let term = lower_solved_ev.lower_ev_term();
let ty = self.types.lower_ev_ty(ev.clone());
let var = Var::new(param, ty);
self.solved.push((var.clone(), term));
var
})
.clone()
}
fn lower_expr(&mut self, expr: crust::Expr<TypedVar>) -> Expr {
match expr {
crust::Expr::Variable(_, TypedVar(var, ty)) => Expr::Variable(Var::new(
self.supply.supply_for(var),
self.types.lower_ty(ty),
)),
crust::Expr::Unit(_) => Expr::Unit,
crust::Expr::Integer(_, i) => Expr::Integer(i),
crust::Expr::Float(_, f) => Expr::Float(f),
crust::Expr::String(_, f) => Expr::String(f),
crust::Expr::Abstraction {
parameter: TypedVar(var, ty),
body,
..
} => {
let expr_typ = self.types.lower_ty(ty);
let expr_var = self.supply.supply_for(var);
let expr_body = self.lower_expr(*body);
Expr::abs(Var::new(expr_var, expr_typ), expr_body)
}
crust::Expr::Application {
abstraction: function,
parameter,
..
} => {
let expr_abs = self.lower_expr(*function);
let expr_param = self.lower_expr(*parameter);
Expr::app(expr_abs, expr_param)
}
crust::Expr::Label { expr, .. } => self.lower_expr(*expr),
crust::Expr::Unlabel { expr, .. } => self.lower_expr(*expr),
crust::Expr::Project(id, body) => {
let param = self
.row_to_ev
.get(&id)
.cloned()
.map(|ev| self.lookup_ev(ev))
.expect("ICE: Project AST node lacks an expected evidence");
let term = self.lower_expr(*body);
let prj_fun = Expr::field(Expr::Variable(param), PROJECT_FIELD_INDEX);
Expr::app(prj_fun, term)
}
crust::Expr::Concatenate { id, left, right } => {
let param = self
.row_to_ev
.get(&id)
.cloned()
.map(|ev| self.lookup_ev(ev))
.expect("ICE: Concat AST node lacks an expected evidence");
let left = self.lower_expr(*left);
let right = self.lower_expr(*right);
let concat_fun = Expr::field(Expr::Variable(param), CONCAT_FIELD_INDEX);
Expr::app(Expr::app(concat_fun, left), right)
}
crust::Expr::Inject(id, expr) => {
let param = self
.row_to_ev
.get(&id)
.cloned()
.map(|ev| self.lookup_ev(ev))
.expect("ICE: Inject AST node lacks an expected evidence");
let term = self.lower_expr(*expr);
let inj_fn = Expr::field(Expr::Variable(param), INJECT_FIELD_INDEX);
Expr::app(inj_fn, term)
}
crust::Expr::Branch { id, left, right } => {
let ev = self
.row_to_ev
.get(&id)
.expect("ICE: Branch AST node lacks expected evidence");
let param = self.lookup_ev(ev.clone());
let ret_ty = self
.branch_to_ret_ty
.get(&id)
.map(|ty| self.types.lower_ty(ty.clone()))
.expect("ICE: Branch AST node lacks expected type");
let left = self.lower_expr(*left);
let right = self.lower_expr(*right);
let branch = Expr::ty_app(
Expr::field(Expr::Variable(param), BRANCH_FIELD_INDEX),
TypApp::Ty(ret_ty),
);
Expr::app(Expr::app(branch, left), right)
}
crust::Expr::Item(id, item_id, symbol) => {
let ty = self.item_source.lookup_item(item_id);
let item_expr = Expr::Item(ty, self.item_supply.supply_for(item_id), symbol.into());
let wrapper = self
.item_wrappers
.get(&id)
.cloned()
.expect("ICE: Item lacks expected wrapper");
let ty_expr = wrapper.types.into_iter().fold(item_expr, |expr, ty| {
Expr::ty_app(expr, TypApp::Ty(self.types.lower_ty(ty)))
});
let row_expr = wrapper.rows.into_iter().fold(ty_expr, |expr, row| {
Expr::ty_app(expr, TypApp::Row(self.types.lower_row_ty(row)))
});
wrapper.evidence.into_iter().fold(row_expr, |expr, ev| {
let param = self.lookup_ev(ev);
Expr::app(expr, Expr::Variable(param))
})
}
}
}
}
struct ItemSource {
items: HashMap<crust::ItemId, Type>,
}
impl ItemSource {
fn lookup_item(&self, item: crust::ItemId) -> Type {
self.items[&item].clone()
}
}
fn lower_item_source(items: &crust::ItemSource) -> ItemSource {
ItemSource {
items: items
.types
.iter()
.map(|(item_id, ty_scheme)| {
let lowered_ty_scheme = lower_ty_scheme(ty_scheme.clone());
(*item_id, lowered_ty_scheme.scheme)
})
.collect(),
}
}
pub fn lower_module(item_source: &crust::ItemSource, module: crust::TypedModule) -> Module {
let mut item_supply = ItemSupply::default();
let items = module
.items
.into_iter()
.map(|(id, item)| match item {
crust::TypedItem::Native(item) => (
item_supply.supply_for(id),
Item::Native(lower_with_items(item_source, item, &mut item_supply)),
),
crust::TypedItem::External(item) => (
item_supply.supply_for(id),
Item::External(ExternalItem {
symbol: item.symbol.into(),
external_type: item.external_type,
}),
),
})
.collect();
Module {
items,
item_supply,
instances: Default::default(),
}
}
fn lower_with_items(
item_source: &crust::ItemSource,
item: crust::TypedNativeItem,
item_supply: &mut ItemSupply,
) -> NativeItem {
let lowered_scheme = lower_ty_scheme(item.scheme);
let mut supply = VarSupply::default();
let mut params = vec![];
let ev_to_var: HashMap<crust::Evidence, Var> = lowered_scheme
.ev_to_ty
.into_iter()
.map(|(ev, ty)| {
let param = supply.supply();
let var = Var::new(param, ty);
params.push(var.clone());
(ev, var)
})
.collect();
let mut lower_expr = LowerExpr {
supply,
types: lowered_scheme.lower_types,
ev_to_var,
solved: vec![],
row_to_ev: item.row_to_ev,
branch_to_ret_ty: item.branch_to_ret_typ,
item_wrappers: item.item_wrappers,
item_source: lower_item_source(item_source),
item_supply,
};
let expr = lower_expr.lower_expr(item.typed_expr);
trace_alt!(expr, "lower expr");
let Expr::Abstraction(fun_var, body) = expr else {
panic!("item should be a function: {}", item.symbol.field)
};
let solved_expr = lower_expr
.solved
.into_iter()
.fold(*body, |expr, (var, solved)| Expr::local(var, solved, expr));
trace_alt!(solved_expr, "solved expr");
let param_expr = params
.into_iter()
.rfold(solved_expr, |expr, var| Expr::abs(var, expr));
trace_alt!(param_expr, "param expr");
let bound_expr = lowered_scheme
.kinds
.into_iter()
.fold(Expr::abs(fun_var, param_expr), |expr, kind| {
Expr::ty_abs(kind, expr)
});
trace_alt!(bound_expr, "bound expr");
NativeItem {
symbol: item.symbol.into(),
expr: bound_expr,
typ: lowered_scheme.scheme,
}
}