use std::collections::BTreeMap;
use crate::cfg::DefId;
use crate::hir::common::{HirExpr, HirLValue, HirStmt, TempId};
use crate::hir::traverse::{
traverse_hir_call_children, traverse_hir_decision_children, traverse_hir_expr_children,
traverse_hir_lvalue_children, traverse_hir_stmt_children,
traverse_hir_table_constructor_children,
};
pub(super) fn apply_loop_rewrites(
stmts: &mut [HirStmt],
target_overrides: &BTreeMap<TempId, HirLValue>,
) {
if target_overrides.is_empty() {
return;
}
let expr_overrides = temp_expr_overrides(target_overrides);
for stmt in stmts {
rewrite_stmt_exprs(stmt, &expr_overrides);
rewrite_stmt_targets(stmt, target_overrides);
}
}
pub(super) fn temp_expr_overrides(
target_overrides: &BTreeMap<TempId, HirLValue>,
) -> BTreeMap<TempId, HirExpr> {
target_overrides
.iter()
.filter_map(|(temp, lvalue)| lvalue_as_expr(lvalue).map(|expr| (*temp, expr)))
.collect()
}
pub(super) fn lvalue_as_expr(lvalue: &HirLValue) -> Option<HirExpr> {
match lvalue {
HirLValue::Temp(temp) => Some(HirExpr::TempRef(*temp)),
HirLValue::Local(local) => Some(HirExpr::LocalRef(*local)),
HirLValue::Upvalue(upvalue) => Some(HirExpr::UpvalueRef(*upvalue)),
HirLValue::Global(global) => Some(HirExpr::GlobalRef(global.clone())),
HirLValue::TableAccess(_) => None,
}
}
pub(super) fn expr_as_lvalue(expr: &HirExpr) -> Option<HirLValue> {
match expr {
HirExpr::TempRef(temp) => Some(HirLValue::Temp(*temp)),
HirExpr::LocalRef(local) => Some(HirLValue::Local(*local)),
HirExpr::UpvalueRef(upvalue) => Some(HirLValue::Upvalue(*upvalue)),
HirExpr::GlobalRef(global) => Some(HirLValue::Global(global.clone())),
_ => None,
}
}
pub(super) fn shared_expr_for_defs<I>(
fixed_temps: &[TempId],
defs: I,
target_overrides: &BTreeMap<TempId, HirLValue>,
) -> Option<HirExpr>
where
I: IntoIterator<Item = DefId>,
{
let mut shared_expr = None;
for def in defs {
let temp = *fixed_temps.get(def.index())?;
let lvalue = target_overrides.get(&temp)?;
let expr = lvalue_as_expr(lvalue)?;
if shared_expr
.as_ref()
.is_some_and(|known_expr: &HirExpr| *known_expr != expr)
{
return None;
}
shared_expr = Some(expr);
}
shared_expr
}
pub(super) fn shared_lvalue_for_defs<I>(
fixed_temps: &[TempId],
defs: I,
target_overrides: &BTreeMap<TempId, HirLValue>,
) -> Option<HirLValue>
where
I: IntoIterator<Item = DefId>,
{
let mut shared_target = None;
for def in defs {
let temp = *fixed_temps.get(def.index())?;
let target = target_overrides.get(&temp)?;
let _ = lvalue_as_expr(target)?;
if shared_target
.as_ref()
.is_some_and(|known_target: &HirLValue| *known_target != *target)
{
return None;
}
shared_target = Some(target.clone());
}
shared_target
}
pub(super) fn install_def_target_overrides(
fixed_temps: &[TempId],
defs: impl IntoIterator<Item = DefId>,
target: &HirLValue,
overrides: &mut BTreeMap<TempId, HirLValue>,
) {
for def in defs {
let Some(def_temp) = fixed_temps.get(def.index()) else {
continue;
};
overrides.insert(*def_temp, target.clone());
}
}
pub(super) fn rewrite_stmt_targets(
stmt: &mut HirStmt,
target_overrides: &BTreeMap<TempId, HirLValue>,
) {
let HirStmt::Assign(assign) = stmt else {
return;
};
for target in &mut assign.targets {
let HirLValue::Temp(temp) = target else {
continue;
};
if let Some(replacement) = target_overrides.get(temp) {
*target = replacement.clone();
}
}
}
pub(super) fn rewrite_stmt_exprs(stmt: &mut HirStmt, expr_overrides: &BTreeMap<TempId, HirExpr>) {
traverse_hir_stmt_children!(
stmt,
iter = iter_mut,
opt = as_mut,
borrow = [&mut],
expr(e) => { rewrite_expr_temps(e, expr_overrides); },
lvalue(lv) => {
traverse_hir_lvalue_children!(
lv,
borrow = [&mut],
expr(e) => { rewrite_expr_temps(e, expr_overrides); }
);
},
block(_b) => {},
call(c) => {
traverse_hir_call_children!(
c,
iter = iter_mut,
borrow = [&mut],
expr(e) => { rewrite_expr_temps(e, expr_overrides); }
);
},
condition(cond) => { rewrite_expr_temps(cond, expr_overrides); }
);
}
pub(super) fn rewrite_expr_temps(expr: &mut HirExpr, expr_overrides: &BTreeMap<TempId, HirExpr>) {
if let HirExpr::TempRef(temp) = expr
&& let Some(replacement) = expr_overrides.get(temp)
{
*expr = replacement.clone();
return;
}
traverse_hir_expr_children!(
expr,
iter = iter_mut,
borrow = [&mut],
expr(e) => { rewrite_expr_temps(e, expr_overrides); },
call(c) => {
traverse_hir_call_children!(
c,
iter = iter_mut,
borrow = [&mut],
expr(e) => { rewrite_expr_temps(e, expr_overrides); }
);
},
decision(d) => {
traverse_hir_decision_children!(
d,
iter = iter_mut,
borrow = [&mut],
expr(e) => { rewrite_expr_temps(e, expr_overrides); },
condition(cond) => { rewrite_expr_temps(cond, expr_overrides); }
);
},
table_constructor(t) => {
traverse_hir_table_constructor_children!(
t,
iter = iter_mut,
opt = as_mut,
borrow = [&mut],
expr(e) => { rewrite_expr_temps(e, expr_overrides); }
);
}
);
}