use emmylua_parser::{
LuaAstNode, LuaAstToken, LuaBlock, LuaIfStat, LuaStat, LuaSyntaxKind, LuaTokenKind,
};
use crate::{
DiagnosticCode, SemanticModel,
diagnostic::checker::{Checker, DiagnosticContext},
};
pub struct InvertIfChecker;
impl Checker for InvertIfChecker {
const CODES: &[DiagnosticCode] = &[DiagnosticCode::InvertIf];
fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) {
let root = semantic_model.get_root().clone();
for if_statement in root.descendants::<LuaIfStat>() {
check_early_return_pattern(context, &if_statement);
}
}
}
fn check_early_return_pattern(context: &mut DiagnosticContext, if_statement: &LuaIfStat) {
let Some(else_clause) = if_statement.get_else_clause() else {
return;
};
if if_statement.get_else_if_clause_list().next().is_some() {
return;
}
let Some(if_block) = if_statement.get_block() else {
return;
};
let Some(else_block) = else_clause.get_block() else {
return;
};
let in_loop = is_in_loop(if_statement);
let else_exit_type = get_early_exit_type(&else_block);
if else_exit_type == EarlyExitType::None {
return;
}
if else_exit_type == EarlyExitType::Break && !in_loop {
return;
}
if block_ends_with_exit(&if_block) {
return;
}
if !has_code_after_if(if_statement) {
return;
}
let if_stmt_count = count_meaningful_statements(&if_block);
if if_stmt_count < 3 {
return;
}
if let Some(if_token) = if_statement.token_by_kind(LuaTokenKind::TkIf) {
context.add_diagnostic(
DiagnosticCode::InvertIf,
if_token.get_range(),
t!("Consider inverting 'if' statement to reduce nesting").to_string(),
None,
);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum EarlyExitType {
None,
Return,
Break,
}
fn get_early_exit_type(block: &LuaBlock) -> EarlyExitType {
let stats: Vec<_> = block.get_stats().collect();
if stats.len() != 1 {
return EarlyExitType::None;
}
match &stats[0] {
LuaStat::ReturnStat(return_stat) => {
let expr_count = return_stat.get_expr_list().count();
if expr_count <= 1 {
EarlyExitType::Return
} else {
EarlyExitType::None
}
}
LuaStat::BreakStat(_) => EarlyExitType::Break,
_ => EarlyExitType::None,
}
}
fn block_ends_with_exit(block: &LuaBlock) -> bool {
let stats: Vec<_> = block.get_stats().collect();
if let Some(last) = stats.last() {
matches!(last, LuaStat::ReturnStat(_) | LuaStat::BreakStat(_))
} else {
false
}
}
fn count_meaningful_statements(block: &LuaBlock) -> usize {
block
.get_stats()
.filter(|s| !matches!(s, LuaStat::EmptyStat(_)))
.count()
}
fn is_in_loop(if_statement: &LuaIfStat) -> bool {
for ancestor in if_statement.syntax().ancestors() {
let kind: LuaSyntaxKind = ancestor.kind().into();
match kind {
LuaSyntaxKind::ClosureExpr
| LuaSyntaxKind::FuncStat
| LuaSyntaxKind::LocalFuncStat
| LuaSyntaxKind::Chunk => {
return false;
}
LuaSyntaxKind::WhileStat
| LuaSyntaxKind::RepeatStat
| LuaSyntaxKind::ForStat
| LuaSyntaxKind::ForRangeStat => {
return true;
}
_ => {}
}
}
false
}
fn has_code_after_if(if_statement: &LuaIfStat) -> bool {
let mut next = if_statement.syntax().next_sibling();
while let Some(sibling) = next {
if let Some(stat) = LuaStat::cast(sibling.clone()) {
if !matches!(stat, LuaStat::EmptyStat(_)) {
return true;
}
}
next = sibling.next_sibling();
}
false
}