emmylua_code_analysis 0.22.0

A library for analyzing lua code.
Documentation
use std::collections::HashSet;

use emmylua_parser::{
    LuaAstNode, LuaClosureExpr, LuaDocTagParam, LuaDocTagReturn, LuaDocTagReturnOverload, LuaStat,
};

use crate::{
    DiagnosticCode, LuaSemanticDeclId, LuaSignatureId, LuaType, SemanticDeclLevel, SemanticModel,
    SignatureReturnStatus,
};

use super::{Checker, DiagnosticContext, get_closure_expr_comment, get_return_stats};

pub struct IncompleteSignatureDocChecker;

impl Checker for IncompleteSignatureDocChecker {
    const CODES: &[DiagnosticCode] = &[
        DiagnosticCode::IncompleteSignatureDoc,
        DiagnosticCode::MissingGlobalDoc,
    ];

    fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) {
        let root = semantic_model.get_root();
        for closure_expr in root.descendants::<LuaClosureExpr>() {
            check_doc(context, semantic_model, &closure_expr);
        }
    }
}

fn check_doc(
    context: &mut DiagnosticContext,
    semantic_model: &SemanticModel,
    closure_expr: &LuaClosureExpr,
) -> Option<()> {
    let semantic_decl = semantic_model.find_decl(
        rowan::NodeOrToken::Node(closure_expr.syntax().clone()),
        SemanticDeclLevel::default(),
    )?;
    let (is_global, function_name) = match semantic_decl {
        LuaSemanticDeclId::LuaDecl(decl_id) => {
            let decl = semantic_model
                .get_db()
                .get_decl_index()
                .get_decl(&decl_id)?;
            (decl.is_global(), decl.get_name().to_string())
        }
        _ => (false, String::new()),
    };

    let comment = get_closure_expr_comment(closure_expr);

    let code = if is_global {
        DiagnosticCode::MissingGlobalDoc
    } else {
        DiagnosticCode::IncompleteSignatureDoc
    };

    if comment.is_none() {
        let message = if is_global {
            t!(
                "Missing comment for global function `%{name}`.",
                name = function_name
            )
        } else {
            t!(
                "Missing comment for function `%{name}`.",
                name = function_name
            )
        };
        if let Some(stat) = closure_expr.get_parent::<LuaStat>() {
            context.add_diagnostic(code, stat.get_range(), message.to_string(), None);
        }
        return Some(());
    }

    let Some(comment) = comment else {
        return Some(());
    };

    let doc_param_names: HashSet<String> = comment
        .children::<LuaDocTagParam>()
        .filter_map(|param| {
            param
                .get_name_token()
                .map(|token| token.get_name_text().to_string())
        })
        .collect();

    let doc_return_len =
        get_doc_return_max_len(semantic_model, closure_expr).unwrap_or_else(|| {
            let doc_return_len: usize = comment
                .children::<LuaDocTagReturn>()
                .map(|return_doc| return_doc.get_types().count())
                .sum();
            let doc_return_overload_max_len = comment
                .children::<LuaDocTagReturnOverload>()
                .map(|return_doc| return_doc.get_types().count())
                .max()
                .unwrap_or(0);

            Some(doc_return_len.max(doc_return_overload_max_len))
        });

    check_params(
        context,
        closure_expr,
        &doc_param_names,
        code,
        is_global,
        &function_name,
    );

    check_returns(
        context,
        semantic_model,
        closure_expr,
        doc_return_len,
        code,
        is_global,
        &function_name,
    );

    Some(())
}

fn check_params(
    context: &mut DiagnosticContext,
    closure_expr: &LuaClosureExpr,
    doc_param_names: &HashSet<String>,
    code: DiagnosticCode,
    is_global: bool,
    function_name: &str,
) {
    let Some(params_list) = closure_expr.get_params_list() else {
        return;
    };

    for param in params_list.get_params() {
        let Some(name_token) = param.get_name_token() else {
            continue;
        };

        let name = name_token.get_name_text();
        if !doc_param_names.contains(name) && name != "_" {
            let message = if is_global {
                t!(
                    "Missing @param annotation for parameter `%{name}` in global function `%{function_name}`.",
                    name = name,
                    function_name = function_name
                )
            } else {
                t!(
                    "Incomplete signature. Missing @param annotation for parameter `%{name}`.",
                    name = name
                )
            };

            context.add_diagnostic(code, param.get_range(), message.to_string(), None);
        }
    }
}

fn check_returns(
    context: &mut DiagnosticContext,
    semantic_model: &SemanticModel,
    closure_expr: &LuaClosureExpr,
    doc_return_len: Option<usize>,
    code: DiagnosticCode,
    is_global: bool,
    function_name: &str,
) -> Option<()> {
    for return_stat in get_return_stats(closure_expr) {
        let mut return_stat_len: usize = 0;

        for (i, expr) in return_stat.get_expr_list().enumerate() {
            let Some(infer_type) = semantic_model.infer_expr(expr.clone()).ok() else {
                continue;
            };

            let expr_return_count = match infer_type {
                LuaType::Variadic(variadic) => variadic.get_min_len()?,
                _ => 1,
            };

            return_stat_len += expr_return_count;

            if let Some(doc_return_len) = doc_return_len
                && return_stat_len > doc_return_len
            {
                let message = if is_global {
                    t!(
                        "Missing @return annotation at index `%{index}` in global function `%{function_name}`.",
                        index = i + 1,
                        function_name = function_name
                    )
                } else {
                    t!(
                        "Incomplete signature. Missing @return annotation at index `%{index}`.",
                        index = i + 1
                    )
                };

                context.add_diagnostic(code, expr.get_range(), message.to_string(), None);
            }
        }
    }

    Some(())
}

fn get_doc_return_max_len(
    semantic_model: &SemanticModel,
    closure_expr: &LuaClosureExpr,
) -> Option<Option<usize>> {
    let signature_id = LuaSignatureId::from_closure(semantic_model.get_file_id(), closure_expr);
    let signature = semantic_model
        .get_db()
        .get_signature_index()
        .get(&signature_id)?;
    if signature.resolve_return != SignatureReturnStatus::DocResolve {
        return None;
    }
    let return_type = signature.get_return_type();

    Some(match return_type {
        LuaType::Variadic(variadic) => variadic.get_max_len(),
        LuaType::Any | LuaType::Unknown => Some(1),
        LuaType::Nil => Some(0),
        _ => Some(1),
    })
}