use tensorlogic_ir::TLExpr;
use super::helpers::as_constant;
use super::pe_core::pe_rec;
use super::types::{PEConfig, PEEnv, PEStats, PEValue};
pub(super) fn try_pe_logic(
expr: TLExpr,
env: &PEEnv,
config: &PEConfig,
depth: usize,
stats: &mut PEStats,
) -> Result<TLExpr, TLExpr> {
match expr {
TLExpr::Not(inner) => {
let e = pe_rec(*inner, env, config, depth + 1, stats);
if config.fold_logic {
if let Some(v) = as_constant(&e) {
let b = v != 0.0;
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(TLExpr::Constant(if b { 0.0 } else { 1.0 }));
}
}
Ok(TLExpr::Not(Box::new(e)))
}
TLExpr::And(lhs, rhs) => {
let la = pe_rec(*lhs, env, config, depth + 1, stats);
if config.prune_branches {
if let Some(a) = as_constant(&la) {
if a == 0.0 {
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(TLExpr::Constant(0.0));
}
let ra = pe_rec(*rhs, env, config, depth + 1, stats);
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(ra);
}
}
let ra = pe_rec(*rhs, env, config, depth + 1, stats);
if config.prune_branches {
if let Some(b) = as_constant(&ra) {
if b == 0.0 {
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(TLExpr::Constant(0.0));
}
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(la);
}
}
if config.fold_logic {
if let (Some(a), Some(b)) = (as_constant(&la), as_constant(&ra)) {
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(TLExpr::Constant(if a != 0.0 && b != 0.0 {
1.0
} else {
0.0
}));
}
}
Ok(TLExpr::And(Box::new(la), Box::new(ra)))
}
TLExpr::Or(lhs, rhs) => {
let la = pe_rec(*lhs, env, config, depth + 1, stats);
if config.prune_branches {
if let Some(a) = as_constant(&la) {
if a != 0.0 {
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(TLExpr::Constant(1.0));
}
let ra = pe_rec(*rhs, env, config, depth + 1, stats);
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(ra);
}
}
let ra = pe_rec(*rhs, env, config, depth + 1, stats);
if config.prune_branches {
if let Some(b) = as_constant(&ra) {
if b != 0.0 {
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(TLExpr::Constant(1.0));
}
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(la);
}
}
if config.fold_logic {
if let (Some(a), Some(b)) = (as_constant(&la), as_constant(&ra)) {
stats.nodes_reduced = stats.nodes_reduced.saturating_add(1);
return Ok(TLExpr::Constant(if a != 0.0 || b != 0.0 {
1.0
} else {
0.0
}));
}
}
Ok(TLExpr::Or(Box::new(la), Box::new(ra)))
}
TLExpr::Imply(premise, conclusion) => {
let not_a = Box::new(TLExpr::Not(premise));
let expanded = TLExpr::Or(not_a, conclusion);
Ok(pe_rec(expanded, env, config, depth, stats))
}
TLExpr::IfThenElse {
condition,
then_branch,
else_branch,
} => {
let cond = pe_rec(*condition, env, config, depth + 1, stats);
if config.prune_branches {
if let Some(v) = as_constant(&cond) {
if v != 0.0 {
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
return Ok(pe_rec(*then_branch, env, config, depth + 1, stats));
} else {
stats.branches_pruned = stats.branches_pruned.saturating_add(1);
return Ok(pe_rec(*else_branch, env, config, depth + 1, stats));
}
}
}
let tb = pe_rec(*then_branch, env, config, depth + 1, stats);
let eb = pe_rec(*else_branch, env, config, depth + 1, stats);
Ok(TLExpr::IfThenElse {
condition: Box::new(cond),
then_branch: Box::new(tb),
else_branch: Box::new(eb),
})
}
TLExpr::Let { var, value, body } => {
let evaluated_value = pe_rec(*value, env, config, depth + 1, stats);
if config.inline_lets {
let peval = PEValue::from_expr(evaluated_value.clone());
if peval.is_concrete() {
let new_env = env.extend(var.clone(), peval);
let result = pe_rec(*body, &new_env, config, depth + 1, stats);
stats.lets_inlined = stats.lets_inlined.saturating_add(1);
return Ok(result);
}
let new_env = env.extend(var.clone(), PEValue::Symbolic(evaluated_value.clone()));
let new_body = pe_rec(*body, &new_env, config, depth + 1, stats);
Ok(TLExpr::Let {
var,
value: Box::new(evaluated_value),
body: Box::new(new_body),
})
} else {
let shadowed_env =
env.extend(var.clone(), PEValue::Symbolic(evaluated_value.clone()));
let new_body = pe_rec(*body, &shadowed_env, config, depth + 1, stats);
Ok(TLExpr::Let {
var,
value: Box::new(evaluated_value),
body: Box::new(new_body),
})
}
}
other => Err(other),
}
}