use std::collections::{BTreeMap, BTreeSet};
use crate::hir::TempId;
use super::super::common::{
AstBindingRef, AstBlock, AstCallKind, AstExpr, AstFunctionExpr, AstLValue, AstModule,
AstNameRef, AstStmt, AstSyntheticLocalId, AstTableField, AstTableKey,
};
use super::ReadabilityContext;
use super::visit::{self, AstVisitor};
use super::walk::{self, AstRewritePass, BlockKind};
pub(super) fn apply(module: &mut AstModule, _context: ReadabilityContext) -> bool {
walk::rewrite_module(module, &mut MaterializeTempsPass)
}
struct MaterializeTempsPass;
impl AstRewritePass for MaterializeTempsPass {
fn rewrite_block(&mut self, block: &mut AstBlock, _kind: BlockKind) -> bool {
let temps = collect_function_temps_in_block(block);
if temps.is_empty() {
return false;
}
let mapping = temps
.into_iter()
.map(|temp| (temp, AstSyntheticLocalId(temp)))
.collect::<BTreeMap<_, _>>();
rewrite_function_block(block, &mapping);
true
}
}
fn collect_function_temps_in_block(block: &AstBlock) -> BTreeSet<TempId> {
let mut collector = FunctionTempCollector::default();
visit::visit_block(block, &mut collector);
collector.temps
}
#[derive(Default)]
struct FunctionTempCollector {
temps: BTreeSet<TempId>,
}
impl AstVisitor for FunctionTempCollector {
fn visit_stmt(&mut self, stmt: &AstStmt) {
match stmt {
AstStmt::LocalDecl(local_decl) => {
for binding in &local_decl.bindings {
if let AstBindingRef::Temp(temp) = binding.id {
self.temps.insert(temp);
}
}
}
AstStmt::FunctionDecl(function_decl) => {
collect_function_temps_in_function_name(&function_decl.target, &mut self.temps);
}
AstStmt::LocalFunctionDecl(local_function_decl) => {
if let AstBindingRef::Temp(temp) = local_function_decl.name {
self.temps.insert(temp);
}
}
AstStmt::NumericFor(numeric_for) => {
if let AstBindingRef::Temp(temp) = numeric_for.binding {
self.temps.insert(temp);
}
}
AstStmt::GenericFor(generic_for) => {
for binding in &generic_for.bindings {
if let AstBindingRef::Temp(temp) = binding {
self.temps.insert(*temp);
}
}
}
AstStmt::GlobalDecl(_)
| AstStmt::Assign(_)
| AstStmt::CallStmt(_)
| AstStmt::Return(_)
| AstStmt::If(_)
| AstStmt::While(_)
| AstStmt::Repeat(_)
| AstStmt::DoBlock(_)
| AstStmt::Break
| AstStmt::Continue
| AstStmt::Goto(_)
| AstStmt::Label(_) | AstStmt::Error(_) => {}
}
}
fn visit_lvalue(&mut self, target: &AstLValue) {
if let AstLValue::Name(AstNameRef::Temp(temp)) = target {
self.temps.insert(*temp);
}
}
fn visit_expr(&mut self, expr: &AstExpr) {
if let AstExpr::Var(AstNameRef::Temp(temp)) = expr {
self.temps.insert(*temp);
}
}
fn visit_function_expr(&mut self, function: &AstFunctionExpr) -> bool {
if let Some(AstBindingRef::Temp(temp)) = function.named_vararg {
self.temps.insert(temp);
}
for binding in &function.captured_bindings {
if let AstBindingRef::Temp(temp) = binding {
self.temps.insert(*temp);
}
}
false
}
}
fn collect_function_temps_in_function_name(
target: &super::super::common::AstFunctionName,
temps: &mut BTreeSet<TempId>,
) {
let path = match target {
super::super::common::AstFunctionName::Plain(path) => path,
super::super::common::AstFunctionName::Method(path, _) => path,
};
if let AstNameRef::Temp(temp) = path.root {
temps.insert(temp);
}
}
fn rewrite_function_block(block: &mut AstBlock, mapping: &BTreeMap<TempId, AstSyntheticLocalId>) {
for stmt in &mut block.stmts {
rewrite_function_stmt(stmt, mapping);
}
}
fn rewrite_function_stmt(stmt: &mut AstStmt, mapping: &BTreeMap<TempId, AstSyntheticLocalId>) {
match stmt {
AstStmt::LocalDecl(local_decl) => {
for binding in &mut local_decl.bindings {
if let AstBindingRef::Temp(temp) = binding.id
&& let Some(&synthetic) = mapping.get(&temp)
{
binding.id = AstBindingRef::SyntheticLocal(synthetic);
}
}
for value in &mut local_decl.values {
rewrite_function_expr(value, mapping);
}
}
AstStmt::GlobalDecl(global_decl) => {
for value in &mut global_decl.values {
rewrite_function_expr(value, mapping);
}
}
AstStmt::Assign(assign) => {
for target in &mut assign.targets {
rewrite_function_lvalue(target, mapping);
}
for value in &mut assign.values {
rewrite_function_expr(value, mapping);
}
}
AstStmt::CallStmt(call_stmt) => rewrite_function_call(&mut call_stmt.call, mapping),
AstStmt::Return(ret) => {
for value in &mut ret.values {
rewrite_function_expr(value, mapping);
}
}
AstStmt::If(if_stmt) => {
rewrite_function_expr(&mut if_stmt.cond, mapping);
rewrite_function_block(&mut if_stmt.then_block, mapping);
if let Some(else_block) = &mut if_stmt.else_block {
rewrite_function_block(else_block, mapping);
}
}
AstStmt::While(while_stmt) => {
rewrite_function_expr(&mut while_stmt.cond, mapping);
rewrite_function_block(&mut while_stmt.body, mapping);
}
AstStmt::Repeat(repeat_stmt) => {
rewrite_function_block(&mut repeat_stmt.body, mapping);
rewrite_function_expr(&mut repeat_stmt.cond, mapping);
}
AstStmt::NumericFor(numeric_for) => {
rewrite_function_expr(&mut numeric_for.start, mapping);
rewrite_function_expr(&mut numeric_for.limit, mapping);
rewrite_function_expr(&mut numeric_for.step, mapping);
rewrite_function_block(&mut numeric_for.body, mapping);
}
AstStmt::GenericFor(generic_for) => {
for expr in &mut generic_for.iterator {
rewrite_function_expr(expr, mapping);
}
rewrite_function_block(&mut generic_for.body, mapping);
}
AstStmt::DoBlock(block) => rewrite_function_block(block, mapping),
AstStmt::FunctionDecl(function_decl) => {
rewrite_function_name(&mut function_decl.target, mapping)
}
AstStmt::LocalFunctionDecl(local_function_decl) => {
if let AstBindingRef::Temp(temp) = local_function_decl.name
&& let Some(&synthetic) = mapping.get(&temp)
{
local_function_decl.name = AstBindingRef::SyntheticLocal(synthetic);
}
}
AstStmt::Break | AstStmt::Continue | AstStmt::Goto(_) | AstStmt::Label(_) | AstStmt::Error(_) => {}
}
}
fn rewrite_function_name(
target: &mut super::super::common::AstFunctionName,
mapping: &BTreeMap<TempId, AstSyntheticLocalId>,
) {
let path = match target {
super::super::common::AstFunctionName::Plain(path) => path,
super::super::common::AstFunctionName::Method(path, _) => path,
};
rewrite_name_ref(&mut path.root, mapping);
}
fn rewrite_function_call(call: &mut AstCallKind, mapping: &BTreeMap<TempId, AstSyntheticLocalId>) {
match call {
AstCallKind::Call(call) => {
rewrite_function_expr(&mut call.callee, mapping);
for arg in &mut call.args {
rewrite_function_expr(arg, mapping);
}
}
AstCallKind::MethodCall(call) => {
rewrite_function_expr(&mut call.receiver, mapping);
for arg in &mut call.args {
rewrite_function_expr(arg, mapping);
}
}
}
}
fn rewrite_function_lvalue(
target: &mut AstLValue,
mapping: &BTreeMap<TempId, AstSyntheticLocalId>,
) {
match target {
AstLValue::Name(name) => rewrite_name_ref(name, mapping),
AstLValue::FieldAccess(access) => rewrite_function_expr(&mut access.base, mapping),
AstLValue::IndexAccess(access) => {
rewrite_function_expr(&mut access.base, mapping);
rewrite_function_expr(&mut access.index, mapping);
}
}
}
fn rewrite_function_expr(expr: &mut AstExpr, mapping: &BTreeMap<TempId, AstSyntheticLocalId>) {
match expr {
AstExpr::Var(name) => rewrite_name_ref(name, mapping),
AstExpr::FieldAccess(access) => rewrite_function_expr(&mut access.base, mapping),
AstExpr::IndexAccess(access) => {
rewrite_function_expr(&mut access.base, mapping);
rewrite_function_expr(&mut access.index, mapping);
}
AstExpr::Unary(unary) => rewrite_function_expr(&mut unary.expr, mapping),
AstExpr::Binary(binary) => {
rewrite_function_expr(&mut binary.lhs, mapping);
rewrite_function_expr(&mut binary.rhs, mapping);
}
AstExpr::LogicalAnd(logical) | AstExpr::LogicalOr(logical) => {
rewrite_function_expr(&mut logical.lhs, mapping);
rewrite_function_expr(&mut logical.rhs, mapping);
}
AstExpr::Call(call) => {
rewrite_function_expr(&mut call.callee, mapping);
for arg in &mut call.args {
rewrite_function_expr(arg, mapping);
}
}
AstExpr::MethodCall(call) => {
rewrite_function_expr(&mut call.receiver, mapping);
for arg in &mut call.args {
rewrite_function_expr(arg, mapping);
}
}
AstExpr::SingleValue(expr) => rewrite_function_expr(expr, mapping),
AstExpr::TableConstructor(table) => {
for field in &mut table.fields {
match field {
AstTableField::Array(value) => rewrite_function_expr(value, mapping),
AstTableField::Record(record) => {
if let AstTableKey::Expr(key) = &mut record.key {
rewrite_function_expr(key, mapping);
}
rewrite_function_expr(&mut record.value, mapping);
}
}
}
}
AstExpr::FunctionExpr(function) => {
rewrite_function_capture_bindings(function, mapping);
}
AstExpr::Nil
| AstExpr::Boolean(_)
| AstExpr::Integer(_)
| AstExpr::Number(_)
| AstExpr::String(_)
| AstExpr::Int64(_)
| AstExpr::UInt64(_)
| AstExpr::Complex { .. }
| AstExpr::VarArg | AstExpr::Error(_) => {}
}
}
fn rewrite_name_ref(name: &mut AstNameRef, mapping: &BTreeMap<TempId, AstSyntheticLocalId>) {
if let AstNameRef::Temp(temp) = name
&& let Some(&synthetic) = mapping.get(temp)
{
*name = AstNameRef::SyntheticLocal(synthetic);
}
}
fn rewrite_function_capture_bindings(
function: &mut AstFunctionExpr,
mapping: &BTreeMap<TempId, AstSyntheticLocalId>,
) {
if function.captured_bindings.is_empty() {
return;
}
function.captured_bindings = function
.captured_bindings
.iter()
.map(|binding| match binding {
AstBindingRef::Temp(temp) => mapping
.get(temp)
.copied()
.map(AstBindingRef::SyntheticLocal)
.unwrap_or(AstBindingRef::Temp(*temp)),
_ => *binding,
})
.collect();
}
#[cfg(test)]
mod tests;