emmylua_code_analysis 0.22.0

A library for analyzing lua code.
Documentation
use emmylua_parser::{BinaryOperator, LuaAst, LuaAstNode, LuaBinaryExpr, LuaExpr};

use crate::{
    DiagnosticCode, LuaMemberKey, LuaType, LuaTypeDeclId, SemanticModel,
    diagnostic::checker::humanize_lint_type,
};

use super::{Checker, DiagnosticContext};

pub struct EnumValueMismatchChecker;

impl Checker for EnumValueMismatchChecker {
    const CODES: &[DiagnosticCode] = &[DiagnosticCode::EnumValueMismatch];

    fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) {
        let root = semantic_model.get_root().clone();
        for node in root.descendants::<LuaAst>() {
            let condition_expr = match node {
                LuaAst::LuaIfStat(if_stat) => if_stat.get_condition_expr(),
                LuaAst::LuaElseIfClauseStat(elseif_stat) => elseif_stat.get_condition_expr(),
                _ => None,
            };

            if let Some(expr) = condition_expr {
                check_condition_expr(context, semantic_model, expr);
            }
        }
    }
}

fn check_condition_expr(
    context: &mut DiagnosticContext,
    semantic_model: &SemanticModel,
    condition_expr: LuaExpr,
) -> Option<()> {
    if let LuaExpr::BinaryExpr(binary_expr) = condition_expr {
        check_binary_expr(context, semantic_model, binary_expr);
    }
    Some(())
}

fn check_binary_expr(
    context: &mut DiagnosticContext,
    semantic_model: &SemanticModel,
    binary_expr: LuaBinaryExpr,
) -> Option<()> {
    let op_token = binary_expr.get_op_token()?;
    let operator = op_token.get_op();

    if !matches!(operator, BinaryOperator::OpEq | BinaryOperator::OpNe) {
        return Some(());
    }

    let (left_expr, right_expr) = binary_expr.get_exprs()?;
    let left_type = semantic_model.infer_expr(left_expr.clone()).ok()?;
    let right_type = semantic_model.infer_expr(right_expr.clone()).ok()?;

    if check_enum_value_pair(context, &right_expr, &left_type, &right_type).is_some() {
        return Some(());
    }
    if check_enum_value_pair(context, &left_expr, &right_type, &left_type).is_some() {
        return Some(());
    }

    Some(())
}

fn check_enum_value_pair(
    context: &mut DiagnosticContext,
    value_expr: &LuaExpr,
    enum_type: &LuaType,
    value_type: &LuaType,
) -> Option<()> {
    let enum_decl_id = match &enum_type {
        LuaType::Ref(id) | LuaType::Def(id) => id,
        _ => return None,
    };
    let type_decl = context.db.get_type_index().get_type_decl(enum_decl_id)?;
    if !type_decl.is_enum() {
        return None;
    }

    let constant_type = get_constant_type(value_type)?;
    let enum_value_types = get_enum_value_types(context, enum_decl_id)?;

    if !enum_value_types.contains(constant_type) {
        let constant_value_str = humanize_lint_type(context.db, constant_type);
        let enum_values_str: Vec<String> = enum_value_types
            .iter()
            .map(|typ| humanize_lint_type(context.db, typ))
            .collect();
        context.add_diagnostic(
            DiagnosticCode::EnumValueMismatch,
            value_expr.get_range(),
            t!(
                "Value '%{value}' does not match any enum value. Expected one of: %{enum_values}",
                value = constant_value_str,
                enum_values = enum_values_str.join(", ")
            )
            .to_string(),
            None,
        );
    }

    Some(())
}

fn get_enum_value_types(
    context: &DiagnosticContext,
    enum_decl_id: &LuaTypeDeclId,
) -> Option<Vec<LuaType>> {
    let type_decl = context.db.get_type_index().get_type_decl(enum_decl_id)?;
    let mut values = Vec::new();
    let is_enum_key = type_decl.is_enum_key();

    if let Some(members) = context
        .db
        .get_member_index()
        .get_members(&enum_decl_id.clone().into())
    {
        for member in members {
            if is_enum_key {
                let key = member.get_key();
                match key {
                    LuaMemberKey::Name(name) => {
                        values.push(LuaType::StringConst(name.clone().into()));
                    }
                    LuaMemberKey::Integer(i) => {
                        values.push(LuaType::IntegerConst(*i));
                    }
                    LuaMemberKey::ExprType(typ) => {
                        if let Some(value) = get_constant_type(typ) {
                            values.push(value.clone());
                        }
                    }
                    _ => {}
                }
            } else if let Some(type_cache) = context
                .db
                .get_type_index()
                .get_type_cache(&member.get_id().into())
                && let Some(value) = get_constant_type(type_cache.as_type())
            {
                values.push(value.clone());
            }
        }
    }

    Some(values)
}

fn get_constant_type(typ: &LuaType) -> Option<&LuaType> {
    match typ {
        LuaType::StringConst(_)
        | LuaType::DocStringConst(_)
        | LuaType::IntegerConst(_)
        | LuaType::DocIntegerConst(_)
        | LuaType::FloatConst(_)
        | LuaType::BooleanConst(_)
        | LuaType::DocBooleanConst(_)
        | LuaType::TableConst(_) => Some(typ),
        _ => None,
    }
}