use crate::hir::common::{
HirBlock, HirCallExpr, HirDecisionExpr, HirDecisionTarget, HirExpr, HirLValue, HirProto,
HirStmt, HirTableConstructor, HirTableField, HirTableKey,
};
pub(super) trait HirRewritePass {
fn rewrite_block(&mut self, _block: &mut HirBlock) -> bool {
false
}
fn rewrite_stmt(&mut self, _stmt: &mut HirStmt) -> bool {
false
}
fn rewrite_expr(&mut self, _expr: &mut HirExpr) -> bool {
false
}
fn rewrite_lvalue(&mut self, _lvalue: &mut HirLValue) -> bool {
false
}
fn rewrite_call(&mut self, _call: &mut HirCallExpr) -> bool {
false
}
fn rewrite_condition_expr(&mut self, expr: &mut HirExpr) -> bool {
self.rewrite_expr(expr)
}
}
pub(super) fn rewrite_proto(proto: &mut HirProto, pass: &mut impl HirRewritePass) -> bool {
rewrite_block(&mut proto.body, pass)
}
pub(super) fn rewrite_stmts(stmts: &mut [HirStmt], pass: &mut impl HirRewritePass) -> bool {
stmts.iter_mut().fold(false, |changed, stmt| {
rewrite_stmt_tree(stmt, pass) || changed
})
}
pub(super) trait ExprRewritePass {
fn rewrite_expr(&mut self, expr: &mut HirExpr) -> bool;
fn rewrite_condition_expr(&mut self, expr: &mut HirExpr) -> bool {
self.rewrite_expr(expr)
}
}
pub(super) fn rewrite_proto_exprs(proto: &mut HirProto, pass: &mut impl ExprRewritePass) -> bool {
let mut adapter = ExprRewritePassAdapter { pass };
rewrite_proto(proto, &mut adapter)
}
pub(super) fn for_each_nested_block_mut(stmt: &mut HirStmt, visit: &mut impl FnMut(&mut HirBlock)) {
match stmt {
HirStmt::If(if_stmt) => {
visit(&mut if_stmt.then_block);
if let Some(else_block) = &mut if_stmt.else_block {
visit(else_block);
}
}
HirStmt::While(while_stmt) => visit(&mut while_stmt.body),
HirStmt::Repeat(repeat_stmt) => visit(&mut repeat_stmt.body),
HirStmt::NumericFor(numeric_for) => visit(&mut numeric_for.body),
HirStmt::GenericFor(generic_for) => visit(&mut generic_for.body),
HirStmt::Block(block) => visit(block),
HirStmt::Unstructured(unstructured) => visit(&mut unstructured.body),
HirStmt::LocalDecl(_)
| HirStmt::Assign(_)
| HirStmt::TableSetList(_)
| HirStmt::ErrNil(_)
| HirStmt::ToBeClosed(_)
| HirStmt::Close(_)
| HirStmt::CallStmt(_)
| HirStmt::Return(_)
| HirStmt::Break
| HirStmt::Continue
| HirStmt::Goto(_)
| HirStmt::Label(_) => {}
}
}
pub(super) fn rewrite_nested_blocks_in_stmt(
stmt: &mut HirStmt,
rewrite_block: &mut impl FnMut(&mut HirBlock) -> bool,
) -> bool {
let mut changed = false;
for_each_nested_block_mut(stmt, &mut |block| {
changed |= rewrite_block(block);
});
changed
}
pub(super) fn rewrite_stmt_tree(stmt: &mut HirStmt, pass: &mut impl HirRewritePass) -> bool {
rewrite_stmt(stmt, pass)
}
struct ExprRewritePassAdapter<'a, P> {
pass: &'a mut P,
}
impl<P: ExprRewritePass> HirRewritePass for ExprRewritePassAdapter<'_, P> {
fn rewrite_expr(&mut self, expr: &mut HirExpr) -> bool {
self.pass.rewrite_expr(expr)
}
fn rewrite_condition_expr(&mut self, expr: &mut HirExpr) -> bool {
self.pass.rewrite_condition_expr(expr)
}
}
fn rewrite_block(block: &mut HirBlock, pass: &mut impl HirRewritePass) -> bool {
let nested_changed = block
.stmts
.iter_mut()
.fold(false, |changed, stmt| rewrite_stmt(stmt, pass) || changed);
let block_changed = pass.rewrite_block(block);
block_changed || nested_changed
}
fn rewrite_stmt(stmt: &mut HirStmt, pass: &mut impl HirRewritePass) -> bool {
let nested_changed = match stmt {
HirStmt::LocalDecl(local_decl) => local_decl
.values
.iter_mut()
.fold(false, |changed, expr| rewrite_expr(expr, pass) || changed),
HirStmt::Assign(assign) => {
let targets_changed = assign.targets.iter_mut().fold(false, |changed, target| {
rewrite_lvalue(target, pass) || changed
});
let values_changed = assign
.values
.iter_mut()
.fold(false, |changed, expr| rewrite_expr(expr, pass) || changed);
targets_changed || values_changed
}
HirStmt::TableSetList(set_list) => {
let base_changed = rewrite_expr(&mut set_list.base, pass);
let values_changed = set_list
.values
.iter_mut()
.fold(false, |changed, expr| rewrite_expr(expr, pass) || changed);
let trailing_changed = set_list
.trailing_multivalue
.as_mut()
.is_some_and(|expr| rewrite_expr(expr, pass));
base_changed || values_changed || trailing_changed
}
HirStmt::ErrNil(err_nil) => rewrite_expr(&mut err_nil.value, pass),
HirStmt::ToBeClosed(to_be_closed) => rewrite_expr(&mut to_be_closed.value, pass),
HirStmt::CallStmt(call_stmt) => rewrite_call_expr(&mut call_stmt.call, pass),
HirStmt::Return(ret) => ret
.values
.iter_mut()
.fold(false, |changed, expr| rewrite_expr(expr, pass) || changed),
HirStmt::If(if_stmt) => {
let cond_changed = rewrite_condition_expr(&mut if_stmt.cond, pass);
let then_changed = rewrite_block(&mut if_stmt.then_block, pass);
let else_changed = if_stmt
.else_block
.as_mut()
.is_some_and(|else_block| rewrite_block(else_block, pass));
cond_changed || then_changed || else_changed
}
HirStmt::While(while_stmt) => {
let cond_changed = rewrite_condition_expr(&mut while_stmt.cond, pass);
let body_changed = rewrite_block(&mut while_stmt.body, pass);
cond_changed || body_changed
}
HirStmt::Repeat(repeat_stmt) => {
let body_changed = rewrite_block(&mut repeat_stmt.body, pass);
let cond_changed = rewrite_condition_expr(&mut repeat_stmt.cond, pass);
body_changed || cond_changed
}
HirStmt::NumericFor(numeric_for) => {
let start_changed = rewrite_expr(&mut numeric_for.start, pass);
let limit_changed = rewrite_expr(&mut numeric_for.limit, pass);
let step_changed = rewrite_expr(&mut numeric_for.step, pass);
let body_changed = rewrite_block(&mut numeric_for.body, pass);
start_changed || limit_changed || step_changed || body_changed
}
HirStmt::GenericFor(generic_for) => {
let iterator_changed = generic_for
.iterator
.iter_mut()
.fold(false, |changed, expr| rewrite_expr(expr, pass) || changed);
let body_changed = rewrite_block(&mut generic_for.body, pass);
iterator_changed || body_changed
}
HirStmt::Block(block) => rewrite_block(block, pass),
HirStmt::Unstructured(unstructured) => rewrite_block(&mut unstructured.body, pass),
HirStmt::Break
| HirStmt::Close(_)
| HirStmt::Continue
| HirStmt::Goto(_)
| HirStmt::Label(_) => false,
};
let stmt_changed = pass.rewrite_stmt(stmt);
stmt_changed || nested_changed
}
fn rewrite_lvalue(lvalue: &mut HirLValue, pass: &mut impl HirRewritePass) -> bool {
let nested_changed = match lvalue {
HirLValue::TableAccess(access) => {
let base_changed = rewrite_expr(&mut access.base, pass);
let key_changed = rewrite_expr(&mut access.key, pass);
base_changed || key_changed
}
HirLValue::Temp(_) | HirLValue::Local(_) | HirLValue::Upvalue(_) | HirLValue::Global(_) => {
false
}
};
let lvalue_changed = pass.rewrite_lvalue(lvalue);
lvalue_changed || nested_changed
}
fn rewrite_call_expr(call: &mut HirCallExpr, pass: &mut impl HirRewritePass) -> bool {
let callee_changed = rewrite_expr(&mut call.callee, pass);
let args_changed = call
.args
.iter_mut()
.fold(false, |changed, arg| rewrite_expr(arg, pass) || changed);
let call_changed = pass.rewrite_call(call);
call_changed || callee_changed || args_changed
}
fn rewrite_expr(expr: &mut HirExpr, pass: &mut impl HirRewritePass) -> bool {
let nested_changed = match expr {
HirExpr::TableAccess(access) => {
let base_changed = rewrite_expr(&mut access.base, pass);
let key_changed = rewrite_expr(&mut access.key, pass);
base_changed || key_changed
}
HirExpr::Unary(unary) => rewrite_expr(&mut unary.expr, pass),
HirExpr::Binary(binary) => {
let lhs_changed = rewrite_expr(&mut binary.lhs, pass);
let rhs_changed = rewrite_expr(&mut binary.rhs, pass);
lhs_changed || rhs_changed
}
HirExpr::LogicalAnd(logical) | HirExpr::LogicalOr(logical) => {
let lhs_changed = rewrite_expr(&mut logical.lhs, pass);
let rhs_changed = rewrite_expr(&mut logical.rhs, pass);
lhs_changed || rhs_changed
}
HirExpr::Decision(decision) => rewrite_decision_expr(decision, pass),
HirExpr::Call(call) => rewrite_call_expr(call, pass),
HirExpr::TableConstructor(table) => rewrite_table_constructor(table, pass),
HirExpr::Closure(closure) => closure.captures.iter_mut().fold(false, |changed, capture| {
rewrite_expr(&mut capture.value, pass) || changed
}),
HirExpr::Nil
| HirExpr::Boolean(_)
| HirExpr::Integer(_)
| HirExpr::Number(_)
| HirExpr::String(_)
| HirExpr::Int64(_)
| HirExpr::UInt64(_)
| HirExpr::Complex { .. }
| HirExpr::ParamRef(_)
| HirExpr::LocalRef(_)
| HirExpr::UpvalueRef(_)
| HirExpr::TempRef(_)
| HirExpr::GlobalRef(_)
| HirExpr::VarArg
| HirExpr::Unresolved(_) => false,
};
let expr_changed = pass.rewrite_expr(expr);
expr_changed || nested_changed
}
fn rewrite_condition_expr(expr: &mut HirExpr, pass: &mut impl HirRewritePass) -> bool {
let nested_changed = rewrite_expr(expr, pass);
pass.rewrite_condition_expr(expr) || nested_changed
}
fn rewrite_decision_expr(decision: &mut HirDecisionExpr, pass: &mut impl HirRewritePass) -> bool {
decision.nodes.iter_mut().fold(false, |changed, node| {
let test_changed = rewrite_condition_expr(&mut node.test, pass);
let truthy_changed = rewrite_decision_target(&mut node.truthy, pass);
let falsy_changed = rewrite_decision_target(&mut node.falsy, pass);
changed || test_changed || truthy_changed || falsy_changed
})
}
fn rewrite_decision_target(target: &mut HirDecisionTarget, pass: &mut impl HirRewritePass) -> bool {
match target {
HirDecisionTarget::Expr(expr) => rewrite_expr(expr, pass),
HirDecisionTarget::Node(_) | HirDecisionTarget::CurrentValue => false,
}
}
fn rewrite_table_constructor(
table: &mut HirTableConstructor,
pass: &mut impl HirRewritePass,
) -> bool {
let fields_changed = table.fields.iter_mut().fold(false, |changed, field| {
let field_changed = match field {
HirTableField::Array(expr) => rewrite_expr(expr, pass),
HirTableField::Record(field) => {
let key_changed = match &mut field.key {
HirTableKey::Name(_) => false,
HirTableKey::Expr(expr) => rewrite_expr(expr, pass),
};
let value_changed = rewrite_expr(&mut field.value, pass);
key_changed || value_changed
}
};
changed || field_changed
});
let trailing_changed = table
.trailing_multivalue
.as_mut()
.is_some_and(|expr| rewrite_expr(expr, pass));
fields_changed || trailing_changed
}