runmat-hir 0.4.4

High-level IR for RunMat with type inference and lowering utilities
Documentation
use super::shared::{
    eval_const_num, infer_dataset_type_from_literal_path, infer_slice_result_shape,
    literal_path_arg, logical_binary_result, resolve_context_from_args, shape_rank,
};
use crate::{HirExpr, HirExprKind, Type, VarId};
use runmat_parser as parser;
use std::collections::HashMap;

pub fn infer_expr_type_with_env(
    expr: &HirExpr,
    env: &HashMap<VarId, Type>,
    func_returns: &HashMap<String, Vec<Type>>,
) -> Type {
    fn unify_tensor(a: &Type, b: &Type) -> Type {
        match (a, b) {
            (Type::Tensor { shape: sa }, Type::Tensor { shape: sb }) => match (sa, sb) {
                (Some(sa), Some(sb)) => {
                    let maxr = sa.len().max(sb.len());
                    let mut out: Vec<Option<usize>> = Vec::with_capacity(maxr);
                    for i in 0..maxr {
                        let da = sa.get(i).cloned().unwrap_or(None);
                        let db = sb.get(i).cloned().unwrap_or(None);
                        let d = match (da, db) {
                            (Some(a), Some(b)) => {
                                if a == b {
                                    Some(a)
                                } else if a == 1 {
                                    Some(b)
                                } else if b == 1 {
                                    Some(a)
                                } else {
                                    None
                                }
                            }
                            (Some(a), None) => Some(a),
                            (None, Some(b)) => Some(b),
                            (None, None) => None,
                        };
                        out.push(d);
                    }
                    Type::Tensor { shape: Some(out) }
                }
                _ => Type::tensor(),
            },
            (Type::Tensor { .. }, _) | (_, Type::Tensor { .. }) => Type::tensor(),
            _ => Type::tensor(),
        }
    }

    use HirExprKind as K;

    match &expr.kind {
        K::Number(_) => Type::Num,
        K::String(_) => Type::String,
        K::Constant(_) => Type::Num,
        K::Var(id) => env.get(id).cloned().unwrap_or(Type::Unknown),
        K::Unary(_, e) => infer_expr_type_with_env(e, env, func_returns),
        K::Binary(a, op, b) => {
            let ta = infer_expr_type_with_env(a, env, func_returns);
            let tb = infer_expr_type_with_env(b, env, func_returns);
            match op {
                parser::BinOp::Mul => runmat_builtins::shape_rules::matmul_output_type(&ta, &tb),
                parser::BinOp::LeftDiv => {
                    runmat_builtins::shape_rules::left_divide_output_type(&ta, &tb)
                }
                parser::BinOp::RightDiv => {
                    runmat_builtins::shape_rules::right_divide_output_type(&ta, &tb)
                }
                parser::BinOp::Add
                | parser::BinOp::Sub
                | parser::BinOp::Pow
                | parser::BinOp::ElemMul
                | parser::BinOp::ElemDiv
                | parser::BinOp::ElemPow
                | parser::BinOp::ElemLeftDiv => {
                    if matches!(ta, Type::Tensor { .. }) || matches!(tb, Type::Tensor { .. }) {
                        unify_tensor(&ta, &tb)
                    } else {
                        Type::Num
                    }
                }
                parser::BinOp::Equal
                | parser::BinOp::NotEqual
                | parser::BinOp::Less
                | parser::BinOp::LessEqual
                | parser::BinOp::Greater
                | parser::BinOp::GreaterEqual => logical_binary_result(&ta, &tb),
                parser::BinOp::AndAnd
                | parser::BinOp::OrOr
                | parser::BinOp::BitAnd
                | parser::BinOp::BitOr => logical_binary_result(&ta, &tb),
                parser::BinOp::Colon => runmat_builtins::shape_rules::infer_range_shape(
                    eval_const_num(a),
                    None,
                    eval_const_num(b),
                )
                .map(|shape| Type::Tensor { shape: Some(shape) })
                .unwrap_or_else(Type::tensor),
            }
        }
        K::Tensor(rows) => {
            let mut row_types: Vec<Vec<Type>> = Vec::new();
            for row in rows {
                let mut types = Vec::new();
                for e in row {
                    types.push(infer_expr_type_with_env(e, env, func_returns));
                }
                row_types.push(types);
            }
            if let Some(shape) = runmat_builtins::shape_rules::concat_shape(&row_types) {
                return Type::Tensor { shape: Some(shape) };
            }
            let r = rows.len();
            let c = rows.iter().map(|row| row.len()).max().unwrap_or(0);
            if r > 0 && rows.iter().all(|row| row.len() == c) {
                Type::tensor_with_shape(vec![r, c])
            } else {
                Type::tensor()
            }
        }
        K::Cell(rows) => {
            let mut elem_ty: Option<Type> = None;
            let mut len: usize = 0;
            for row in rows {
                for e in row {
                    let t = infer_expr_type_with_env(e, env, func_returns);
                    elem_ty = Some(match elem_ty {
                        Some(curr) => curr.unify(&t),
                        None => t,
                    });
                    len += 1;
                }
            }
            Type::Cell {
                element_type: elem_ty.map(Box::new),
                length: Some(len),
            }
        }
        K::Index(base, idxs) => {
            let bt = infer_expr_type_with_env(base, env, func_returns);
            let idx_types: Vec<Type> = idxs
                .iter()
                .map(|e| infer_expr_type_with_env(e, env, func_returns))
                .collect();
            runmat_builtins::shape_rules::index_output_type(&bt, &idx_types)
        }
        K::IndexCell(base, idxs) => {
            let bt = infer_expr_type_with_env(base, env, func_returns);
            if let Type::Cell {
                element_type: Some(t),
                ..
            } = bt
            {
                let scalar = idxs.len() == 1
                    && matches!(
                        infer_expr_type_with_env(&idxs[0], env, func_returns),
                        Type::Int | Type::Num | Type::Bool | Type::Tensor { .. }
                    );
                if scalar {
                    *t
                } else {
                    Type::Unknown
                }
            } else {
                Type::Unknown
            }
        }
        K::Range(start, step, end) => runmat_builtins::shape_rules::infer_range_shape(
            eval_const_num(start),
            step.as_ref().and_then(|s| eval_const_num(s)),
            eval_const_num(end),
        )
        .map(|shape| Type::Tensor { shape: Some(shape) })
        .unwrap_or_else(Type::tensor),
        K::FuncCall(name, args) => {
            if name == "data.open" {
                if let Some(path_expr) = args.first() {
                    if let Some(path) = literal_path_arg(path_expr) {
                        if let Some(dataset_ty) = infer_dataset_type_from_literal_path(&path) {
                            return dataset_ty;
                        }
                    }
                }
                return Type::DataDataset { arrays: None };
            }
            if name == "data.create" || name == "data.import" {
                return Type::DataDataset { arrays: None };
            }
            if name == "Dataset.array" {
                if let Some(base) = args.first() {
                    let base_ty = infer_expr_type_with_env(base, env, func_returns);
                    if let Type::DataDataset {
                        arrays: Some(arrays),
                    } = base_ty
                    {
                        if let Some(name_expr) = args.get(1) {
                            if let Some(name) = literal_path_arg(name_expr) {
                                if let Some(info) = arrays.get(&name) {
                                    return Type::DataArray {
                                        dtype: info.dtype.clone(),
                                        shape: info.shape.clone(),
                                        chunk_shape: info.chunk_shape.clone(),
                                        codec: info.codec.clone(),
                                    };
                                }
                            }
                        }
                    }
                }
                return Type::DataArray {
                    dtype: None,
                    shape: None,
                    chunk_shape: None,
                    codec: None,
                };
            }
            if let Some(v) = func_returns.get(name) {
                v.first().cloned().unwrap_or(Type::Unknown)
            } else {
                let arg_types: Vec<Type> = args
                    .iter()
                    .map(|arg| infer_expr_type_with_env(arg, env, func_returns))
                    .collect();
                let ctx = resolve_context_from_args(args);
                let builtins = runmat_builtins::builtin_functions();
                if let Some(b) = builtins.iter().find(|b| b.name == *name) {
                    b.infer_return_type_with_context(&arg_types, &ctx)
                } else {
                    Type::Unknown
                }
            }
        }
        K::MethodCall(base, method, args) | K::DottedInvoke(base, method, args) => {
            let base_ty = infer_expr_type_with_env(base, env, func_returns);
            if let Type::DataDataset { arrays } = &base_ty {
                match method.as_str() {
                    "array" => {
                        if let Some(name_expr) = args.first() {
                            if let Some(name) = literal_path_arg(name_expr) {
                                if let Some(info) = arrays.as_ref().and_then(|m| m.get(&name)) {
                                    return Type::DataArray {
                                        dtype: info.dtype.clone(),
                                        shape: info.shape.clone(),
                                        chunk_shape: info.chunk_shape.clone(),
                                        codec: info.codec.clone(),
                                    };
                                }
                            }
                        }
                        return Type::DataArray {
                            dtype: None,
                            shape: None,
                            chunk_shape: None,
                            codec: None,
                        };
                    }
                    "begin" => return Type::DataTransaction,
                    "arrays" => return Type::cell_of(Type::String),
                    "has_array" => return Type::Bool,
                    "id" | "path" | "version" | "snapshot" => return Type::String,
                    "attrs" => return Type::Struct { known_fields: None },
                    "get_attr" => return Type::Unknown,
                    "set_attr" | "set_attrs" => return Type::Bool,
                    "refresh" => {
                        return Type::DataDataset {
                            arrays: arrays.clone(),
                        };
                    }
                    _ => {}
                }
            }

            if let Type::DataArray {
                dtype,
                shape,
                chunk_shape,
                codec,
            } = &base_ty
            {
                match method.as_str() {
                    "name" | "dtype" | "codec" => return Type::String,
                    "rank" => return Type::Int,
                    "shape" => {
                        let rank = shape_rank(shape);
                        return Type::Tensor {
                            shape: Some(vec![Some(1), rank]),
                        };
                    }
                    "chunk_shape" => {
                        let rank = shape_rank(chunk_shape);
                        return Type::Tensor {
                            shape: Some(vec![Some(1), rank]),
                        };
                    }
                    "read" => {
                        let out_shape = if let Some(slice_expr) = args.first() {
                            infer_slice_result_shape(shape, slice_expr)
                        } else {
                            shape.clone()
                        };
                        return Type::Tensor { shape: out_shape };
                    }
                    "write" | "resize" | "fill" => return Type::Bool,
                    _ => {
                        let _ = (dtype, codec);
                    }
                }
            }

            if let Type::DataTransaction = base_ty {
                match method.as_str() {
                    "id" | "status" => return Type::String,
                    "write" | "resize" | "fill" | "set_attr" | "set_attrs" | "delete_array"
                    | "create_array" | "commit" | "abort" => return Type::Bool,
                    _ => {}
                }
            }

            let mut arg_types = Vec::with_capacity(args.len() + 1);
            arg_types.push(base_ty);
            arg_types.extend(
                args.iter()
                    .map(|arg| infer_expr_type_with_env(arg, env, func_returns)),
            );
            let ctx = resolve_context_from_args(args);
            let builtins = runmat_builtins::builtin_functions();
            let suffix = format!(".{method}");
            let candidates = builtins
                .iter()
                .filter(|b| b.name.ends_with(&suffix))
                .collect::<Vec<_>>();
            if candidates.is_empty() {
                Type::Unknown
            } else {
                let mut ty = candidates[0].infer_return_type_with_context(&arg_types, &ctx);
                for candidate in candidates.iter().skip(1) {
                    let next = candidate.infer_return_type_with_context(&arg_types, &ctx);
                    ty = ty.unify(&next);
                }
                ty
            }
        }
        K::Member(base, _) => {
            let _bt = infer_expr_type_with_env(base, env, func_returns);
            Type::Unknown
        }
        K::MemberDynamic(_, _) => Type::Unknown,
        K::AnonFunc { .. } => Type::Function {
            params: vec![Type::Unknown],
            returns: Box::new(Type::Unknown),
        },
        K::FuncHandle(_) => Type::Function {
            params: vec![Type::Unknown],
            returns: Box::new(Type::Unknown),
        },
        K::MetaClass(_) => Type::String,
        K::End => Type::Unknown,
        K::Colon => Type::tensor(),
    }
}