#![cfg(feature = "parallel")]
use crate::deriv::log::{DerivationLog, DerivedExpr};
use crate::kernel::{ExprData, ExprId, ExprPool};
use crate::simplify::engine::SimplifyConfig;
use crate::simplify::rules::RewriteRule;
use rayon::prelude::*;
use std::sync::Arc;
const PAR_THRESHOLD: usize = 4;
pub fn simplify_par(expr: ExprId, pool: &ExprPool) -> DerivedExpr<ExprId> {
simplify_par_with_config(expr, pool, &SimplifyConfig::default())
}
pub fn simplify_par_with_config(
expr: ExprId,
pool: &ExprPool,
config: &SimplifyConfig,
) -> DerivedExpr<ExprId> {
let rules: Arc<Vec<Box<dyn RewriteRule + Send + Sync>>> =
Arc::new(rules_for_config_par(config));
let mut current = expr;
let mut full_log = DerivationLog::new();
for _ in 0..config.max_iterations {
let result = simplify_node_par(current, pool, &rules);
full_log = full_log.merge(result.log);
if result.value == current {
break;
}
current = result.value;
}
DerivedExpr::with_log(current, full_log)
}
fn simplify_node_par(
expr: ExprId,
pool: &ExprPool,
rules: &Arc<Vec<Box<dyn RewriteRule + Send + Sync>>>,
) -> DerivedExpr<ExprId> {
let data = pool.get(expr);
let (rebuilt, child_log) = simplify_children_par(data, pool, rules);
let mut current = rebuilt;
let mut rule_log = DerivationLog::new();
loop {
let mut fired = false;
for rule in rules.as_ref() {
if let Some((new_expr, step_log)) = rule.apply(current, pool) {
rule_log = rule_log.merge(step_log);
current = new_expr;
fired = true;
break;
}
}
if !fired {
break;
}
}
DerivedExpr::with_log(current, child_log.merge(rule_log))
}
fn simplify_children_par(
data: ExprData,
pool: &ExprPool,
rules: &Arc<Vec<Box<dyn RewriteRule + Send + Sync>>>,
) -> (ExprId, DerivationLog) {
match data {
ExprData::Add(args) if args.len() >= PAR_THRESHOLD => {
let results: Vec<DerivedExpr<ExprId>> = args
.par_iter()
.map(|&a| simplify_node_par(a, pool, rules))
.collect();
let new_args: Vec<ExprId> = results.iter().map(|r| r.value).collect();
let mut log = DerivationLog::new();
for r in results {
log = log.merge(r.log);
}
(pool.add(new_args), log)
}
ExprData::Mul(args) if args.len() >= PAR_THRESHOLD => {
let results: Vec<DerivedExpr<ExprId>> = args
.par_iter()
.map(|&a| simplify_node_par(a, pool, rules))
.collect();
let new_args: Vec<ExprId> = results.iter().map(|r| r.value).collect();
let mut log = DerivationLog::new();
for r in results {
log = log.merge(r.log);
}
(pool.mul(new_args), log)
}
ExprData::Add(args) => {
let mut log = DerivationLog::new();
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node_par(a, pool, rules);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.add(new_args), log)
}
ExprData::Mul(args) => {
let mut log = DerivationLog::new();
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node_par(a, pool, rules);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.mul(new_args), log)
}
ExprData::Pow { base, exp } => {
let rb = simplify_node_par(base, pool, rules);
let re = simplify_node_par(exp, pool, rules);
let log = rb.log.merge(re.log);
(pool.pow(rb.value, re.value), log)
}
ExprData::Func { name, args } => {
let mut log = DerivationLog::new();
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node_par(a, pool, rules);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.func(&name, new_args), log)
}
ExprData::Piecewise { branches, default } => {
let mut log = DerivationLog::new();
let new_branches: Vec<(ExprId, ExprId)> = branches
.into_iter()
.map(|(cond, val)| {
let rv = simplify_node_par(val, pool, rules);
log = std::mem::take(&mut log).merge(rv.log);
(cond, rv.value)
})
.collect();
let rd = simplify_node_par(default, pool, rules);
log = log.merge(rd.log);
(pool.piecewise(new_branches, rd.value), log)
}
ExprData::Predicate { kind, args } => {
let mut log = DerivationLog::new();
let new_args: Vec<ExprId> = args
.into_iter()
.map(|a| {
let r = simplify_node_par(a, pool, rules);
log = std::mem::take(&mut log).merge(r.log);
r.value
})
.collect();
(pool.predicate(kind, new_args), log)
}
ExprData::Forall { var, body } => {
let rb = simplify_node_par(body, pool, rules);
(pool.forall(var, rb.value), rb.log)
}
ExprData::Exists { var, body } => {
let rb = simplify_node_par(body, pool, rules);
(pool.exists(var, rb.value), rb.log)
}
ExprData::BigO(arg) => {
let r = simplify_node_par(arg, pool, rules);
(pool.big_o(r.value), r.log)
}
leaf => (pool.intern(leaf), DerivationLog::new()),
}
}
pub fn rules_for_config_par(config: &SimplifyConfig) -> Vec<Box<dyn RewriteRule + Send + Sync>> {
use crate::simplify::rules::{
AddZero, CanonicalOrder, ConstFold, DivSelf, ExpandMul, FlattenAdd, FlattenMul, MulOne,
MulZero, PowOne, PowZero, SubSelf,
};
let mut rules: Vec<Box<dyn RewriteRule + Send + Sync>> = vec![
Box::new(FlattenMul),
Box::new(FlattenAdd),
Box::new(MulZero),
Box::new(AddZero),
Box::new(MulOne),
Box::new(PowZero),
Box::new(PowOne),
Box::new(ConstFold),
Box::new(SubSelf),
Box::new(DivSelf),
Box::new(CanonicalOrder),
];
if config.expand {
rules.push(Box::new(ExpandMul));
}
rules
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
use crate::simplify::simplify;
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn par_matches_sequential_add() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let zero = pool.integer(0_i32);
let expr = pool.add(vec![x, zero, zero, zero, zero, zero]);
let seq = simplify(expr, &pool);
let par = simplify_par(expr, &pool);
assert_eq!(seq.value, par.value);
}
#[test]
fn par_matches_sequential_mul() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let one = pool.integer(1_i32);
let expr = pool.mul(vec![x, one, one, one, one, one]);
let seq = simplify(expr, &pool);
let par = simplify_par(expr, &pool);
assert_eq!(seq.value, par.value);
}
#[test]
fn par_constant_folding() {
let pool = p();
let a = pool.integer(2_i32);
let b = pool.integer(3_i32);
let c = pool.integer(4_i32);
let d = pool.integer(5_i32);
let expr = pool.add(vec![a, b, c, d]);
let par = simplify_par(expr, &pool);
let expected = pool.integer(14_i32);
assert_eq!(par.value, expected);
}
#[test]
fn par_large_sum() {
let pool = p();
let args: Vec<ExprId> = (1..=20).map(|i| pool.integer(i)).collect();
let expr = pool.add(args);
let par = simplify_par(expr, &pool);
let seq = simplify(expr, &pool);
assert_eq!(par.value, seq.value);
}
}