blr-lang 0.1.0

A language implementation that provides type safe dataframes
Documentation
use std::collections::{BTreeMap, HashMap};

use itertools::{Either, Itertools as _};
use tracing::instrument;

use crate::compiler::sexpr::ToSExpr;
use crate::{debug_alt, trace_alt};

use super::{
    Branch, Expr, Item, ItemId, ItemSupply, Module, NativeItem, Row, Symbol, TypApp, Type,
    simplify::subst_typ,
};

pub fn monomorph_module(module: Module, instances: &HashMap<String, Vec<Type>>) -> Module {
    let (module_mono_items, poly_items): (BTreeMap<_, _>, BTreeMap<_, _>) = module
        .items
        .into_iter()
        .partition_map(|(item_id, item)| match item {
            Item::Native(native_item) => {
                if is_mono_type(&native_item.typ) {
                    Either::Left((item_id, Item::Native(native_item)))
                } else {
                    Either::Right((item_id, native_item))
                }
            }
            // All external items are monomorphic
            item @ Item::External(_) => Either::Left((item_id, item)),
        });
    let mut mono_items = Default::default();
    let mut new_instances = Default::default();
    let mut item_supply = module.item_supply;
    module_mono_items
        .into_iter()
        // Include specific polymorphic item instances
        .chain(
            poly_items
                .values()
                .flat_map(|item| {
                    instances
                        .get(&item.symbol.field)
                        .map(|typs| {
                            typs.iter()
                                .map(|typ| {
                                    (
                                        // Create new item_id for new instance
                                        item_supply.new_supply(),
                                        Item::Native(NativeItem {
                                            symbol: monomorph_symbol(
                                                typ.clone(),
                                                item.symbol.clone(),
                                            ),
                                            expr: Expr::TypApp(
                                                Box::new(item.expr.clone()),
                                                TypApp::Ty(typ.clone()),
                                            ),
                                            typ: item.typ.clone(),
                                        }),
                                    )
                                })
                                // TODO: Avoid collecting
                                .collect::<Vec<_>>()
                        })
                        .unwrap_or_default()
                })
                // TODO: Avoid collecting
                .collect::<BTreeMap<_, _>>(),
        )
        .for_each(|(id, item)| match item {
            Item::Native(item) => {
                let expr = monomorph(
                    item.expr,
                    vec![],
                    &poly_items,
                    &mut mono_items,
                    &mut item_supply,
                    &mut new_instances,
                );
                debug_alt!(expr, "monomorphic expr");
                let typ = expr.type_of();
                mono_items.insert(id, Item::Native(NativeItem { expr, typ, ..item }));
            }
            Item::External(item) => {
                mono_items.insert(id, Item::External(item));
            }
        });
    Module {
        items: mono_items,
        item_supply,
        instances: new_instances,
    }
}

fn monomorph(
    expr: Expr,
    mut types: Vec<Type>,
    poly_items: &BTreeMap<ItemId, NativeItem>,
    mono_items: &mut BTreeMap<ItemId, Item>,
    item_supply: &mut ItemSupply,
    instances: &mut HashMap<Symbol, Vec<Type>>,
) -> Expr {
    match expr {
        Expr::TypAbs(_kind, body) => {
            if types.is_empty() {
                // HACK:
                // We guarantee no bare TypAbs, this must be dead code that simplify was not able to
                // remove.
                //
                // TODO: Improve dead code elimitation for unused fields of tuples.
                Expr::Unit
            } else {
                instantiate(*body, types, poly_items, mono_items, item_supply, instances)
            }
        }
        Expr::TypApp(fun, ty_app) => {
            if let TypApp::Ty(ty) = ty_app {
                types.push(ty);
                monomorph(*fun, types, poly_items, mono_items, item_supply, instances)
            } else {
                todo!("row monomorph not implemented")
            }
        }
        Expr::Abstraction(var, body) => Expr::abs(
            var,
            monomorph(*body, types, poly_items, mono_items, item_supply, instances),
        ),
        Expr::Application(fun, parameter) => Expr::app(
            monomorph(
                *fun,
                types.clone(),
                poly_items,
                mono_items,
                item_supply,
                instances,
            ),
            monomorph(
                *parameter,
                types,
                poly_items,
                mono_items,
                item_supply,
                instances,
            ),
        ),
        Expr::Local(var, defn, body) => Expr::local(
            var,
            monomorph(
                *defn,
                types.clone(),
                poly_items,
                mono_items,
                item_supply,
                instances,
            ),
            monomorph(*body, types, poly_items, mono_items, item_supply, instances),
        ),
        Expr::Tuple(fields) => Expr::tuple(fields.into_iter().map(|(name, field)| {
            (
                name,
                monomorph(
                    field,
                    types.clone(),
                    poly_items,
                    mono_items,
                    item_supply,
                    instances,
                ),
            )
        })),
        Expr::Field(body, index) => Expr::field(
            monomorph(*body, types, poly_items, mono_items, item_supply, instances),
            index,
        ),
        Expr::Tag(ty, tag, expr) => Expr::tag(
            ty,
            tag,
            monomorph(*expr, types, poly_items, mono_items, item_supply, instances),
        ),
        Expr::Case(ty, expr, branches) => Expr::case(
            ty,
            monomorph(
                *expr,
                types.clone(),
                poly_items,
                mono_items,
                item_supply,
                instances,
            ),
            branches.into_iter().map(|branch| Branch {
                param: branch.param,
                body: monomorph(
                    branch.body,
                    types.clone(),
                    poly_items,
                    mono_items,
                    item_supply,
                    instances,
                ),
            }),
        ),
        Expr::Item(Type::TypAbs(_kind, ty), item_id, symbol) => {
            let mono_item_id = item_supply.new_supply();

            let local_mono_expr = poly_items.get(&item_id).map(|item| {
                monomorph(
                    item.expr.clone(),
                    types.clone(),
                    poly_items,
                    mono_items,
                    item_supply,
                    instances,
                )
            });
            let ty_param = types.pop().expect("missing TyApp to generic item");
            let mono_symbol = monomorph_symbol(ty_param.clone(), symbol.clone());
            let ty = ty.subst_typ(ty_param.clone());

            if let Some(local_mono_expr) = local_mono_expr {
                mono_items.insert(
                    mono_item_id,
                    Item::Native(NativeItem {
                        symbol: mono_symbol.clone(),
                        expr: local_mono_expr,
                        typ: ty.clone(),
                    }),
                );
            } else {
                // When we create an instance of a non-local item record it.
                instances
                    .entry(symbol)
                    .and_modify(|typs| typs.push(ty_param.clone()))
                    .or_insert_with(|| vec![ty_param]);
            }
            Expr::Item(ty, mono_item_id, mono_symbol)
        }
        expr => expr,
    }
}

fn instantiate(
    expr: Expr,
    types: Vec<Type>,
    poly_items: &BTreeMap<ItemId, NativeItem>,
    mono_items: &mut BTreeMap<ItemId, Item>,
    item_supply: &mut ItemSupply,
    instances: &mut HashMap<Symbol, Vec<Type>>,
) -> Expr {
    trace_alt!(expr, types, "instantiate");
    let expr = types.into_iter().fold(expr, |expr, ty| {
        monomorph(
            subst_typ(expr, ty),
            vec![],
            poly_items,
            mono_items,
            item_supply,
            instances,
        )
    });
    trace_alt!(expr, "instantiated");
    expr
}

pub fn monomorph_symbol(ty: Type, symbol: Symbol) -> Symbol {
    // TODO make explicit name mangling scheme
    let ty_str = ty.to_sexpr(&()).to_string();
    Symbol {
        module: symbol.module,
        field: format!("{}[{ty_str}]", symbol.field),
    }
}

#[instrument(ret(level=tracing::Level::TRACE))]
fn is_mono_type(typ: &Type) -> bool {
    match typ {
        Type::Unit | Type::Int | Type::Float | Type::String | Type::DataFrame => true,
        Type::Var(_type_var) => false,
        Type::Abs(param, ret) => is_mono_type(param) && is_mono_type(ret),
        Type::TypAbs(_kind, _) => false,
        Type::Prod(row) | Type::Sum(row) => match row {
            Row::Open(_type_var) => false,
            Row::Closed(typs) => typs.iter().all(|(_name, typ)| is_mono_type(typ)),
        },
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        compiler::{
            mantle::{lower_module, simplify_module},
            parse,
        },
        debug_alt,
        runtime::binary,
    };

    use super::*;

    use expect_test::{Expect, expect};
    use test_log::test;
    use tracing::debug;

    use super::super::crust;

    fn test_monomorph(src: &str, poly: Expect, mono: Expect) {
        let module = parse(src).expect("source should parse");
        let mut item_source = crust::ItemSource::default();
        binary::register_binary_operator_functions(&mut item_source);
        let module = crust::lower(module, &mut item_source, "main");
        debug_alt!(module, "crust");
        let module = crust::type_infer_with_items(item_source.clone(), module)
            .expect("should be valid types");
        debug_alt!(module, "crust typed");
        let module = lower_module(&item_source, module);
        debug!(?module, "mantle");
        let module = simplify_module(module);
        debug!(?module, "simplified");
        poly.assert_debug_eq(&module);
        let module = monomorph_module(module, &Default::default());
        debug!(?module, "monomorphic");
        mono.assert_debug_eq(&module);
    }

    #[test]
    fn polymorphic_add() {
        test_monomorph(
            "
            let i = 1 + 1;
            1.0 + 1.0
            ",
            expect![[r#"
                (mod
                  (pub
                    (symbol "" main)
                    (abs
                      (var (var_id 0) (unit))
                      (app
                        (ty_app
                          (item
                            (itm_id 1)
                            (symbol std::ops +)
                            (typ_abs (type) (fn (typ_id 0) (fn (typ_id 0) (typ_id 0)))))
                          (ty (float)))
                        (f 1.0)
                        (f 1.0)))
                    (fn (unit) (float))))
            "#]],
            expect![[r#"
                (mod
                  (pub
                    (symbol "" main)
                    (abs
                      (var (var_id 0) (unit))
                      (app
                        (item
                          (itm_id 2)
                          (symbol std::ops "+[(float)]")
                          (fn (float) (fn (float) (float))))
                        (f 1.0)
                        (f 1.0)))
                    (fn (unit) (float))))
            "#]],
        )
    }
}