runmat-hir 0.4.6

High-level IR for RunMat with type inference and lowering utilities
Documentation
use super::expr::infer_expr_type_with_env;
use super::shared::{
    apply_lvalue_type_effects, apply_struct_field_assertions, collect_function_defs,
    collect_struct_field_assertions, join_env, refine_multi_assign_outputs_from_func, Analysis,
    FuncDef,
};
use crate::{HirClassMember, HirExpr, HirExprKind, HirProgram, HirStmt, Type, VarId};
use std::collections::HashMap;

pub fn infer_function_output_types(prog: &HirProgram) -> HashMap<String, Vec<Type>> {
    fn infer_expr_type(
        expr: &HirExpr,
        env: &HashMap<VarId, Type>,
        func_returns: &HashMap<String, Vec<Type>>,
    ) -> Type {
        infer_expr_type_with_env(expr, env, func_returns)
    }

    #[allow(clippy::type_complexity)]
    fn analyze_stmts(
        stmts: &[HirStmt],
        mut env: HashMap<VarId, Type>,
        returns: &HashMap<String, Vec<Type>>,
        func_defs: &HashMap<String, FuncDef>,
    ) -> Analysis {
        let mut exits = Vec::new();
        let mut i = 0usize;
        while i < stmts.len() {
            match &stmts[i] {
                HirStmt::Assign(var, expr, _, _) => {
                    let t = infer_expr_type(expr, &env, returns);
                    env.insert(*var, t);
                }
                HirStmt::MultiAssign(vars, expr, _, _) => {
                    if let HirExprKind::FuncCall(ref name, _) = expr.kind {
                        let mut per_out: Vec<Type> = returns.get(name).cloned().unwrap_or_default();
                        refine_multi_assign_outputs_from_func(
                            name,
                            &mut per_out,
                            returns,
                            func_defs,
                            infer_expr_type,
                        );
                        for (i, v) in vars.iter().enumerate() {
                            if let Some(id) = v {
                                env.insert(*id, per_out.get(i).cloned().unwrap_or(Type::Unknown));
                            }
                        }
                    } else {
                        let t = infer_expr_type(expr, &env, returns);
                        for v in vars.iter().flatten() {
                            env.insert(*v, t.clone());
                        }
                    }
                }
                HirStmt::ExprStmt(_, _, _) | HirStmt::Break(_) | HirStmt::Continue(_) => {}
                HirStmt::Return(_) => {
                    exits.push(env.clone());
                    return Analysis {
                        exits,
                        fallthrough: None,
                    };
                }
                HirStmt::If {
                    cond,
                    then_body,
                    elseif_blocks,
                    else_body,
                    ..
                } => {
                    let mut assertions: Vec<(VarId, String)> = Vec::new();
                    collect_struct_field_assertions(cond, &mut assertions);
                    let mut then_env = env.clone();
                    apply_struct_field_assertions(&mut then_env, assertions);
                    let then_a = analyze_stmts(then_body, then_env, returns, func_defs);
                    let mut out_env = then_a.fallthrough.clone().unwrap_or_else(|| env.clone());
                    let mut all_exits = then_a.exits.clone();
                    for (c, b) in elseif_blocks {
                        let mut elseif_env = env.clone();
                        let mut els_assertions: Vec<(VarId, String)> = Vec::new();
                        collect_struct_field_assertions(c, &mut els_assertions);
                        apply_struct_field_assertions(&mut elseif_env, els_assertions);
                        let a = analyze_stmts(b, elseif_env, returns, func_defs);
                        if let Some(f) = a.fallthrough {
                            out_env = join_env(&out_env, &f);
                        }
                        all_exits.extend(a.exits);
                    }
                    if let Some(else_body) = else_body {
                        let a = analyze_stmts(else_body, env.clone(), returns, func_defs);
                        if let Some(f) = a.fallthrough {
                            out_env = join_env(&out_env, &f);
                        }
                        all_exits.extend(a.exits);
                    } else {
                        out_env = join_env(&out_env, &env);
                    }
                    env = out_env;
                    exits.extend(all_exits);
                }
                HirStmt::While { body, .. } => {
                    let a = analyze_stmts(body, env.clone(), returns, func_defs);
                    if let Some(f) = a.fallthrough {
                        env = join_env(&env, &f);
                    }
                    exits.extend(a.exits);
                }
                HirStmt::For {
                    var, expr, body, ..
                } => {
                    let t = infer_expr_type(expr, &env, returns);
                    env.insert(*var, t);
                    let a = analyze_stmts(body, env.clone(), returns, func_defs);
                    if let Some(f) = a.fallthrough {
                        env = join_env(&env, &f);
                    }
                    exits.extend(a.exits);
                }
                HirStmt::Switch {
                    cases, otherwise, ..
                } => {
                    let mut out_env: Option<HashMap<VarId, Type>> = None;
                    for (_v, b) in cases {
                        let a = analyze_stmts(b, env.clone(), returns, func_defs);
                        if let Some(f) = a.fallthrough {
                            out_env = Some(match out_env {
                                Some(curr) => join_env(&curr, &f),
                                None => f,
                            });
                        }
                        exits.extend(a.exits);
                    }
                    if let Some(otherwise) = otherwise {
                        let a = analyze_stmts(otherwise, env.clone(), returns, func_defs);
                        if let Some(f) = a.fallthrough {
                            out_env = Some(match out_env {
                                Some(curr) => join_env(&curr, &f),
                                None => f,
                            });
                        }
                        exits.extend(a.exits);
                    } else {
                        out_env = Some(match out_env {
                            Some(curr) => join_env(&curr, &env),
                            None => env.clone(),
                        });
                    }
                    if let Some(f) = out_env {
                        env = f;
                    }
                }
                HirStmt::TryCatch {
                    try_body,
                    catch_body,
                    ..
                } => {
                    let a_try = analyze_stmts(try_body, env.clone(), returns, func_defs);
                    let a_catch = analyze_stmts(catch_body, env.clone(), returns, func_defs);
                    let mut out_env = a_try.fallthrough.clone().unwrap_or_else(|| env.clone());
                    if let Some(f) = a_catch.fallthrough {
                        out_env = join_env(&out_env, &f);
                    }
                    env = out_env;
                    exits.extend(a_try.exits);
                    exits.extend(a_catch.exits);
                }
                HirStmt::Global(_, _) | HirStmt::Persistent(_, _) => {}
                HirStmt::Function { .. } => {}
                HirStmt::ClassDef { .. } => {}
                HirStmt::AssignLValue(lv, expr, _, _) => {
                    apply_lvalue_type_effects(&mut env, lv);
                    let _ = infer_expr_type(expr, &env, returns);
                }
                HirStmt::Import { .. } => {}
            }
            i += 1;
        }
        Analysis {
            exits,
            fallthrough: Some(env),
        }
    }

    fn collect_function_names(stmts: &[HirStmt], acc: &mut Vec<String>) {
        for s in stmts {
            match s {
                HirStmt::Function { name, .. } => acc.push(name.clone()),
                HirStmt::ClassDef { members, .. } => {
                    for m in members {
                        if let HirClassMember::Methods { body, .. } = m {
                            collect_function_names(body, acc);
                        }
                    }
                }
                _ => {}
            }
        }
    }

    let mut function_names: Vec<String> = Vec::new();
    collect_function_names(&prog.body, &mut function_names);
    let mut returns: HashMap<String, Vec<Type>> = function_names
        .iter()
        .map(|n| (n.clone(), Vec::new()))
        .collect();

    let func_defs = collect_function_defs(prog);

    for stmt in &prog.body {
        if let HirStmt::Function {
            name,
            outputs,
            body,
            ..
        } = stmt
        {
            let mut per_output: Vec<Type> = vec![Type::Unknown; outputs.len()];
            let analysis = analyze_stmts(body, HashMap::new(), &returns, &func_defs);
            let mut accumulate = |env: &HashMap<VarId, Type>| {
                for (i, out_id) in outputs.iter().enumerate() {
                    if let Some(t) = env.get(out_id) {
                        per_output[i] = per_output[i].unify(t);
                    }
                }
            };
            if let Some(f) = &analysis.fallthrough {
                accumulate(f);
            }
            for e in &analysis.exits {
                accumulate(e);
            }
            returns.insert(name.clone(), per_output);
        }
    }

    let mut changed = true;
    let mut iter = 0usize;
    let max_iters = 3usize;
    while changed && iter < max_iters {
        changed = false;
        iter += 1;
        for stmt in &prog.body {
            match stmt {
                HirStmt::Function {
                    name,
                    outputs,
                    body,
                    ..
                } => {
                    let analysis = analyze_stmts(body, HashMap::new(), &returns, &func_defs);
                    let mut per_output: Vec<Type> = vec![Type::Unknown; outputs.len()];
                    let mut accumulate = |env: &HashMap<VarId, Type>| {
                        for (i, out_id) in outputs.iter().enumerate() {
                            if let Some(t) = env.get(out_id) {
                                per_output[i] = per_output[i].unify(t);
                            }
                        }
                    };
                    for e in &analysis.exits {
                        accumulate(e);
                    }
                    if let Some(f) = &analysis.fallthrough {
                        accumulate(f);
                    }
                    if returns.get(name) != Some(&per_output) {
                        returns.insert(name.clone(), per_output);
                        changed = true;
                    }
                }
                HirStmt::ClassDef { members, .. } => {
                    for m in members {
                        if let HirClassMember::Methods { body, .. } = m {
                            for s in body {
                                if let HirStmt::Function {
                                    name,
                                    outputs,
                                    body,
                                    ..
                                } = s
                                {
                                    let analysis =
                                        analyze_stmts(body, HashMap::new(), &returns, &func_defs);
                                    let mut per_output: Vec<Type> =
                                        vec![Type::Unknown; outputs.len()];
                                    let mut accumulate = |env: &HashMap<VarId, Type>| {
                                        for (i, out_id) in outputs.iter().enumerate() {
                                            if let Some(t) = env.get(out_id) {
                                                per_output[i] = per_output[i].unify(t);
                                            }
                                        }
                                    };
                                    for e in &analysis.exits {
                                        accumulate(e);
                                    }
                                    if let Some(f) = &analysis.fallthrough {
                                        accumulate(f);
                                    }
                                    if returns.get(name) != Some(&per_output) {
                                        returns.insert(name.clone(), per_output);
                                        changed = true;
                                    }
                                }
                            }
                        }
                    }
                }
                _ => {}
            }
        }
    }

    returns
}