use super::expr_facts::{
expr_is_side_effect_free, expr_truthiness, fold_associative_duplicate_and,
fold_associative_duplicate_or,
};
use super::walk::{ExprRewritePass, rewrite_proto_exprs};
use crate::hir::common::{HirExpr, HirLogicalExpr, HirProto};
#[cfg(test)]
use crate::hir::common::HirBlock;
#[cfg(test)]
use crate::hir::common::HirStmt;
pub(super) fn simplify_logical_exprs_in_proto(proto: &mut HirProto) -> bool {
rewrite_proto_exprs(proto, &mut LogicalExprPass)
}
struct LogicalExprPass;
impl ExprRewritePass for LogicalExprPass {
fn rewrite_expr(&mut self, expr: &mut HirExpr) -> bool {
let mut changed = false;
if let Some(replacement) = simplify_logical_shape(expr) {
*expr = replacement;
changed = true;
}
if let Some(replacement) = super::decision::naturalize_pure_logical_expr(expr) {
*expr = replacement;
changed = true;
}
changed
}
}
fn simplify_logical_shape(expr: &HirExpr) -> Option<HirExpr> {
match expr {
HirExpr::LogicalAnd(logical) => simplify_logical_and(&logical.lhs, &logical.rhs),
HirExpr::LogicalOr(logical) => simplify_logical_or(&logical.lhs, &logical.rhs),
_ => None,
}
}
fn simplify_logical_and(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
if lhs == rhs {
return Some(lhs.clone());
}
if let Some(replacement) = fold_associative_duplicate_and(lhs, rhs) {
return Some(replacement);
}
if let Some(replacement) = fold_constant_short_circuit_and(lhs, rhs) {
return Some(replacement);
}
match rhs {
HirExpr::LogicalOr(inner) if lhs == &inner.lhs => Some(lhs.clone()),
_ => match lhs {
HirExpr::LogicalOr(inner) if rhs == &inner.lhs || rhs == &inner.rhs => {
Some(rhs.clone())
}
_ => None,
},
}
}
fn simplify_logical_or(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
if lhs == rhs {
return Some(lhs.clone());
}
if let Some(replacement) = fold_associative_duplicate_or(lhs, rhs) {
return Some(replacement);
}
if let Some(replacement) = fold_constant_short_circuit_or(lhs, rhs) {
return Some(replacement);
}
if let Some(replacement) = factor_shared_and_guards(lhs, rhs) {
return Some(replacement);
}
if let Some(replacement) = pull_shared_or_tail(lhs, rhs) {
return Some(replacement);
}
if let Some(replacement) = simplify_or_chain(lhs, rhs) {
return Some(replacement);
}
if let Some(replacement) = fold_shared_fallback_or(lhs, rhs) {
return Some(replacement);
}
match rhs {
HirExpr::LogicalAnd(inner) if lhs == &inner.lhs => Some(lhs.clone()),
_ => match lhs {
HirExpr::LogicalAnd(inner) if rhs == &inner.lhs || rhs == &inner.rhs => {
Some(rhs.clone())
}
_ => None,
},
}
}
fn factor_shared_and_guards(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
factor_shared_and_guards_one_side(lhs, rhs)
.or_else(|| factor_shared_and_guards_one_side(rhs, lhs))
}
fn factor_shared_and_guards_one_side(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
let HirExpr::LogicalAnd(lhs_and) = lhs else {
return None;
};
let HirExpr::LogicalAnd(rhs_and) = rhs else {
return None;
};
if lhs_and.lhs == rhs_and.lhs && expr_is_side_effect_free(&lhs_and.lhs) {
return Some(HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
lhs: lhs_and.lhs.clone(),
rhs: HirExpr::LogicalOr(Box::new(HirLogicalExpr {
lhs: lhs_and.rhs.clone(),
rhs: rhs_and.rhs.clone(),
})),
})));
}
if lhs_and.rhs == rhs_and.rhs && expr_is_side_effect_free(&lhs_and.rhs) {
return Some(HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
lhs: HirExpr::LogicalOr(Box::new(HirLogicalExpr {
lhs: lhs_and.lhs.clone(),
rhs: rhs_and.lhs.clone(),
})),
rhs: lhs_and.rhs.clone(),
})));
}
None
}
fn pull_shared_or_tail(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
pull_shared_or_tail_one_side(lhs, rhs).or_else(|| pull_shared_or_tail_one_side(rhs, lhs))
}
fn pull_shared_or_tail_one_side(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
let HirExpr::LogicalAnd(lhs_and) = lhs else {
return None;
};
let HirExpr::LogicalOr(inner_or) = &lhs_and.rhs else {
return None;
};
if rhs != &inner_or.rhs || !expr_is_side_effect_free(rhs) {
return None;
}
Some(HirExpr::LogicalOr(Box::new(HirLogicalExpr {
lhs: HirExpr::LogicalAnd(Box::new(HirLogicalExpr {
lhs: lhs_and.lhs.clone(),
rhs: inner_or.lhs.clone(),
})),
rhs: rhs.clone(),
})))
}
fn simplify_or_chain(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
let terms = flatten_or_chain_exprs(lhs, rhs);
if terms.len() < 3 {
return None;
}
let mut best = None;
for left in 0..terms.len() {
for right in left + 1..terms.len() {
let rewritten = factor_shared_and_guards_one_side(&terms[left], &terms[right])
.or_else(|| factor_shared_and_guards_one_side(&terms[right], &terms[left]))
.or_else(|| pull_shared_or_tail_one_side(&terms[left], &terms[right]))
.or_else(|| pull_shared_or_tail_one_side(&terms[right], &terms[left]));
let Some(rewritten) = rewritten else {
continue;
};
let mut rebuilt = Vec::with_capacity(terms.len() - 1);
for (index, term) in terms.iter().enumerate() {
if index == left {
rebuilt.push(rewritten.clone());
} else if index != right {
rebuilt.push(term.clone());
}
}
let candidate = rebuild_or_chain(rebuilt);
if expr_cost(&candidate)
< expr_cost(&HirExpr::LogicalOr(Box::new(HirLogicalExpr {
lhs: lhs.clone(),
rhs: rhs.clone(),
})))
{
match &best {
Some(existing) if expr_cost(existing) <= expr_cost(&candidate) => {}
_ => best = Some(candidate),
}
}
}
}
best
}
fn flatten_or_chain_exprs(lhs: &HirExpr, rhs: &HirExpr) -> Vec<HirExpr> {
let mut out = Vec::new();
collect_or_chain_exprs(lhs, &mut out);
collect_or_chain_exprs(rhs, &mut out);
out
}
fn collect_or_chain_exprs(expr: &HirExpr, out: &mut Vec<HirExpr>) {
match expr {
HirExpr::LogicalOr(logical) => {
collect_or_chain_exprs(&logical.lhs, out);
collect_or_chain_exprs(&logical.rhs, out);
}
_ => out.push(expr.clone()),
}
}
fn rebuild_or_chain(mut terms: Vec<HirExpr>) -> HirExpr {
let first = terms
.drain(..1)
.next()
.expect("or chain rebuild requires at least one term");
terms.into_iter().fold(first, |lhs, rhs| {
HirExpr::LogicalOr(Box::new(HirLogicalExpr { lhs, rhs }))
})
}
fn expr_cost(expr: &HirExpr) -> usize {
match expr {
HirExpr::LogicalAnd(logical) | HirExpr::LogicalOr(logical) => {
1 + expr_cost(&logical.lhs) + expr_cost(&logical.rhs)
}
HirExpr::Unary(unary) => 1 + expr_cost(&unary.expr),
HirExpr::Binary(binary) => 1 + expr_cost(&binary.lhs) + expr_cost(&binary.rhs),
_ => 1,
}
}
fn fold_constant_short_circuit_and(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
match expr_truthiness(lhs) {
Some(true) => Some(rhs.clone()),
Some(false) if expr_is_side_effect_free(rhs) => Some(lhs.clone()),
Some(false) => None,
None => None,
}
}
fn fold_constant_short_circuit_or(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
match expr_truthiness(lhs) {
Some(true) if expr_is_side_effect_free(rhs) => Some(lhs.clone()),
Some(true) => None,
Some(false) => Some(rhs.clone()),
None => None,
}
}
fn fold_shared_fallback_or(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
shared_fallback_or_one_side(lhs, rhs).or_else(|| shared_fallback_or_one_side(rhs, lhs))
}
fn shared_fallback_or_one_side(lhs: &HirExpr, rhs: &HirExpr) -> Option<HirExpr> {
let HirExpr::LogicalAnd(lhs_and) = lhs else {
return None;
};
let HirExpr::LogicalOr(rhs_or) = rhs else {
return None;
};
let guard = strip_negation(&lhs_and.lhs)?;
if guard != rhs_or.lhs || lhs_and.rhs != rhs_or.rhs {
return None;
}
expr_is_side_effect_free(&lhs_and.rhs).then_some(rhs.clone())
}
fn strip_negation(expr: &HirExpr) -> Option<HirExpr> {
match expr {
HirExpr::Unary(unary) if matches!(unary.op, crate::hir::common::HirUnaryOpKind::Not) => {
Some(unary.expr.clone())
}
_ => None,
}
}
#[cfg(test)]
mod tests;