use std::collections::BTreeMap;
use crate::hir::common::{
HirAssign, HirBlock, HirExpr, HirLValue, HirLocalDecl, HirLogicalExpr, HirProto, HirStmt,
HirUnaryExpr, HirUnaryOpKind, LocalId, TempId,
};
use super::expr_facts::{expr_is_boolean_valued, expr_is_side_effect_free};
use super::visit::{HirVisitor, visit_proto};
use super::walk::{HirRewritePass, rewrite_proto};
pub(super) fn remove_boolean_materialization_shells_in_proto(proto: &mut HirProto) -> bool {
let use_counts = collect_temp_use_counts(proto);
let mut dead_shell_pass = DeadBooleanShellPass {
use_counts: &use_counts,
};
let mut collapse_shell_pass = CollapseBooleanShellPass;
rewrite_proto(proto, &mut dead_shell_pass) | rewrite_proto(proto, &mut collapse_shell_pass)
}
struct DeadBooleanShellPass<'a> {
use_counts: &'a BTreeMap<TempId, usize>,
}
impl HirRewritePass for DeadBooleanShellPass<'_> {
fn rewrite_block(&mut self, block: &mut HirBlock) -> bool {
remove_dead_materialization_shells_from_block(block, self.use_counts)
}
}
fn remove_dead_materialization_shells_from_block(
block: &mut HirBlock,
use_counts: &BTreeMap<TempId, usize>,
) -> bool {
let mut index = 0;
let mut changed = false;
while index < block.stmts.len() {
if removable_dead_materialization_shell(&block.stmts[index], use_counts) {
block.stmts.remove(index);
changed = true;
continue;
}
index += 1;
}
changed
}
struct CollapseBooleanShellPass;
impl HirRewritePass for CollapseBooleanShellPass {
fn rewrite_block(&mut self, block: &mut HirBlock) -> bool {
collapse_live_boolean_materialization_shells_in_block(block)
}
}
fn collapse_live_boolean_materialization_shells_in_block(block: &mut HirBlock) -> bool {
let mut index = 0;
let mut changed = false;
while index < block.stmts.len() {
let Some((target, value)) =
collapsible_live_boolean_materialization_shell(&block.stmts[index])
else {
index += 1;
continue;
};
if index > 0
&& let HirLValue::Local(local) = &target
&& empty_single_local_decl_binding(&block.stmts[index - 1]) == Some(*local)
{
block.stmts[index - 1] = HirStmt::LocalDecl(Box::new(HirLocalDecl {
bindings: vec![*local],
values: vec![value],
}));
block.stmts.remove(index);
changed = true;
index = index.saturating_sub(1);
continue;
}
block.stmts[index] = HirStmt::Assign(Box::new(HirAssign {
targets: vec![target],
values: vec![value],
}));
changed = true;
index += 1;
}
changed
}
fn collapsible_live_boolean_materialization_shell(stmt: &HirStmt) -> Option<(HirLValue, HirExpr)> {
let HirStmt::If(if_stmt) = stmt else {
return None;
};
let Some(else_block) = &if_stmt.else_block else {
return None;
};
let (then_target, then_value) = pure_assign_pattern(&if_stmt.then_block)?;
let (else_target, else_value) = pure_assign_pattern(else_block)?;
if then_target != else_target {
return None;
}
match (then_value, else_value) {
(HirExpr::Boolean(true), HirExpr::Boolean(false)) => Some((
then_target.clone(),
booleanized_truthiness_expr(if_stmt.cond.clone()),
)),
(HirExpr::Boolean(false), HirExpr::Boolean(true)) => Some((
then_target.clone(),
HirExpr::Unary(Box::new(HirUnaryExpr {
op: HirUnaryOpKind::Not,
expr: if_stmt.cond.clone(),
})),
)),
_ => None,
}
}
fn removable_dead_materialization_shell(
stmt: &HirStmt,
use_counts: &BTreeMap<TempId, usize>,
) -> bool {
let HirStmt::If(if_stmt) = stmt else {
return false;
};
let Some(else_block) = &if_stmt.else_block else {
return false;
};
if !expr_is_side_effect_free(&if_stmt.cond) {
return false;
}
let Some((then_target, then_value)) = pure_assign_pattern(&if_stmt.then_block) else {
return false;
};
let Some((else_target, else_value)) = pure_assign_pattern(else_block) else {
return false;
};
let (HirLValue::Temp(then_temp), HirLValue::Temp(else_temp)) = (then_target, else_target)
else {
return false;
};
if use_counts.get(then_temp).copied().unwrap_or(0) != 0
|| use_counts.get(else_temp).copied().unwrap_or(0) != 0
{
return false;
}
expr_is_side_effect_free(then_value) && expr_is_side_effect_free(else_value)
}
fn pure_assign_pattern(block: &HirBlock) -> Option<(&HirLValue, &HirExpr)> {
let [HirStmt::Assign(assign)] = block.stmts.as_slice() else {
return None;
};
let [target] = assign.targets.as_slice() else {
return None;
};
let [value] = assign.values.as_slice() else {
return None;
};
Some((target, value))
}
fn empty_single_local_decl_binding(stmt: &HirStmt) -> Option<LocalId> {
let HirStmt::LocalDecl(local_decl) = stmt else {
return None;
};
let [binding] = local_decl.bindings.as_slice() else {
return None;
};
if !local_decl.values.is_empty() {
return None;
}
Some(*binding)
}
fn booleanized_truthiness_expr(cond: HirExpr) -> HirExpr {
if expr_is_boolean_valued(&cond) {
cond
} else {
HirExpr::LogicalOr(Box::new(HirLogicalExpr {
lhs: HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
lhs: cond,
rhs: HirExpr::Boolean(true),
})),
rhs: HirExpr::Boolean(false),
}))
}
}
fn collect_temp_use_counts(proto: &HirProto) -> BTreeMap<TempId, usize> {
let mut collector = TempUseCollector::default();
visit_proto(proto, &mut collector);
collector.use_counts
}
#[derive(Default)]
struct TempUseCollector {
use_counts: BTreeMap<TempId, usize>,
}
impl HirVisitor for TempUseCollector {
fn visit_expr(&mut self, expr: &HirExpr) {
if let HirExpr::TempRef(temp) = expr {
*self.use_counts.entry(*temp).or_default() += 1;
}
}
}
#[cfg(test)]
mod tests;