use std::collections::BTreeMap;
use crate::cfg::DefId;
use crate::hir::common::{HirExpr, HirLValue, HirStmt, TempId};
pub(super) fn apply_loop_rewrites(
stmts: &mut [HirStmt],
target_overrides: &BTreeMap<TempId, HirLValue>,
) {
if target_overrides.is_empty() {
return;
}
let expr_overrides = temp_expr_overrides(target_overrides);
for stmt in stmts {
rewrite_stmt_exprs(stmt, &expr_overrides);
rewrite_stmt_targets(stmt, target_overrides);
}
}
pub(super) fn temp_expr_overrides(
target_overrides: &BTreeMap<TempId, HirLValue>,
) -> BTreeMap<TempId, HirExpr> {
target_overrides
.iter()
.filter_map(|(temp, lvalue)| lvalue_as_expr(lvalue).map(|expr| (*temp, expr)))
.collect()
}
pub(super) fn lvalue_as_expr(lvalue: &HirLValue) -> Option<HirExpr> {
match lvalue {
HirLValue::Temp(temp) => Some(HirExpr::TempRef(*temp)),
HirLValue::Local(local) => Some(HirExpr::LocalRef(*local)),
HirLValue::Upvalue(upvalue) => Some(HirExpr::UpvalueRef(*upvalue)),
HirLValue::Global(global) => Some(HirExpr::GlobalRef(global.clone())),
HirLValue::TableAccess(_) => None,
}
}
pub(super) fn expr_as_lvalue(expr: &HirExpr) -> Option<HirLValue> {
match expr {
HirExpr::TempRef(temp) => Some(HirLValue::Temp(*temp)),
HirExpr::LocalRef(local) => Some(HirLValue::Local(*local)),
HirExpr::UpvalueRef(upvalue) => Some(HirLValue::Upvalue(*upvalue)),
HirExpr::GlobalRef(global) => Some(HirLValue::Global(global.clone())),
_ => None,
}
}
pub(super) fn shared_expr_for_defs<I>(
fixed_temps: &[TempId],
defs: I,
target_overrides: &BTreeMap<TempId, HirLValue>,
) -> Option<HirExpr>
where
I: IntoIterator<Item = DefId>,
{
let mut shared_expr = None;
for def in defs {
let temp = *fixed_temps.get(def.index())?;
let lvalue = target_overrides.get(&temp)?;
let expr = lvalue_as_expr(lvalue)?;
if shared_expr
.as_ref()
.is_some_and(|known_expr: &HirExpr| *known_expr != expr)
{
return None;
}
shared_expr = Some(expr);
}
shared_expr
}
pub(super) fn rewrite_stmt_targets(
stmt: &mut HirStmt,
target_overrides: &BTreeMap<TempId, HirLValue>,
) {
let HirStmt::Assign(assign) = stmt else {
return;
};
for target in &mut assign.targets {
let HirLValue::Temp(temp) = target else {
continue;
};
if let Some(replacement) = target_overrides.get(temp) {
*target = replacement.clone();
}
}
}
pub(super) fn rewrite_stmt_exprs(stmt: &mut HirStmt, expr_overrides: &BTreeMap<TempId, HirExpr>) {
match stmt {
HirStmt::LocalDecl(local_decl) => {
for value in &mut local_decl.values {
rewrite_expr_temps(value, expr_overrides);
}
}
HirStmt::Assign(assign) => {
for target in &mut assign.targets {
rewrite_lvalue_exprs(target, expr_overrides);
}
for value in &mut assign.values {
rewrite_expr_temps(value, expr_overrides);
}
}
HirStmt::TableSetList(set_list) => {
rewrite_expr_temps(&mut set_list.base, expr_overrides);
for value in &mut set_list.values {
rewrite_expr_temps(value, expr_overrides);
}
if let Some(trailing) = &mut set_list.trailing_multivalue {
rewrite_expr_temps(trailing, expr_overrides);
}
}
HirStmt::ErrNil(err_nil) => {
rewrite_expr_temps(&mut err_nil.value, expr_overrides);
}
HirStmt::ToBeClosed(to_be_closed) => {
rewrite_expr_temps(&mut to_be_closed.value, expr_overrides);
}
HirStmt::CallStmt(call_stmt) => {
rewrite_call_expr_temps(&mut call_stmt.call, expr_overrides)
}
HirStmt::Return(ret) => {
for value in &mut ret.values {
rewrite_expr_temps(value, expr_overrides);
}
}
HirStmt::If(if_stmt) => {
rewrite_expr_temps(&mut if_stmt.cond, expr_overrides);
}
HirStmt::While(while_stmt) => {
rewrite_expr_temps(&mut while_stmt.cond, expr_overrides);
}
HirStmt::Repeat(repeat_stmt) => {
rewrite_expr_temps(&mut repeat_stmt.cond, expr_overrides);
}
HirStmt::NumericFor(numeric_for) => {
rewrite_expr_temps(&mut numeric_for.start, expr_overrides);
rewrite_expr_temps(&mut numeric_for.limit, expr_overrides);
rewrite_expr_temps(&mut numeric_for.step, expr_overrides);
}
HirStmt::GenericFor(generic_for) => {
for value in &mut generic_for.iterator {
rewrite_expr_temps(value, expr_overrides);
}
}
HirStmt::Break
| HirStmt::Close(_)
| HirStmt::Continue
| HirStmt::Goto(_)
| HirStmt::Label(_)
| HirStmt::Block(_)
| HirStmt::Unstructured(_) => {}
}
}
fn rewrite_call_expr_temps(
call: &mut crate::hir::common::HirCallExpr,
expr_overrides: &BTreeMap<TempId, HirExpr>,
) {
rewrite_expr_temps(&mut call.callee, expr_overrides);
for arg in &mut call.args {
rewrite_expr_temps(arg, expr_overrides);
}
}
fn rewrite_lvalue_exprs(lvalue: &mut HirLValue, expr_overrides: &BTreeMap<TempId, HirExpr>) {
if let HirLValue::TableAccess(access) = lvalue {
rewrite_expr_temps(&mut access.base, expr_overrides);
rewrite_expr_temps(&mut access.key, expr_overrides);
}
}
pub(super) fn rewrite_expr_temps(expr: &mut HirExpr, expr_overrides: &BTreeMap<TempId, HirExpr>) {
match expr {
HirExpr::TempRef(temp) => {
if let Some(replacement) = expr_overrides.get(temp) {
*expr = replacement.clone();
}
}
HirExpr::TableAccess(access) => {
rewrite_expr_temps(&mut access.base, expr_overrides);
rewrite_expr_temps(&mut access.key, expr_overrides);
}
HirExpr::Unary(unary) => rewrite_expr_temps(&mut unary.expr, expr_overrides),
HirExpr::Binary(binary) => {
rewrite_expr_temps(&mut binary.lhs, expr_overrides);
rewrite_expr_temps(&mut binary.rhs, expr_overrides);
}
HirExpr::LogicalAnd(logical) | HirExpr::LogicalOr(logical) => {
rewrite_expr_temps(&mut logical.lhs, expr_overrides);
rewrite_expr_temps(&mut logical.rhs, expr_overrides);
}
HirExpr::Decision(decision) => {
for node in &mut decision.nodes {
rewrite_expr_temps(&mut node.test, expr_overrides);
rewrite_decision_target_temps(&mut node.truthy, expr_overrides);
rewrite_decision_target_temps(&mut node.falsy, expr_overrides);
}
}
HirExpr::Call(call) => rewrite_call_expr_temps(call, expr_overrides),
HirExpr::TableConstructor(table) => {
for field in &mut table.fields {
match field {
crate::hir::common::HirTableField::Array(expr) => {
rewrite_expr_temps(expr, expr_overrides);
}
crate::hir::common::HirTableField::Record(field) => {
if let crate::hir::common::HirTableKey::Expr(expr) = &mut field.key {
rewrite_expr_temps(expr, expr_overrides);
}
rewrite_expr_temps(&mut field.value, expr_overrides);
}
}
}
if let Some(trailing) = &mut table.trailing_multivalue {
rewrite_expr_temps(trailing, expr_overrides);
}
}
HirExpr::Closure(closure) => {
for capture in &mut closure.captures {
rewrite_expr_temps(&mut capture.value, expr_overrides);
}
}
HirExpr::Nil
| HirExpr::Boolean(_)
| HirExpr::Integer(_)
| HirExpr::Number(_)
| HirExpr::String(_)
| HirExpr::Int64(_)
| HirExpr::UInt64(_)
| HirExpr::Complex { .. }
| HirExpr::ParamRef(_)
| HirExpr::LocalRef(_)
| HirExpr::UpvalueRef(_)
| HirExpr::GlobalRef(_)
| HirExpr::VarArg
| HirExpr::Unresolved(_) => {}
}
}
fn rewrite_decision_target_temps(
target: &mut crate::hir::common::HirDecisionTarget,
expr_overrides: &BTreeMap<TempId, HirExpr>,
) {
if let crate::hir::common::HirDecisionTarget::Expr(expr) = target {
rewrite_expr_temps(expr, expr_overrides);
}
}