emmylua_code_analysis 0.22.0

A library for analyzing lua code.
Documentation
use emmylua_parser::{
    LuaAstNode, LuaExpr, LuaForStat, LuaIndexKey, LuaIndexMemberExpr, LuaNameExpr, LuaUnaryExpr,
    NumberResult, UnaryOperator,
};

use crate::{
    DbIndex, InferFailReason, LuaArrayLen, LuaArrayType, LuaInferCache, LuaType, TypeOps,
    infer_expr, semantic::infer::narrow::get_var_expr_var_ref_id,
};

pub fn infer_array_member(
    db: &DbIndex,
    cache: &mut LuaInferCache,
    array_type: &LuaArrayType,
    index_member_expr: LuaIndexMemberExpr,
) -> Result<LuaType, InferFailReason> {
    let key = index_member_expr
        .get_index_key()
        .ok_or(InferFailReason::None)?;
    let index_prefix_expr = match index_member_expr {
        LuaIndexMemberExpr::TableField(_) => {
            return Ok(array_type.get_base().clone());
        }
        _ => index_member_expr
            .get_prefix_expr()
            .ok_or(InferFailReason::None)?,
    };

    match key {
        LuaIndexKey::Integer(i) => {
            if !db.get_emmyrc().strict.array_index {
                return Ok(array_type.get_base().clone());
            }

            let base_type = array_type.get_base();
            match array_type.get_len() {
                LuaArrayLen::None => {}
                LuaArrayLen::Max(max_len) => {
                    if let NumberResult::Int(index_value) = i.get_number_value() {
                        if index_value > 0 && index_value <= *max_len {
                            return Ok(base_type.clone());
                        }
                    }
                }
            }

            let result_type = match &base_type {
                LuaType::Any | LuaType::Unknown => base_type.clone(),
                _ => TypeOps::Union.apply(db, base_type, &LuaType::Nil),
            };

            Ok(result_type)
        }
        LuaIndexKey::Expr(expr) => {
            let expr_type = infer_expr(db, cache, expr.clone())?;
            if expr_type.is_integer() {
                let base_type = array_type.get_base();
                match (array_type.get_len(), expr_type) {
                    (
                        LuaArrayLen::Max(max_len),
                        LuaType::IntegerConst(index_value) | LuaType::DocIntegerConst(index_value),
                    ) => {
                        if index_value > 0 && index_value <= *max_len {
                            return Ok(base_type.clone());
                        }
                    }
                    _ => {
                        if check_iter_var_range(db, cache, &expr, index_prefix_expr)
                            .unwrap_or(false)
                        {
                            return Ok(base_type.clone());
                        }
                    }
                }

                let result_type = match &base_type {
                    LuaType::Any | LuaType::Unknown => base_type.clone(),
                    _ => {
                        if db.get_emmyrc().strict.array_index {
                            TypeOps::Union.apply(db, base_type, &LuaType::Nil)
                        } else {
                            base_type.clone()
                        }
                    }
                };

                Ok(result_type)
            } else {
                Err(InferFailReason::FieldNotFound)
            }
        }
        _ => Err(InferFailReason::FieldNotFound),
    }
}

pub fn check_iter_var_range(
    db: &DbIndex,
    cache: &mut LuaInferCache,
    may_iter_var: &LuaExpr,
    prefix_expr: LuaExpr,
) -> Option<bool> {
    match may_iter_var {
        LuaExpr::NameExpr(name_expr) => check_index_var_in_range(db, cache, name_expr, prefix_expr),
        LuaExpr::UnaryExpr(unary_expr) => check_is_len(db, cache, unary_expr, prefix_expr),
        _ => None,
    }
}

fn check_index_var_in_range(
    db: &DbIndex,
    cache: &mut LuaInferCache,
    iter_var: &LuaNameExpr,
    prefix_expr: LuaExpr,
) -> Option<bool> {
    let decl_id = db
        .get_reference_index()
        .get_var_reference_decl(&cache.get_file_id(), iter_var.get_range())?;

    let decl = db.get_decl_index().get_decl(&decl_id)?;
    let decl_syntax_id = decl.get_syntax_id();
    if !decl_syntax_id.is_token() {
        return None;
    }

    let root = prefix_expr.get_root();
    let token = decl_syntax_id.to_token_from_root(&root)?;
    let parent_node = token.parent()?;
    let for_stat = LuaForStat::cast(parent_node)?;
    let iter_exprs = for_stat.get_iter_expr().collect::<Vec<_>>();
    let test_len_expr = match iter_exprs.len() {
        2 => {
            let LuaExpr::UnaryExpr(unary_expr) = iter_exprs[1].clone() else {
                return None;
            };
            unary_expr
        }
        3 => {
            let step_type = infer_expr(db, cache, iter_exprs[2].clone()).ok()?;
            let LuaType::IntegerConst(step_value) = step_type else {
                return None;
            };
            if step_value > 0 {
                let LuaExpr::UnaryExpr(unary_expr) = iter_exprs[1].clone() else {
                    return None;
                };
                unary_expr
            } else if step_value < 0 {
                let LuaExpr::UnaryExpr(unary_expr) = iter_exprs[0].clone() else {
                    return None;
                };
                unary_expr
            } else {
                return None;
            }
        }
        _ => return None,
    };

    let op = test_len_expr.get_op_token()?;
    if op.get_op() != UnaryOperator::OpLen {
        return None;
    }

    let len_expr = test_len_expr.get_expr()?;
    let len_expr_var_ref_id = get_var_expr_var_ref_id(db, cache, len_expr)?;
    let prefix_expr_var_ref_id = get_var_expr_var_ref_id(db, cache, prefix_expr)?;

    Some(len_expr_var_ref_id == prefix_expr_var_ref_id)
}

fn check_is_len(
    db: &DbIndex,
    cache: &mut LuaInferCache,
    unary_expr: &LuaUnaryExpr,
    prefix_expr: LuaExpr,
) -> Option<bool> {
    let op = unary_expr.get_op_token()?;
    if op.get_op() != UnaryOperator::OpLen {
        return None;
    }

    let inner_var_expr = unary_expr.get_expr()?;
    let len_expr_var_ref_id = get_var_expr_var_ref_id(db, cache, inner_var_expr)?;
    let prefix_expr_var_ref_id = get_var_expr_var_ref_id(db, cache, prefix_expr)?;

    Some(len_expr_var_ref_id == prefix_expr_var_ref_id)
}