use std::collections::BTreeMap;
use super::super::common::{AstBindingRef, AstBlock, AstModule, AstStmt};
use super::ReadabilityContext;
use super::binding_flow::{count_binding_mentions_in_block, count_binding_uses_in_stmts};
use super::walk::{self, AstRewritePass, BlockKind};
pub(super) fn apply(module: &mut AstModule, _context: ReadabilityContext) -> bool {
walk::rewrite_module(module, &mut CleanupPass)
}
struct CleanupPass;
impl AstRewritePass for CleanupPass {
fn rewrite_block(&mut self, block: &mut AstBlock, kind: BlockKind) -> bool {
cleanup_block(
block,
matches!(kind, BlockKind::ModuleBody | BlockKind::FunctionBody),
)
}
}
fn cleanup_block(block: &mut AstBlock, allow_trailing_empty_return_elision: bool) -> bool {
let mut changed = false;
let old_stmts = std::mem::take(&mut block.stmts);
let mut flattened_stmts = Vec::with_capacity(old_stmts.len());
for stmt in old_stmts {
match stmt {
AstStmt::DoBlock(nested)
if nested.stmts.len() == 1 && can_elide_single_stmt_do_block(&nested.stmts[0]) =>
{
flattened_stmts.push(nested.stmts[0].clone());
changed = true;
}
other => flattened_stmts.push(other),
}
}
block.stmts = flattened_stmts;
while let Some(AstStmt::DoBlock(nested)) = block.stmts.last()
&& !nested.stmts.iter().any(|s| matches!(s, AstStmt::GlobalDecl(_)))
{
let Some(AstStmt::DoBlock(nested)) = block.stmts.pop() else {
unreachable!();
};
block.stmts.extend(nested.stmts);
changed = true;
}
let discardable_unused_locals = collect_discardable_unused_locals(block);
let original_len = block.stmts.len();
block.stmts.retain(|stmt| {
!matches!(
stmt,
AstStmt::LocalDecl(local_decl)
if local_decl.bindings.len() == 1
&& local_decl.values.len() == 1
&& discardable_unused_locals.contains(&local_decl.bindings[0].id)
)
});
changed |= block.stmts.len() != original_len;
let mechanical_binding_uses = collect_mechanical_binding_uses(block);
for stmt in &mut block.stmts {
let AstStmt::LocalDecl(local_decl) = stmt else {
continue;
};
if !local_decl.values.is_empty() {
continue;
}
let original_len = local_decl.bindings.len();
local_decl.bindings.retain(|binding| match binding.id {
AstBindingRef::Temp(_) | AstBindingRef::SyntheticLocal(_) => {
mechanical_binding_uses
.get(&binding.id)
.copied()
.unwrap_or_default()
> 0
}
AstBindingRef::Local(_) => true,
});
changed |= local_decl.bindings.len() != original_len;
}
let original_len = block.stmts.len();
block.stmts.retain(|stmt| match stmt {
AstStmt::LocalDecl(local_decl) => {
!(local_decl.bindings.is_empty() && local_decl.values.is_empty())
}
_ => true,
});
changed |= block.stmts.len() != original_len;
if allow_trailing_empty_return_elision
&& matches!(
block.stmts.last(),
Some(AstStmt::Return(ret)) if ret.values.is_empty()
)
{
block.stmts.pop();
changed = true;
}
changed
}
fn can_elide_single_stmt_do_block(stmt: &AstStmt) -> bool {
matches!(
stmt,
AstStmt::Assign(_)
| AstStmt::CallStmt(_)
| AstStmt::Return(_)
| AstStmt::If(_)
| AstStmt::While(_)
| AstStmt::Repeat(_)
| AstStmt::NumericFor(_)
| AstStmt::GenericFor(_)
| AstStmt::Break
| AstStmt::Continue
| AstStmt::Goto(_)
| AstStmt::FunctionDecl(_)
)
}
fn collect_mechanical_binding_uses(block: &AstBlock) -> BTreeMap<AstBindingRef, usize> {
let mut uses = BTreeMap::new();
for stmt in &block.stmts {
let AstStmt::LocalDecl(local_decl) = stmt else {
continue;
};
for binding in &local_decl.bindings {
if matches!(
binding.id,
AstBindingRef::Temp(_) | AstBindingRef::SyntheticLocal(_)
) {
uses.entry(binding.id).or_insert_with(|| {
let mentions = count_binding_mentions_in_block(block, binding.id);
if block_captures_binding(block, binding.id) {
mentions.max(1)
} else {
mentions
}
});
}
}
}
uses
}
fn collect_discardable_unused_locals(
block: &AstBlock,
) -> std::collections::BTreeSet<AstBindingRef> {
let mut bindings = std::collections::BTreeSet::new();
for stmt in &block.stmts {
let AstStmt::LocalDecl(local_decl) = stmt else {
continue;
};
let [binding] = local_decl.bindings.as_slice() else {
continue;
};
let [value] = local_decl.values.as_slice() else {
continue;
};
if !matches!(binding.origin, crate::ast::AstLocalOrigin::Recovered) {
continue;
}
if count_binding_uses_in_stmts(&block.stmts, binding.id) != 0
|| block_captures_binding(block, binding.id)
{
continue;
}
if is_discard_safe_local_value(value) {
bindings.insert(binding.id);
}
}
bindings
}
fn is_discard_safe_local_value(value: &crate::ast::AstExpr) -> bool {
is_definitely_pure_expr(value)
}
fn is_definitely_pure_expr(expr: &crate::ast::AstExpr) -> bool {
use crate::ast::AstExpr;
match expr {
AstExpr::Nil
| AstExpr::Boolean(_)
| AstExpr::Integer(_)
| AstExpr::Number(_)
| AstExpr::String(_)
| AstExpr::Int64(_)
| AstExpr::UInt64(_)
| AstExpr::Complex { .. }
| AstExpr::Var(_) => true,
AstExpr::SingleValue(inner) => is_definitely_pure_expr(inner),
_ => false,
}
}
fn block_captures_binding(block: &AstBlock, binding: AstBindingRef) -> bool {
block
.stmts
.iter()
.any(|stmt| stmt_captures_binding(stmt, binding))
}
fn stmt_captures_binding(stmt: &AstStmt, binding: AstBindingRef) -> bool {
match stmt {
AstStmt::LocalDecl(local_decl) => local_decl
.values
.iter()
.any(|value| expr_captures_binding(value, binding)),
AstStmt::GlobalDecl(global_decl) => global_decl
.values
.iter()
.any(|value| expr_captures_binding(value, binding)),
AstStmt::Assign(assign) => {
assign
.values
.iter()
.any(|value| expr_captures_binding(value, binding))
|| assign
.targets
.iter()
.any(|target| lvalue_captures_binding(target, binding))
}
AstStmt::CallStmt(call_stmt) => call_captures_binding(&call_stmt.call, binding),
AstStmt::Return(ret) => ret
.values
.iter()
.any(|value| expr_captures_binding(value, binding)),
AstStmt::If(if_stmt) => {
expr_captures_binding(&if_stmt.cond, binding)
|| block_captures_binding(&if_stmt.then_block, binding)
|| if_stmt
.else_block
.as_ref()
.is_some_and(|else_block| block_captures_binding(else_block, binding))
}
AstStmt::While(while_stmt) => {
expr_captures_binding(&while_stmt.cond, binding)
|| block_captures_binding(&while_stmt.body, binding)
}
AstStmt::Repeat(repeat_stmt) => {
block_captures_binding(&repeat_stmt.body, binding)
|| expr_captures_binding(&repeat_stmt.cond, binding)
}
AstStmt::NumericFor(numeric_for) => {
expr_captures_binding(&numeric_for.start, binding)
|| expr_captures_binding(&numeric_for.limit, binding)
|| expr_captures_binding(&numeric_for.step, binding)
|| block_captures_binding(&numeric_for.body, binding)
}
AstStmt::GenericFor(generic_for) => {
generic_for
.iterator
.iter()
.any(|value| expr_captures_binding(value, binding))
|| block_captures_binding(&generic_for.body, binding)
}
AstStmt::DoBlock(block) => block_captures_binding(block, binding),
AstStmt::FunctionDecl(function_decl) => {
function_expr_captures_binding(&function_decl.func, binding)
}
AstStmt::LocalFunctionDecl(function_decl) => {
function_expr_captures_binding(&function_decl.func, binding)
}
AstStmt::Break | AstStmt::Continue | AstStmt::Goto(_) | AstStmt::Label(_) | AstStmt::Error(_) => false,
}
}
fn lvalue_captures_binding(lvalue: &crate::ast::AstLValue, binding: AstBindingRef) -> bool {
match lvalue {
crate::ast::AstLValue::Name(_) => false,
crate::ast::AstLValue::FieldAccess(access) => expr_captures_binding(&access.base, binding),
crate::ast::AstLValue::IndexAccess(access) => {
expr_captures_binding(&access.base, binding)
|| expr_captures_binding(&access.index, binding)
}
}
}
fn call_captures_binding(call: &crate::ast::AstCallKind, binding: AstBindingRef) -> bool {
match call {
crate::ast::AstCallKind::Call(call) => {
expr_captures_binding(&call.callee, binding)
|| call
.args
.iter()
.any(|arg| expr_captures_binding(arg, binding))
}
crate::ast::AstCallKind::MethodCall(call) => {
expr_captures_binding(&call.receiver, binding)
|| call
.args
.iter()
.any(|arg| expr_captures_binding(arg, binding))
}
}
}
fn expr_captures_binding(expr: &crate::ast::AstExpr, binding: AstBindingRef) -> bool {
match expr {
crate::ast::AstExpr::Unary(unary) => expr_captures_binding(&unary.expr, binding),
crate::ast::AstExpr::Binary(binary) => {
expr_captures_binding(&binary.lhs, binding)
|| expr_captures_binding(&binary.rhs, binding)
}
crate::ast::AstExpr::LogicalAnd(logical) | crate::ast::AstExpr::LogicalOr(logical) => {
expr_captures_binding(&logical.lhs, binding)
|| expr_captures_binding(&logical.rhs, binding)
}
crate::ast::AstExpr::FieldAccess(access) => expr_captures_binding(&access.base, binding),
crate::ast::AstExpr::IndexAccess(access) => {
expr_captures_binding(&access.base, binding)
|| expr_captures_binding(&access.index, binding)
}
crate::ast::AstExpr::Call(call) => {
expr_captures_binding(&call.callee, binding)
|| call
.args
.iter()
.any(|arg| expr_captures_binding(arg, binding))
}
crate::ast::AstExpr::MethodCall(call) => {
expr_captures_binding(&call.receiver, binding)
|| call
.args
.iter()
.any(|arg| expr_captures_binding(arg, binding))
}
crate::ast::AstExpr::SingleValue(expr) => expr_captures_binding(expr, binding),
crate::ast::AstExpr::TableConstructor(table) => {
table.fields.iter().any(|field| match field {
crate::ast::AstTableField::Array(value) => expr_captures_binding(value, binding),
crate::ast::AstTableField::Record(record) => {
(match &record.key {
crate::ast::AstTableKey::Name(_) => false,
crate::ast::AstTableKey::Expr(key) => expr_captures_binding(key, binding),
}) || expr_captures_binding(&record.value, binding)
}
})
}
crate::ast::AstExpr::FunctionExpr(function) => {
function_expr_captures_binding(function, binding)
}
crate::ast::AstExpr::Nil
| crate::ast::AstExpr::Boolean(_)
| crate::ast::AstExpr::Integer(_)
| crate::ast::AstExpr::Number(_)
| crate::ast::AstExpr::String(_)
| crate::ast::AstExpr::Int64(_)
| crate::ast::AstExpr::UInt64(_)
| crate::ast::AstExpr::Complex { .. }
| crate::ast::AstExpr::Var(_)
| crate::ast::AstExpr::VarArg
| crate::ast::AstExpr::Error(_) => false,
}
}
fn function_expr_captures_binding(
function: &crate::ast::AstFunctionExpr,
binding: AstBindingRef,
) -> bool {
function.captured_bindings.contains(&binding) || block_captures_binding(&function.body, binding)
}
#[cfg(test)]
mod tests;