use super::rules::{
AddZero, CanonicalOrder, ConstFold, DivSelf, ExpandMul, FlattenAdd, FlattenMul, MulOne,
MulZero, PowOne, PowZero, RewriteRule, SqrtInteger, SubSelf,
};
use super::rulesets::PatternRuleSet;
use crate::deriv::log::{DerivationLog, DerivedExpr};
use crate::kernel::{ExprData, ExprId, ExprPool};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SimplifyConfig {
pub max_iterations: usize,
pub expand: bool,
pub allow_branch_cut_rewrites: bool,
pub assumptions: Vec<crate::deriv::log::SideCondition>,
}
impl Default for SimplifyConfig {
fn default() -> Self {
SimplifyConfig {
max_iterations: 100,
expand: false,
allow_branch_cut_rewrites: false,
assumptions: vec![],
}
}
}
pub fn rules_for_config(config: &SimplifyConfig) -> Vec<Box<dyn RewriteRule>> {
let mut rules: Vec<Box<dyn RewriteRule>> = vec![
Box::new(FlattenMul),
Box::new(FlattenAdd),
Box::new(MulZero),
Box::new(AddZero),
Box::new(MulOne),
Box::new(PowZero),
Box::new(PowOne),
Box::new(ConstFold),
Box::new(SqrtInteger),
Box::new(SubSelf),
Box::new(DivSelf),
Box::new(CanonicalOrder),
];
if config.expand {
rules.push(Box::new(ExpandMul));
}
rules
}
pub fn default_rules() -> Vec<Box<dyn RewriteRule>> {
rules_for_config(&SimplifyConfig::default())
}
fn simplify_node(
expr: ExprId,
pool: &ExprPool,
rules: &[Box<dyn RewriteRule>],
memo: &mut HashMap<ExprId, ExprId>,
) -> DerivedExpr<ExprId> {
if let Some(&cached) = memo.get(&expr) {
return DerivedExpr::new(cached);
}
let data = pool.get(expr);
let (rebuilt, child_log) = simplify_children(data, pool, rules, memo);
let mut current = rebuilt;
let mut rule_log = DerivationLog::new();
loop {
let mut fired = false;
for rule in rules {
if let Some((new_expr, step_log)) = rule.apply(current, pool) {
rule_log = rule_log.merge(step_log);
current = new_expr;
fired = true;
break; }
}
if !fired {
break;
}
}
let result = DerivedExpr::with_log(current, child_log.merge(rule_log));
memo.insert(expr, result.value);
result
}
fn simplify_node_indexed(
expr: ExprId,
pool: &ExprPool,
rule_set: &PatternRuleSet,
child_rules: &[Box<dyn RewriteRule>],
memo: &mut HashMap<ExprId, ExprId>,
) -> DerivedExpr<ExprId> {
if let Some(&cached) = memo.get(&expr) {
return DerivedExpr::new(cached);
}
let data = pool.get(expr);
let (rebuilt, child_log) = simplify_children(data, pool, child_rules, memo);
let mut current = rebuilt;
let mut rule_log = DerivationLog::new();
loop {
let mut fired = false;
for idx in rule_set.index().candidates(current, pool) {
if let Some((new_expr, step_log)) = rule_set.rules()[idx].apply(current, pool) {
rule_log = rule_log.merge(step_log);
current = new_expr;
fired = true;
break;
}
}
if !fired {
break;
}
}
let result = DerivedExpr::with_log(current, child_log.merge(rule_log));
memo.insert(expr, result.value);
result
}
fn simplify_children(
data: ExprData,
pool: &ExprPool,
rules: &[Box<dyn RewriteRule>],
memo: &mut HashMap<ExprId, ExprId>,
) -> (ExprId, DerivationLog) {
let mut log = DerivationLog::new();
match data {
ExprData::Add(args) => {
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node(a, pool, rules, memo);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.add(new_args), log)
}
ExprData::Mul(args) => {
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node(a, pool, rules, memo);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.mul(new_args), log)
}
ExprData::Pow { base, exp } => {
let rb = simplify_node(base, pool, rules, memo);
log = log.merge(rb.log);
let re = simplify_node(exp, pool, rules, memo);
log = log.merge(re.log);
(pool.pow(rb.value, re.value), log)
}
ExprData::Func { name, args } => {
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node(a, pool, rules, memo);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.func(name, new_args), log)
}
ExprData::Piecewise { branches, default } => {
let new_branches: Vec<(ExprId, ExprId)> = branches
.into_iter()
.map(|(cond, val)| {
let rv = simplify_node(val, pool, rules, memo);
log = std::mem::take(&mut log).merge(rv.log);
(cond, rv.value)
})
.collect();
let rd = simplify_node(default, pool, rules, memo);
log = log.merge(rd.log);
(pool.piecewise(new_branches, rd.value), log)
}
ExprData::Predicate { kind, args } => {
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node(a, pool, rules, memo);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.predicate(kind, new_args), log)
}
ExprData::Forall { var, body } => {
let rb = simplify_node(body, pool, rules, memo);
log = log.merge(rb.log);
(pool.forall(var, rb.value), log)
}
ExprData::Exists { var, body } => {
let rb = simplify_node(body, pool, rules, memo);
log = log.merge(rb.log);
(pool.exists(var, rb.value), log)
}
ExprData::BigO(arg) => {
let r = simplify_node(arg, pool, rules, memo);
log = log.merge(r.log);
(pool.big_o(r.value), log)
}
atom => (pool.intern(atom), log),
}
}
pub fn simplify_with(
expr: ExprId,
pool: &ExprPool,
rules: &[Box<dyn RewriteRule>],
config: SimplifyConfig,
) -> DerivedExpr<ExprId> {
let mut current = DerivedExpr::new(expr);
for _ in 0..config.max_iterations {
let mut memo: HashMap<ExprId, ExprId> = HashMap::new();
let result = simplify_node(current.value, pool, rules, &mut memo);
let merged_log = current.log.merge(result.log);
if result.value == current.value {
current = DerivedExpr::with_log(current.value, merged_log);
break;
}
current = DerivedExpr::with_log(result.value, merged_log);
}
if !config.assumptions.is_empty() {
let colored = super::colored_egraph::apply_colored_if_needed(
current.value,
pool,
&config.assumptions,
);
return DerivedExpr::with_log(colored.value, current.log.merge(colored.log));
}
current
}
pub fn simplify_with_pattern_rules(
expr: ExprId,
pool: &ExprPool,
rule_set: &PatternRuleSet,
config: SimplifyConfig,
) -> DerivedExpr<ExprId> {
let child_rules = rule_set.as_dyn_rules();
let mut current = DerivedExpr::new(expr);
for _ in 0..config.max_iterations {
let mut memo: HashMap<ExprId, ExprId> = HashMap::new();
let result = simplify_node_indexed(current.value, pool, rule_set, &child_rules, &mut memo);
let merged_log = current.log.merge(result.log);
if result.value == current.value {
current = DerivedExpr::with_log(current.value, merged_log);
break;
}
current = DerivedExpr::with_log(result.value, merged_log);
}
if !config.assumptions.is_empty() {
let colored = super::colored_egraph::apply_colored_if_needed(
current.value,
pool,
&config.assumptions,
);
return DerivedExpr::with_log(colored.value, current.log.merge(colored.log));
}
current
}
pub fn simplify(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
let config = SimplifyConfig::default();
simplify_with(expr, pool, &rules_for_config(&config), config)
}
pub fn simplify_batch(exprs: &[ExprId], pool: &ExprPool) -> Vec<DerivedExpr<ExprId>> {
let config = SimplifyConfig::default();
let rules = rules_for_config(&config);
let mut current: Vec<ExprId> = exprs.to_vec();
let mut logs: Vec<DerivationLog> = vec![DerivationLog::new(); exprs.len()];
let mut done = vec![false; exprs.len()];
for _ in 0..config.max_iterations {
let mut memo: HashMap<ExprId, ExprId> = HashMap::new();
let mut any_changed = false;
for i in 0..current.len() {
if done[i] {
continue;
}
let result = simplify_node(current[i], pool, &rules, &mut memo);
logs[i] = std::mem::take(&mut logs[i]).merge(result.log);
if result.value == current[i] {
done[i] = true;
} else {
current[i] = result.value;
any_changed = true;
}
}
if !any_changed {
break;
}
}
current
.into_iter()
.zip(logs)
.map(|(value, log)| DerivedExpr::with_log(value, log))
.collect()
}
pub fn simplify_expanded(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
let config = SimplifyConfig {
expand: true,
..SimplifyConfig::default()
};
simplify_with(expr, pool, &rules_for_config(&config), config)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn simplify_x_plus_zero() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![x, pool.integer(0_i32)]);
let r = simplify(expr, &pool);
assert_eq!(r.value, x);
assert!(!r.log.is_empty(), "should have logged a step");
assert!(
r.log.steps().iter().any(|s| s.rule_name == "add_zero"),
"log should mention add_zero"
);
}
#[test]
fn simplify_x_times_one() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.mul(vec![x, pool.integer(1_i32)]);
let r = simplify(expr, &pool);
assert_eq!(r.value, x);
}
#[test]
fn simplify_x_times_zero() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.mul(vec![x, pool.integer(0_i32)]);
let r = simplify(expr, &pool);
assert_eq!(r.value, pool.integer(0_i32));
}
#[test]
fn simplify_x_pow_one() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.pow(x, pool.integer(1_i32));
let r = simplify(expr, &pool);
assert_eq!(r.value, x);
}
#[test]
fn simplify_x_pow_zero() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.pow(x, pool.integer(0_i32));
let r = simplify(expr, &pool);
assert_eq!(r.value, pool.integer(1_i32));
assert!(
r.log.steps().iter().any(|s| !s.side_conditions.is_empty()),
"pow_zero should record side condition"
);
}
#[test]
fn simplify_const_fold_add() {
let pool = p();
let expr = pool.add(vec![pool.integer(2_i32), pool.integer(3_i32)]);
let r = simplify(expr, &pool);
assert_eq!(r.value, pool.integer(5_i32));
}
#[test]
fn simplify_const_fold_mul() {
let pool = p();
let expr = pool.mul(vec![pool.integer(4_i32), pool.integer(5_i32)]);
let r = simplify(expr, &pool);
assert_eq!(r.value, pool.integer(20_i32));
}
#[test]
fn simplify_const_fold_pow() {
let pool = p();
let expr = pool.pow(pool.integer(2_i32), pool.integer(10_i32));
let r = simplify(expr, &pool);
assert_eq!(r.value, pool.integer(1024_i32));
}
#[test]
fn simplify_sub_self() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let neg_x = pool.mul(vec![pool.integer(-1_i32), x]);
let expr = pool.add(vec![x, neg_x]);
let r = simplify(expr, &pool);
assert_eq!(r.value, pool.integer(0_i32));
}
#[test]
fn simplify_div_self() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x_inv = pool.pow(x, pool.integer(-1_i32));
let expr = pool.mul(vec![x, x_inv]);
let r = simplify(expr, &pool);
assert_eq!(r.value, pool.integer(1_i32));
}
#[test]
fn simplify_nested() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let inner = pool.add(vec![x, pool.integer(0_i32)]);
let expr = pool.mul(vec![inner, pool.integer(1_i32)]);
let r = simplify(expr, &pool);
assert_eq!(r.value, x);
}
#[test]
fn simplify_idempotent_on_already_simple() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = simplify(x, &pool);
assert_eq!(r.value, x);
assert!(r.log.is_empty());
}
#[test]
fn simplify_batch_matches_individual() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let y = pool.symbol("y", Domain::Real);
let a = pool.add(vec![x, pool.integer(0_i32)]);
let b = pool.mul(vec![pool.add(vec![y, pool.integer(0_i32)]), a]);
let c = pool.pow(x, pool.integer(1_i32));
let inputs = [a, b, c];
let batched = simplify_batch(&inputs, &pool);
assert_eq!(batched.len(), inputs.len());
for (i, &input) in inputs.iter().enumerate() {
let individual = simplify(input, &pool);
assert_eq!(
batched[i].value, individual.value,
"batch result for input {i} must equal simplify()"
);
}
}
#[test]
fn simplify_with_custom_config() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![x, pool.integer(0_i32)]);
let config = SimplifyConfig {
max_iterations: 1,
..SimplifyConfig::default()
};
let r = simplify_with(expr, &pool, &default_rules(), config);
assert_eq!(r.value, x);
}
#[test]
fn simplify_with_assumptions_sqrt_square() {
use crate::deriv::log::SideCondition;
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.func("sqrt", vec![pool.pow(x, pool.integer(2_i32))]);
let config = SimplifyConfig {
assumptions: vec![SideCondition::Positive(x)],
..SimplifyConfig::default()
};
let r = simplify_with(expr, &pool, &default_rules(), config);
assert_eq!(r.value, x);
}
#[test]
fn simplify_dag_shared_subexpr_correct() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let mut node = pool.add(vec![x, pool.integer(0_i32)]); for _ in 0..20 {
node = pool.add(vec![node, node]);
}
let r = simplify(node, &pool);
let s = pool.display(r.value).to_string();
assert!(
!s.contains("+ 0") && !s.contains("0 +"),
"simplify should eliminate '+ 0' from shared expression: {s}"
);
}
#[test]
fn diff_dag_shared_subexpr_correct() {
use crate::diff::diff;
let pool = p();
let x = pool.symbol("x", Domain::Real);
let inner = pool.add(vec![pool.pow(x, pool.integer(2_i32)), x]); let expr = pool.add(vec![inner, inner]); let r = diff(expr, x, &pool).unwrap();
let s = pool.display(r.value).to_string();
assert!(
!s.is_empty(),
"diff of shared DAG expression returned empty string"
);
}
#[test]
fn eval_interp_dag_shared_subexpr_correct() {
use crate::jit::{compile, eval_interp};
let pool = p();
let x = pool.symbol("x", Domain::Real);
let shared = pool.add(vec![x, pool.integer(1_i32)]);
let expr = pool.mul(vec![shared, shared]);
let mut env = std::collections::HashMap::new();
env.insert(x, 3.0f64); let result = eval_interp(expr, &env, &pool);
assert_eq!(result, Some(16.0), "eval_interp shared DAG: expected 16");
let f = compile(expr, &[x], &pool).unwrap();
assert!((f.call(&[3.0]) - 16.0).abs() < 1e-10);
}
}