use std::collections::{BTreeMap, HashMap};
use itertools::{Either, Itertools as _};
use tracing::instrument;
use crate::compiler::sexpr::ToSExpr;
use crate::{debug_alt, trace_alt};
use super::{
Branch, Expr, Item, ItemId, ItemSupply, Module, NativeItem, Row, Symbol, TypApp, Type,
simplify::subst_typ,
};
pub fn monomorph_module(module: Module, instances: &HashMap<String, Vec<Type>>) -> Module {
let (module_mono_items, poly_items): (BTreeMap<_, _>, BTreeMap<_, _>) = module
.items
.into_iter()
.partition_map(|(item_id, item)| match item {
Item::Native(native_item) => {
if is_mono_type(&native_item.typ) {
Either::Left((item_id, Item::Native(native_item)))
} else {
Either::Right((item_id, native_item))
}
}
item @ Item::External(_) => Either::Left((item_id, item)),
});
let mut mono_items = Default::default();
let mut new_instances = Default::default();
let mut item_supply = module.item_supply;
module_mono_items
.into_iter()
.chain(
poly_items
.values()
.flat_map(|item| {
instances
.get(&item.symbol.field)
.map(|typs| {
typs.iter()
.map(|typ| {
(
item_supply.new_supply(),
Item::Native(NativeItem {
symbol: monomorph_symbol(
typ.clone(),
item.symbol.clone(),
),
expr: Expr::TypApp(
Box::new(item.expr.clone()),
TypApp::Ty(typ.clone()),
),
typ: item.typ.clone(),
}),
)
})
.collect::<Vec<_>>()
})
.unwrap_or_default()
})
.collect::<BTreeMap<_, _>>(),
)
.for_each(|(id, item)| match item {
Item::Native(item) => {
let expr = monomorph(
item.expr,
vec![],
&poly_items,
&mut mono_items,
&mut item_supply,
&mut new_instances,
);
debug_alt!(expr, "monomorphic expr");
let typ = expr.type_of();
mono_items.insert(id, Item::Native(NativeItem { expr, typ, ..item }));
}
Item::External(item) => {
mono_items.insert(id, Item::External(item));
}
});
Module {
items: mono_items,
item_supply,
instances: new_instances,
}
}
fn monomorph(
expr: Expr,
mut types: Vec<Type>,
poly_items: &BTreeMap<ItemId, NativeItem>,
mono_items: &mut BTreeMap<ItemId, Item>,
item_supply: &mut ItemSupply,
instances: &mut HashMap<Symbol, Vec<Type>>,
) -> Expr {
match expr {
Expr::TypAbs(_kind, body) => {
if types.is_empty() {
Expr::Unit
} else {
instantiate(*body, types, poly_items, mono_items, item_supply, instances)
}
}
Expr::TypApp(fun, ty_app) => {
if let TypApp::Ty(ty) = ty_app {
types.push(ty);
monomorph(*fun, types, poly_items, mono_items, item_supply, instances)
} else {
todo!("row monomorph not implemented")
}
}
Expr::Abstraction(var, body) => Expr::abs(
var,
monomorph(*body, types, poly_items, mono_items, item_supply, instances),
),
Expr::Application(fun, parameter) => Expr::app(
monomorph(
*fun,
types.clone(),
poly_items,
mono_items,
item_supply,
instances,
),
monomorph(
*parameter,
types,
poly_items,
mono_items,
item_supply,
instances,
),
),
Expr::Local(var, defn, body) => Expr::local(
var,
monomorph(
*defn,
types.clone(),
poly_items,
mono_items,
item_supply,
instances,
),
monomorph(*body, types, poly_items, mono_items, item_supply, instances),
),
Expr::Tuple(fields) => Expr::tuple(fields.into_iter().map(|(name, field)| {
(
name,
monomorph(
field,
types.clone(),
poly_items,
mono_items,
item_supply,
instances,
),
)
})),
Expr::Field(body, index) => Expr::field(
monomorph(*body, types, poly_items, mono_items, item_supply, instances),
index,
),
Expr::Tag(ty, tag, expr) => Expr::tag(
ty,
tag,
monomorph(*expr, types, poly_items, mono_items, item_supply, instances),
),
Expr::Case(ty, expr, branches) => Expr::case(
ty,
monomorph(
*expr,
types.clone(),
poly_items,
mono_items,
item_supply,
instances,
),
branches.into_iter().map(|branch| Branch {
param: branch.param,
body: monomorph(
branch.body,
types.clone(),
poly_items,
mono_items,
item_supply,
instances,
),
}),
),
Expr::Item(Type::TypAbs(_kind, ty), item_id, symbol) => {
let mono_item_id = item_supply.new_supply();
let local_mono_expr = poly_items.get(&item_id).map(|item| {
monomorph(
item.expr.clone(),
types.clone(),
poly_items,
mono_items,
item_supply,
instances,
)
});
let ty_param = types.pop().expect("missing TyApp to generic item");
let mono_symbol = monomorph_symbol(ty_param.clone(), symbol.clone());
let ty = ty.subst_typ(ty_param.clone());
if let Some(local_mono_expr) = local_mono_expr {
mono_items.insert(
mono_item_id,
Item::Native(NativeItem {
symbol: mono_symbol.clone(),
expr: local_mono_expr,
typ: ty.clone(),
}),
);
} else {
instances
.entry(symbol)
.and_modify(|typs| typs.push(ty_param.clone()))
.or_insert_with(|| vec![ty_param]);
}
Expr::Item(ty, mono_item_id, mono_symbol)
}
expr => expr,
}
}
fn instantiate(
expr: Expr,
types: Vec<Type>,
poly_items: &BTreeMap<ItemId, NativeItem>,
mono_items: &mut BTreeMap<ItemId, Item>,
item_supply: &mut ItemSupply,
instances: &mut HashMap<Symbol, Vec<Type>>,
) -> Expr {
trace_alt!(expr, types, "instantiate");
let expr = types.into_iter().fold(expr, |expr, ty| {
monomorph(
subst_typ(expr, ty),
vec![],
poly_items,
mono_items,
item_supply,
instances,
)
});
trace_alt!(expr, "instantiated");
expr
}
pub fn monomorph_symbol(ty: Type, symbol: Symbol) -> Symbol {
let ty_str = ty.to_sexpr(&()).to_string();
Symbol {
module: symbol.module,
field: format!("{}[{ty_str}]", symbol.field),
}
}
#[instrument(ret(level=tracing::Level::TRACE))]
fn is_mono_type(typ: &Type) -> bool {
match typ {
Type::Unit | Type::Int | Type::Float | Type::String | Type::DataFrame => true,
Type::Var(_type_var) => false,
Type::Abs(param, ret) => is_mono_type(param) && is_mono_type(ret),
Type::TypAbs(_kind, _) => false,
Type::Prod(row) | Type::Sum(row) => match row {
Row::Open(_type_var) => false,
Row::Closed(typs) => typs.iter().all(|(_name, typ)| is_mono_type(typ)),
},
}
}
#[cfg(test)]
mod tests {
use crate::{
compiler::{
mantle::{lower_module, simplify_module},
parse,
},
debug_alt,
runtime::binary,
};
use super::*;
use expect_test::{Expect, expect};
use test_log::test;
use tracing::debug;
use super::super::crust;
fn test_monomorph(src: &str, poly: Expect, mono: Expect) {
let module = parse(src).expect("source should parse");
let mut item_source = crust::ItemSource::default();
binary::register_binary_operator_functions(&mut item_source);
let module = crust::lower(module, &mut item_source, "main");
debug_alt!(module, "crust");
let module = crust::type_infer_with_items(item_source.clone(), module)
.expect("should be valid types");
debug_alt!(module, "crust typed");
let module = lower_module(&item_source, module);
debug!(?module, "mantle");
let module = simplify_module(module);
debug!(?module, "simplified");
poly.assert_debug_eq(&module);
let module = monomorph_module(module, &Default::default());
debug!(?module, "monomorphic");
mono.assert_debug_eq(&module);
}
#[test]
fn polymorphic_add() {
test_monomorph(
"
let i = 1 + 1;
1.0 + 1.0
",
expect![[r#"
(mod
(pub
(symbol "" main)
(abs
(var (var_id 0) (unit))
(app
(ty_app
(item
(itm_id 1)
(symbol std::ops +)
(typ_abs (type) (fn (typ_id 0) (fn (typ_id 0) (typ_id 0)))))
(ty (float)))
(f 1.0)
(f 1.0)))
(fn (unit) (float))))
"#]],
expect![[r#"
(mod
(pub
(symbol "" main)
(abs
(var (var_id 0) (unit))
(app
(item
(itm_id 2)
(symbol std::ops "+[(float)]")
(fn (float) (fn (float) (float))))
(f 1.0)
(f 1.0)))
(fn (unit) (float))))
"#]],
)
}
}