emmylua_code_analysis 0.22.0

A library for analyzing lua code.
Documentation
use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr};
use std::collections::BTreeSet;

use crate::{
    DbIndex, DiagnosticCode, InferFailReason, InferGuard, LuaInferCache, LuaSemanticDeclId,
    LuaType, RenderLevel, SemanticDeclLevel, SemanticModel,
    diagnostic::checker::humanize_lint_type, get_real_type, humanize_type,
    semantic::infer_call_expr_func,
};

use super::{Checker, DiagnosticContext};

pub struct CallNonCallableChecker;

impl Checker for CallNonCallableChecker {
    const CODES: &[DiagnosticCode] = &[DiagnosticCode::CallNonCallable];

    fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) {
        for call_expr in semantic_model.get_root().descendants::<LuaCallExpr>() {
            check_call_expr(context, semantic_model, call_expr);
        }
    }
}

fn check_call_expr(
    context: &mut DiagnosticContext,
    semantic_model: &SemanticModel,
    call_expr: LuaCallExpr,
) -> Option<()> {
    let prefix_expr = call_expr.get_prefix_expr()?;
    let db = semantic_model.get_db();
    let call_expr_type = infer_call_target_type(semantic_model, &prefix_expr)?;
    let mut cache = semantic_model.get_cache().borrow_mut();
    let call_result = infer_call_expr_func(
        db,
        &mut cache,
        call_expr.clone(),
        call_expr_type.clone(),
        &InferGuard::new(),
        None,
    );
    if let Err(reason) = &call_result {
        // Needs-resolve errors (except unresolved operator calls) should not
        // emit "non-callable" diagnostics.
        if reason.is_need_resolve() && !matches!(reason, InferFailReason::UnResolveOperatorCall) {
            return Some(());
        }
    }

    let non_callable_types =
        collect_non_callable_union_types(db, &mut cache, &call_expr, &call_expr_type);
    if call_result.is_ok() && non_callable_types.is_empty() {
        return Some(());
    }

    if !has_non_callable_member(db, &call_expr_type) {
        return Some(());
    }

    let message = if !non_callable_types.is_empty() {
        t!(
            "Cannot call expression of type `%{full}`; non-callable type(s): %{types}.",
            full = humanize_type(db, &call_expr_type, RenderLevel::Detailed),
            types = non_callable_types.join(", "),
        )
        .to_string()
    } else {
        t!(
            "Cannot call expression of type `%{typ}`.",
            typ = humanize_lint_type(db, &call_expr_type),
        )
        .to_string()
    };
    context.add_diagnostic(
        DiagnosticCode::CallNonCallable,
        prefix_expr.get_range(),
        message,
        None,
    );
    Some(())
}

fn infer_call_target_type(
    semantic_model: &SemanticModel,
    prefix_expr: &LuaExpr,
) -> Option<LuaType> {
    let inferred = semantic_model.infer_expr(prefix_expr.clone()).ok();
    let typ = inferred.unwrap_or(LuaType::Unknown);
    if !matches!(
        typ,
        LuaType::Any | LuaType::Unknown | LuaType::SelfInfer | LuaType::Global
    ) {
        return Some(typ);
    }
    let db = semantic_model.get_db();
    let file_id = semantic_model.get_file_id();
    let expr_range = if let LuaExpr::NameExpr(name_expr) = prefix_expr {
        name_expr
            .get_name_token()
            .map(|token| token.get_range())
            .unwrap_or_else(|| prefix_expr.get_range())
    } else {
        prefix_expr.get_range()
    };

    let refs = db.get_reference_index().get_local_reference(&file_id);

    let decl_id = refs
        .and_then(|refs| refs.get_decl_id(&expr_range))
        .or_else(|| {
            let decl = semantic_model.find_decl(
                rowan::NodeOrToken::Node(prefix_expr.syntax().clone()),
                SemanticDeclLevel::default(),
            )?;
            match decl {
                LuaSemanticDeclId::LuaDecl(id) => Some(id),
                _ => None,
            }
        })
        .or_else(|| {
            let LuaExpr::NameExpr(name_expr) = prefix_expr else {
                return None;
            };
            let name_token = name_expr.get_name_token()?;
            let decl = db
                .get_decl_index()
                .get_decl_tree(&file_id)?
                .find_local_decl(name_token.get_name_text(), name_token.get_position())?;
            Some(decl.get_id())
        })?;

    let decl = db.get_decl_index().get_decl(&decl_id)?;
    let value_syntax_id = decl.get_value_syntax_id()?;
    let root = db
        .get_vfs()
        .get_syntax_tree(&decl_id.file_id)?
        .get_red_root();
    let node = value_syntax_id.to_node_from_root(&root)?;
    let expr = LuaExpr::cast(node)?;
    semantic_model.infer_expr(expr).ok()
}

fn has_non_callable_member(db: &DbIndex, typ: &LuaType) -> bool {
    let typ = get_real_type(db, typ).unwrap_or(typ);
    if typ.is_function() || typ.is_call() {
        return false;
    }

    match typ {
        LuaType::Any | LuaType::Unknown | LuaType::SelfInfer | LuaType::Global | LuaType::Nil => {
            false
        }
        LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl
            .get_constraint()
            .is_some_and(|constraint| has_non_callable_member(db, constraint)),
        LuaType::StrTplRef(str_tpl) => str_tpl
            .get_constraint()
            .is_some_and(|constraint| has_non_callable_member(db, constraint)),
        LuaType::Union(union) => union
            .into_vec()
            .iter()
            .any(|t| has_non_callable_member(db, t)),
        LuaType::Intersection(intersection) => intersection
            .get_types()
            .iter()
            .all(|t| has_non_callable_member(db, t)),
        LuaType::MultiLineUnion(union) => union
            .get_unions()
            .iter()
            .any(|(t, _)| has_non_callable_member(db, t)),
        _ => true,
    }
}

fn collect_non_callable_union_types(
    db: &DbIndex,
    cache: &mut LuaInferCache,
    call_expr: &LuaCallExpr,
    typ: &LuaType,
) -> Vec<String> {
    let mut types = BTreeSet::new();
    let mut insert_if_non_callable = |t: &LuaType| {
        let real_type = get_real_type(db, t).unwrap_or(t);
        if *real_type == LuaType::Nil {
            return;
        }
        if !has_non_callable_member(db, real_type) {
            return;
        }
        if infer_call_expr_func(
            db,
            cache,
            call_expr.clone(),
            t.clone(),
            &InferGuard::new(),
            None,
        )
        .is_err()
        {
            types.insert(humanize_lint_type(db, real_type));
        }
    };
    match typ {
        LuaType::Union(union) => {
            for t in union.into_vec() {
                insert_if_non_callable(&t);
            }
        }
        LuaType::MultiLineUnion(union) => {
            for (t, _) in union.get_unions().iter() {
                insert_if_non_callable(t);
            }
        }
        _ => {}
    }

    types.into_iter().collect()
}