use crate::ast::common::{
AstBlock, AstCallKind, AstExpr, AstFunctionExpr, AstLValue, AstModule, AstStmt,
};
pub(super) use super::traverse::BlockKind;
use crate::ast::traverse::{
traverse_call_children, traverse_expr_children, traverse_lvalue_children,
traverse_stmt_children,
};
pub(super) trait AstRewritePass {
fn rewrite_block(&mut self, _block: &mut AstBlock, _kind: BlockKind) -> bool {
false
}
fn rewrite_stmt(&mut self, _stmt: &mut AstStmt) -> bool {
false
}
fn rewrite_expr(&mut self, _expr: &mut AstExpr) -> bool {
false
}
fn rewrite_lvalue(&mut self, _lvalue: &mut AstLValue) -> bool {
false
}
fn rewrite_condition_expr(&mut self, expr: &mut AstExpr) -> bool {
self.rewrite_expr(expr)
}
}
pub(super) trait ScopedAstRewritePass {
type Scope: Clone;
fn enter_block(
&mut self,
_block: &mut AstBlock,
_kind: BlockKind,
outer_scope: &Self::Scope,
) -> (bool, Self::Scope) {
(false, outer_scope.clone())
}
fn rewrite_stmt(&mut self, _stmt: &mut AstStmt, _scope: &Self::Scope) -> bool {
false
}
fn rewrite_expr(&mut self, _expr: &mut AstExpr, _scope: &Self::Scope) -> bool {
false
}
fn rewrite_lvalue(&mut self, _lvalue: &mut AstLValue, _scope: &Self::Scope) -> bool {
false
}
fn rewrite_condition_expr(&mut self, expr: &mut AstExpr, scope: &Self::Scope) -> bool {
self.rewrite_expr(expr, scope)
}
}
pub(super) fn rewrite_module(module: &mut AstModule, pass: &mut impl AstRewritePass) -> bool {
rewrite_block_with_kind(&mut module.body, BlockKind::ModuleBody, pass)
}
fn rewrite_block_with_kind(
block: &mut AstBlock,
kind: BlockKind,
pass: &mut impl AstRewritePass,
) -> bool {
let nested_changed = block
.stmts
.iter_mut()
.fold(false, |changed, stmt| rewrite_stmt(stmt, pass) || changed);
let block_changed = pass.rewrite_block(block, kind);
block_changed || nested_changed
}
pub(super) fn rewrite_module_scoped<P: ScopedAstRewritePass>(
module: &mut AstModule,
scope: &P::Scope,
pass: &mut P,
) -> bool {
rewrite_block_with_kind_scoped(&mut module.body, BlockKind::ModuleBody, scope, pass)
}
fn rewrite_block_with_kind_scoped<P: ScopedAstRewritePass>(
block: &mut AstBlock,
kind: BlockKind,
outer_scope: &P::Scope,
pass: &mut P,
) -> bool {
let (block_changed, scope) = pass.enter_block(block, kind, outer_scope);
let nested_changed = block.stmts.iter_mut().fold(false, |changed, stmt| {
rewrite_stmt_scoped(stmt, &scope, pass) || changed
});
block_changed || nested_changed
}
pub(super) fn rewrite_stmt(stmt: &mut AstStmt, pass: &mut impl AstRewritePass) -> bool {
let mut nested_changed = false;
traverse_stmt_children!(
stmt,
iter = iter_mut,
opt = as_mut,
borrow = [&mut],
expr(expr) => {
nested_changed |= rewrite_expr(expr, pass);
},
lvalue(lvalue) => {
nested_changed |= rewrite_lvalue(lvalue, pass);
},
block(block) => {
nested_changed |= rewrite_block_with_kind(block, BlockKind::Regular, pass);
},
function(function) => {
nested_changed |= rewrite_function_expr(function, BlockKind::FunctionBody, pass);
},
condition(condition) => {
nested_changed |= rewrite_condition_expr(condition, pass);
},
call(call) => {
nested_changed |= rewrite_call(call, pass);
}
);
let stmt_changed = pass.rewrite_stmt(stmt);
stmt_changed || nested_changed
}
fn rewrite_stmt_scoped<P: ScopedAstRewritePass>(
stmt: &mut AstStmt,
scope: &P::Scope,
pass: &mut P,
) -> bool {
let mut nested_changed = false;
traverse_stmt_children!(
stmt,
iter = iter_mut,
opt = as_mut,
borrow = [&mut],
expr(expr) => {
nested_changed |= rewrite_expr_scoped(expr, scope, pass);
},
lvalue(lvalue) => {
nested_changed |= rewrite_lvalue_scoped(lvalue, scope, pass);
},
block(block) => {
nested_changed |= rewrite_block_with_kind_scoped(block, BlockKind::Regular, scope, pass);
},
function(function) => {
nested_changed |= rewrite_function_expr_scoped(function, BlockKind::FunctionBody, scope, pass);
},
condition(condition) => {
nested_changed |= rewrite_condition_expr_scoped(condition, scope, pass);
},
call(call) => {
nested_changed |= rewrite_call_scoped(call, scope, pass);
}
);
let stmt_changed = pass.rewrite_stmt(stmt, scope);
stmt_changed || nested_changed
}
pub(super) fn rewrite_expr(expr: &mut AstExpr, pass: &mut impl AstRewritePass) -> bool {
let mut nested_changed = false;
traverse_expr_children!(
expr,
iter = iter_mut,
borrow = [&mut],
expr(expr) => {
nested_changed |= rewrite_expr(expr, pass);
},
function(function) => {
nested_changed |= rewrite_function_expr(function, BlockKind::FunctionBody, pass);
}
);
let expr_changed = pass.rewrite_expr(expr);
expr_changed || nested_changed
}
fn rewrite_expr_scoped<P: ScopedAstRewritePass>(
expr: &mut AstExpr,
scope: &P::Scope,
pass: &mut P,
) -> bool {
let mut nested_changed = false;
traverse_expr_children!(
expr,
iter = iter_mut,
borrow = [&mut],
expr(expr) => {
nested_changed |= rewrite_expr_scoped(expr, scope, pass);
},
function(function) => {
nested_changed |= rewrite_function_expr_scoped(function, BlockKind::FunctionBody, scope, pass);
}
);
let expr_changed = pass.rewrite_expr(expr, scope);
expr_changed || nested_changed
}
pub(super) fn rewrite_lvalue(lvalue: &mut AstLValue, pass: &mut impl AstRewritePass) -> bool {
let mut nested_changed = false;
traverse_lvalue_children!(lvalue, borrow = [&mut], expr(expr) => {
nested_changed |= rewrite_expr(expr, pass);
});
let lvalue_changed = pass.rewrite_lvalue(lvalue);
lvalue_changed || nested_changed
}
fn rewrite_lvalue_scoped<P: ScopedAstRewritePass>(
lvalue: &mut AstLValue,
scope: &P::Scope,
pass: &mut P,
) -> bool {
let mut nested_changed = false;
traverse_lvalue_children!(lvalue, borrow = [&mut], expr(expr) => {
nested_changed |= rewrite_expr_scoped(expr, scope, pass);
});
let lvalue_changed = pass.rewrite_lvalue(lvalue, scope);
lvalue_changed || nested_changed
}
fn rewrite_condition_expr(expr: &mut AstExpr, pass: &mut impl AstRewritePass) -> bool {
let nested_changed = rewrite_expr(expr, pass);
let expr_changed = pass.rewrite_condition_expr(expr);
expr_changed || nested_changed
}
fn rewrite_condition_expr_scoped<P: ScopedAstRewritePass>(
expr: &mut AstExpr,
scope: &P::Scope,
pass: &mut P,
) -> bool {
let nested_changed = rewrite_expr_scoped(expr, scope, pass);
let expr_changed = pass.rewrite_condition_expr(expr, scope);
expr_changed || nested_changed
}
fn rewrite_call(call: &mut AstCallKind, pass: &mut impl AstRewritePass) -> bool {
let mut nested_changed = false;
traverse_call_children!(call, iter = iter_mut, borrow = [&mut], expr(expr) => {
nested_changed |= rewrite_expr(expr, pass);
});
nested_changed
}
fn rewrite_call_scoped<P: ScopedAstRewritePass>(
call: &mut AstCallKind,
scope: &P::Scope,
pass: &mut P,
) -> bool {
let mut nested_changed = false;
traverse_call_children!(call, iter = iter_mut, borrow = [&mut], expr(expr) => {
nested_changed |= rewrite_expr_scoped(expr, scope, pass);
});
nested_changed
}
fn rewrite_function_expr(
function: &mut AstFunctionExpr,
kind: BlockKind,
pass: &mut impl AstRewritePass,
) -> bool {
rewrite_block_with_kind(&mut function.body, kind, pass)
}
fn rewrite_function_expr_scoped<P: ScopedAstRewritePass>(
function: &mut AstFunctionExpr,
kind: BlockKind,
scope: &P::Scope,
pass: &mut P,
) -> bool {
rewrite_block_with_kind_scoped(&mut function.body, kind, scope, pass)
}