use std::collections::BTreeSet;
use crate::hir::common::{HirExpr, HirLValue, HirProto, HirStmt, TempId};
use super::visit::{HirVisitor, visit_proto};
use super::walk::{HirRewritePass, rewrite_proto};
pub(super) fn resolve_recursive_closure_self_captures_in_proto(proto: &mut HirProto) -> bool {
let defined_temps = collect_defined_temps(proto);
let mut pass = RecursiveClosureSelfCapturePass {
defined_temps: &defined_temps,
};
rewrite_proto(proto, &mut pass)
}
struct RecursiveClosureSelfCapturePass<'a> {
defined_temps: &'a BTreeSet<TempId>,
}
impl HirRewritePass for RecursiveClosureSelfCapturePass<'_> {
fn rewrite_stmt(&mut self, stmt: &mut HirStmt) -> bool {
match stmt {
HirStmt::LocalDecl(local_decl)
if local_decl.bindings.len() == 1 && local_decl.values.len() == 1 =>
{
rewrite_closure_self_captures(
&mut local_decl.values[0],
HirExpr::LocalRef(local_decl.bindings[0]),
self.defined_temps,
)
}
HirStmt::Assign(assign) if assign.targets.len() == 1 && assign.values.len() == 1 => {
let Some(binding_expr) = lvalue_as_expr(&assign.targets[0]) else {
return false;
};
rewrite_closure_self_captures(
&mut assign.values[0],
binding_expr,
self.defined_temps,
)
}
_ => false,
}
}
}
fn rewrite_closure_self_captures(
expr: &mut HirExpr,
replacement: HirExpr,
defined_temps: &BTreeSet<TempId>,
) -> bool {
let HirExpr::Closure(closure) = expr else {
return false;
};
let mut changed = false;
for capture in &mut closure.captures {
let HirExpr::TempRef(temp) = capture.value else {
continue;
};
if defined_temps.contains(&temp) {
continue;
}
capture.value = replacement.clone();
changed = true;
}
changed
}
fn lvalue_as_expr(target: &HirLValue) -> Option<HirExpr> {
match target {
HirLValue::Temp(temp) => Some(HirExpr::TempRef(*temp)),
HirLValue::Local(local) => Some(HirExpr::LocalRef(*local)),
HirLValue::Upvalue(upvalue) => Some(HirExpr::UpvalueRef(*upvalue)),
HirLValue::Global(global) => Some(HirExpr::GlobalRef(global.clone())),
HirLValue::TableAccess(_) => None,
}
}
fn collect_defined_temps(proto: &HirProto) -> BTreeSet<TempId> {
let mut collector = DefinedTempCollector::default();
visit_proto(proto, &mut collector);
collector.defined
}
#[derive(Default)]
struct DefinedTempCollector {
defined: BTreeSet<TempId>,
}
impl HirVisitor for DefinedTempCollector {
fn visit_stmt(&mut self, stmt: &HirStmt) {
let HirStmt::Assign(assign) = stmt else {
return;
};
for target in &assign.targets {
if let HirLValue::Temp(temp) = target {
self.defined.insert(*temp);
}
}
}
}
#[cfg(test)]
mod tests;