use tensorlogic_ir::TLExpr;
use super::consts::{is_false_const, is_true_const};
use super::types::{DceStats, DeadCodeEliminator};
impl DeadCodeEliminator {
pub(super) fn fold_and(
&self,
lhs: TLExpr,
rhs: TLExpr,
stats: &mut DceStats,
) -> Option<TLExpr> {
if is_false_const(&lhs) {
stats.constant_folds += 1;
return Some(TLExpr::Constant(0.0));
}
if is_false_const(&rhs) {
stats.constant_folds += 1;
return Some(TLExpr::Constant(0.0));
}
if is_true_const(&lhs) {
stats.constant_folds += 1;
return Some(rhs);
}
if is_true_const(&rhs) {
stats.constant_folds += 1;
return Some(lhs);
}
Some(TLExpr::And(Box::new(lhs), Box::new(rhs)))
}
pub(super) fn fold_or(&self, lhs: TLExpr, rhs: TLExpr, stats: &mut DceStats) -> Option<TLExpr> {
if is_true_const(&lhs) {
stats.constant_folds += 1;
return Some(TLExpr::Constant(1.0));
}
if is_true_const(&rhs) {
stats.constant_folds += 1;
return Some(TLExpr::Constant(1.0));
}
if is_false_const(&lhs) {
stats.constant_folds += 1;
return Some(rhs);
}
if is_false_const(&rhs) {
stats.constant_folds += 1;
return Some(lhs);
}
Some(TLExpr::Or(Box::new(lhs), Box::new(rhs)))
}
pub(super) fn fold_not(&self, inner: TLExpr, stats: &mut DceStats) -> Option<TLExpr> {
if is_true_const(&inner) {
stats.constant_folds += 1;
return Some(TLExpr::Constant(0.0));
}
if is_false_const(&inner) {
stats.constant_folds += 1;
return Some(TLExpr::Constant(1.0));
}
Some(TLExpr::Not(Box::new(inner)))
}
pub(super) fn fold_if(
&self,
cond: TLExpr,
then: TLExpr,
else_: TLExpr,
stats: &mut DceStats,
) -> Option<TLExpr> {
if is_true_const(&cond) {
stats.unreachable_branches += 1;
return Some(then);
}
if is_false_const(&cond) {
stats.unreachable_branches += 1;
return Some(else_);
}
Some(TLExpr::IfThenElse {
condition: Box::new(cond),
then_branch: Box::new(then),
else_branch: Box::new(else_),
})
}
}