use crate::hir::common::{HirBlock, HirExpr, HirLValue, HirProto, HirStmt, LocalId, TempId};
use super::visit::{HirVisitor, visit_stmts};
use super::walk::{HirRewritePass, for_each_nested_block_mut, rewrite_proto};
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct ScopeInterval {
start: usize,
end: usize,
reg_index: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ScopeBinding {
Local(LocalId),
Temp(TempId),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct ScopeStart {
start: usize,
reg_index: usize,
binding: ScopeBinding,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
struct ScopeActivity {
mentions_binding: bool,
closes_scope: bool,
}
impl ScopeActivity {
fn any(self) -> bool {
self.mentions_binding || self.closes_scope
}
}
pub(super) fn materialize_tbc_close_scopes_in_proto(proto: &mut HirProto) -> bool {
rewrite_proto(proto, &mut CloseScopePass)
}
struct CloseScopePass;
impl HirRewritePass for CloseScopePass {
fn rewrite_block(&mut self, block: &mut HirBlock) -> bool {
materialize_block(block)
}
}
fn materialize_block(block: &mut HirBlock) -> bool {
let rewritten = rewrite_stmt_slice(&block.stmts);
if rewritten != block.stmts {
block.stmts = rewritten;
return true;
}
false
}
fn rewrite_stmt_slice(stmts: &[HirStmt]) -> Vec<HirStmt> {
let intervals = collect_scope_intervals(stmts);
if intervals.is_empty() {
return stmts
.iter()
.filter(|stmt| !matches!(stmt, HirStmt::Close(close) if close.from_reg == 0))
.cloned()
.collect();
}
let mut cursor = 0;
rebuild_slice(stmts, 0, stmts.len(), &intervals, &mut cursor, None)
}
fn collect_scope_intervals(stmts: &[HirStmt]) -> Vec<ScopeInterval> {
let mut intervals: Vec<_> = (0..stmts.len())
.filter_map(|index| {
let scope_start = scope_start(stmts, index)?;
let end = find_scope_end(
stmts,
scope_start.start + 2,
scope_start.binding,
scope_start.reg_index,
)?;
(scope_start.start < end).then_some(ScopeInterval {
start: scope_start.start,
end,
reg_index: scope_start.reg_index,
})
})
.collect();
intervals.sort_by_key(|interval| (interval.start, interval.end));
if well_nested_scope_intervals(&intervals) {
intervals
} else {
Vec::new()
}
}
fn scope_start(stmts: &[HirStmt], index: usize) -> Option<ScopeStart> {
match (stmts.get(index), stmts.get(index + 1)) {
(
Some(HirStmt::LocalDecl(_) | HirStmt::Assign(_)),
Some(HirStmt::ToBeClosed(to_be_closed)),
) => binding_from_expr(&to_be_closed.value).map(|binding| ScopeStart {
start: index,
reg_index: to_be_closed.reg_index,
binding,
}),
_ => None,
}
}
fn binding_from_expr(expr: &HirExpr) -> Option<ScopeBinding> {
match expr {
HirExpr::LocalRef(local) => Some(ScopeBinding::Local(*local)),
HirExpr::TempRef(temp) => Some(ScopeBinding::Temp(*temp)),
_ => None,
}
}
fn find_scope_end(
stmts: &[HirStmt],
start_index: usize,
binding: ScopeBinding,
reg_index: usize,
) -> Option<usize> {
let mut saw_close = false;
let mut last_activity = None;
for (index, stmt) in stmts.iter().enumerate().skip(start_index) {
let activity = scope_activity_in_stmt(stmt, binding, reg_index);
if activity.any() {
last_activity = Some(index + 1);
}
saw_close |= activity.closes_scope;
}
if saw_close { last_activity } else { None }
}
fn well_nested_scope_intervals(intervals: &[ScopeInterval]) -> bool {
let mut stack = Vec::<ScopeInterval>::new();
for interval in intervals {
while let Some(top) = stack.last() {
if interval.start >= top.end {
stack.pop();
} else {
break;
}
}
if let Some(parent) = stack.last()
&& interval.end > parent.end
{
return false;
}
stack.push(*interval);
}
true
}
fn rebuild_slice(
stmts: &[HirStmt],
start: usize,
end: usize,
intervals: &[ScopeInterval],
cursor: &mut usize,
active_scope_reg: Option<usize>,
) -> Vec<HirStmt> {
let mut rewritten = Vec::new();
let mut index = start;
while index < end {
while *cursor < intervals.len() && intervals[*cursor].end <= index {
*cursor += 1;
}
if *cursor < intervals.len() {
let interval = intervals[*cursor];
if interval.start == index && interval.end <= end {
*cursor += 1;
let inner = rebuild_slice(
stmts,
interval.start,
interval.end,
intervals,
cursor,
Some(interval.reg_index),
);
let mut block_stmt = HirStmt::Block(Box::new(HirBlock { stmts: inner }));
strip_matching_close_from_stmt(&mut block_stmt, active_scope_reg);
rewritten.push(block_stmt);
index = interval.end;
continue;
}
}
let mut cloned = stmts[index].clone();
if strip_matching_close_from_stmt(&mut cloned, active_scope_reg) {
rewritten.push(cloned);
}
index += 1;
}
rewritten
}
fn strip_matching_close_from_stmt(stmt: &mut HirStmt, active_scope_reg: Option<usize>) -> bool {
if let HirStmt::Close(close) = stmt {
return close.from_reg != 0 && active_scope_reg != Some(close.from_reg);
}
for_each_nested_block_mut(stmt, &mut |block| {
strip_matching_close_from_block(block, active_scope_reg);
});
true
}
fn strip_matching_close_from_block(block: &mut HirBlock, active_scope_reg: Option<usize>) {
block
.stmts
.retain_mut(|stmt| strip_matching_close_from_stmt(stmt, active_scope_reg));
}
fn scope_activity_in_stmt(
stmt: &HirStmt,
binding: ScopeBinding,
reg_index: usize,
) -> ScopeActivity {
let mut collector = ScopeActivityCollector {
binding,
reg_index,
activity: ScopeActivity::default(),
};
visit_stmts(std::slice::from_ref(stmt), &mut collector);
collector.activity
}
struct ScopeActivityCollector {
binding: ScopeBinding,
reg_index: usize,
activity: ScopeActivity,
}
impl ScopeActivityCollector {
fn binding_matches_local(&self, local: LocalId) -> bool {
self.binding == ScopeBinding::Local(local)
}
fn binding_matches_temp(&self, temp: TempId) -> bool {
self.binding == ScopeBinding::Temp(temp)
}
}
impl HirVisitor for ScopeActivityCollector {
fn visit_stmt(&mut self, stmt: &HirStmt) {
match stmt {
HirStmt::LocalDecl(local_decl) => {
self.activity.mentions_binding |= local_decl
.bindings
.iter()
.copied()
.any(|local| self.binding_matches_local(local));
}
HirStmt::Close(close) => {
self.activity.closes_scope |= close.from_reg == self.reg_index;
}
HirStmt::NumericFor(numeric_for) => {
self.activity.mentions_binding |= self.binding_matches_local(numeric_for.binding);
}
HirStmt::GenericFor(generic_for) => {
self.activity.mentions_binding |= generic_for
.bindings
.iter()
.copied()
.any(|local| self.binding_matches_local(local));
}
HirStmt::Assign(_)
| HirStmt::TableSetList(_)
| HirStmt::ErrNil(_)
| HirStmt::ToBeClosed(_)
| HirStmt::CallStmt(_)
| HirStmt::Return(_)
| HirStmt::If(_)
| HirStmt::While(_)
| HirStmt::Repeat(_)
| HirStmt::Block(_)
| HirStmt::Unstructured(_)
| HirStmt::Break
| HirStmt::Continue
| HirStmt::Goto(_)
| HirStmt::Label(_) => {}
}
}
fn visit_expr(&mut self, expr: &HirExpr) {
match expr {
HirExpr::LocalRef(local) => {
self.activity.mentions_binding |= self.binding_matches_local(*local);
}
HirExpr::TempRef(temp) => {
self.activity.mentions_binding |= self.binding_matches_temp(*temp);
}
HirExpr::Nil
| HirExpr::Boolean(_)
| HirExpr::Integer(_)
| HirExpr::Number(_)
| HirExpr::String(_)
| HirExpr::Int64(_)
| HirExpr::UInt64(_)
| HirExpr::Complex { .. }
| HirExpr::ParamRef(_)
| HirExpr::UpvalueRef(_)
| HirExpr::GlobalRef(_)
| HirExpr::VarArg
| HirExpr::Unresolved(_)
| HirExpr::TableAccess(_)
| HirExpr::Unary(_)
| HirExpr::Binary(_)
| HirExpr::LogicalAnd(_)
| HirExpr::LogicalOr(_)
| HirExpr::Decision(_)
| HirExpr::Call(_)
| HirExpr::TableConstructor(_)
| HirExpr::Closure(_) => {}
}
}
fn visit_lvalue(&mut self, lvalue: &HirLValue) {
match lvalue {
HirLValue::Temp(temp) => {
self.activity.mentions_binding |= self.binding_matches_temp(*temp);
}
HirLValue::Local(local) => {
self.activity.mentions_binding |= self.binding_matches_local(*local);
}
HirLValue::Upvalue(_) | HirLValue::Global(_) | HirLValue::TableAccess(_) => {}
}
}
}