use super::{
algebraic::{simplify_algebraic, AlgebraicSimplificationStats},
constant_folding::{fold_constants, ConstantFoldingStats},
dead_code::{eliminate_dead_code, DeadCodeStats},
distributivity::{optimize_distributivity, DistributivityStats},
negation::{optimize_negations, NegationOptStats},
quantifier_opt::{optimize_quantifiers, QuantifierOptStats},
strength_reduction::{reduce_strength, StrengthReductionStats},
};
use tensorlogic_ir::TLExpr;
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub enable_negation_opt: bool,
pub enable_constant_folding: bool,
pub enable_algebraic_simplification: bool,
pub enable_strength_reduction: bool,
pub enable_distributivity: bool,
pub enable_quantifier_opt: bool,
pub enable_dead_code_elimination: bool,
pub max_iterations: usize,
pub stop_on_fixed_point: bool,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
enable_negation_opt: true,
enable_constant_folding: true,
enable_algebraic_simplification: true,
enable_strength_reduction: true,
enable_distributivity: true,
enable_quantifier_opt: true,
enable_dead_code_elimination: true,
max_iterations: 10,
stop_on_fixed_point: true,
}
}
}
impl PipelineConfig {
pub fn all() -> Self {
Self::default()
}
pub fn none() -> Self {
Self {
enable_negation_opt: false,
enable_constant_folding: false,
enable_algebraic_simplification: false,
enable_strength_reduction: false,
enable_distributivity: false,
enable_quantifier_opt: false,
enable_dead_code_elimination: false,
max_iterations: 1,
stop_on_fixed_point: true,
}
}
pub fn constant_folding_only() -> Self {
Self {
enable_negation_opt: false,
enable_constant_folding: true,
enable_algebraic_simplification: false,
enable_strength_reduction: false,
enable_distributivity: false,
enable_quantifier_opt: false,
enable_dead_code_elimination: false,
max_iterations: 1,
stop_on_fixed_point: true,
}
}
pub fn algebraic_only() -> Self {
Self {
enable_negation_opt: false,
enable_constant_folding: false,
enable_algebraic_simplification: true,
enable_strength_reduction: false,
enable_distributivity: false,
enable_quantifier_opt: false,
enable_dead_code_elimination: false,
max_iterations: 1,
stop_on_fixed_point: true,
}
}
pub fn aggressive() -> Self {
Self {
enable_negation_opt: true,
enable_constant_folding: true,
enable_algebraic_simplification: true,
enable_strength_reduction: true,
enable_distributivity: true,
enable_quantifier_opt: true,
enable_dead_code_elimination: true,
max_iterations: 20,
stop_on_fixed_point: true,
}
}
pub fn with_negation_opt(mut self, enable: bool) -> Self {
self.enable_negation_opt = enable;
self
}
pub fn with_constant_folding(mut self, enable: bool) -> Self {
self.enable_constant_folding = enable;
self
}
pub fn with_algebraic_simplification(mut self, enable: bool) -> Self {
self.enable_algebraic_simplification = enable;
self
}
pub fn with_max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub fn with_stop_on_fixed_point(mut self, stop: bool) -> Self {
self.stop_on_fixed_point = stop;
self
}
pub fn with_strength_reduction(mut self, enable: bool) -> Self {
self.enable_strength_reduction = enable;
self
}
pub fn with_distributivity(mut self, enable: bool) -> Self {
self.enable_distributivity = enable;
self
}
pub fn with_quantifier_opt(mut self, enable: bool) -> Self {
self.enable_quantifier_opt = enable;
self
}
pub fn with_dead_code_elimination(mut self, enable: bool) -> Self {
self.enable_dead_code_elimination = enable;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct IterationStats {
pub negation: NegationOptStats,
pub constant_folding: ConstantFoldingStats,
pub algebraic: AlgebraicSimplificationStats,
pub strength_reduction: StrengthReductionStats,
pub distributivity: DistributivityStats,
pub quantifier_opt: QuantifierOptStats,
pub dead_code: DeadCodeStats,
}
impl IterationStats {
pub fn made_changes(&self) -> bool {
self.negation.double_negations_eliminated > 0
|| self.negation.demorgans_applied > 0
|| self.negation.quantifier_negations_pushed > 0
|| self.constant_folding.binary_ops_folded > 0
|| self.constant_folding.unary_ops_folded > 0
|| self.algebraic.identities_eliminated > 0
|| self.algebraic.annihilations_applied > 0
|| self.algebraic.idempotent_simplified > 0
|| self.strength_reduction.total_optimizations() > 0
|| self.distributivity.total_optimizations() > 0
|| self.quantifier_opt.total_optimizations() > 0
|| self.dead_code.total_optimizations() > 0
}
pub fn total_optimizations(&self) -> usize {
self.negation.double_negations_eliminated
+ self.negation.demorgans_applied
+ self.negation.quantifier_negations_pushed
+ self.constant_folding.binary_ops_folded
+ self.constant_folding.unary_ops_folded
+ self.algebraic.identities_eliminated
+ self.algebraic.annihilations_applied
+ self.algebraic.idempotent_simplified
+ self.strength_reduction.total_optimizations()
+ self.distributivity.total_optimizations()
+ self.quantifier_opt.total_optimizations()
+ self.dead_code.total_optimizations()
}
}
#[derive(Debug, Clone, Default)]
pub struct PipelineStats {
pub total_iterations: usize,
pub negation: NegationOptStats,
pub constant_folding: ConstantFoldingStats,
pub algebraic: AlgebraicSimplificationStats,
pub strength_reduction: StrengthReductionStats,
pub distributivity: DistributivityStats,
pub quantifier_opt: QuantifierOptStats,
pub dead_code: DeadCodeStats,
pub iterations: Vec<IterationStats>,
pub reached_fixed_point: bool,
pub stopped_at_max_iterations: bool,
}
impl PipelineStats {
pub fn total_optimizations(&self) -> usize {
self.negation.double_negations_eliminated
+ self.negation.demorgans_applied
+ self.negation.quantifier_negations_pushed
+ self.constant_folding.binary_ops_folded
+ self.constant_folding.unary_ops_folded
+ self.algebraic.identities_eliminated
+ self.algebraic.annihilations_applied
+ self.algebraic.idempotent_simplified
+ self.strength_reduction.total_optimizations()
+ self.distributivity.total_optimizations()
+ self.quantifier_opt.total_optimizations()
+ self.dead_code.total_optimizations()
}
pub fn most_productive_iteration(&self) -> Option<(usize, &IterationStats)> {
self.iterations
.iter()
.enumerate()
.max_by_key(|(_, stats)| stats.total_optimizations())
}
}
impl std::fmt::Display for PipelineStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Pipeline Statistics:")?;
writeln!(f, " Iterations: {}", self.total_iterations)?;
writeln!(f, " Reached fixed point: {}", self.reached_fixed_point)?;
writeln!(f, " Total optimizations: {}", self.total_optimizations())?;
writeln!(f, "\nNegation Optimization:")?;
writeln!(
f,
" Double negations eliminated: {}",
self.negation.double_negations_eliminated
)?;
writeln!(
f,
" De Morgan's laws applied: {}",
self.negation.demorgans_applied
)?;
writeln!(
f,
" Quantifier negations pushed: {}",
self.negation.quantifier_negations_pushed
)?;
writeln!(f, "\nConstant Folding:")?;
writeln!(
f,
" Binary ops folded: {}",
self.constant_folding.binary_ops_folded
)?;
writeln!(
f,
" Unary ops folded: {}",
self.constant_folding.unary_ops_folded
)?;
writeln!(f, "\nAlgebraic Simplification:")?;
writeln!(
f,
" Identities eliminated: {}",
self.algebraic.identities_eliminated
)?;
writeln!(
f,
" Annihilations applied: {}",
self.algebraic.annihilations_applied
)?;
writeln!(
f,
" Idempotent simplified: {}",
self.algebraic.idempotent_simplified
)?;
writeln!(f, "\nStrength Reduction:")?;
writeln!(
f,
" Power reductions: {}",
self.strength_reduction.power_reductions
)?;
writeln!(
f,
" Operations eliminated: {}",
self.strength_reduction.operations_eliminated
)?;
writeln!(
f,
" Special function optimizations: {}",
self.strength_reduction.special_function_optimizations
)?;
writeln!(f, "\nDistributivity:")?;
writeln!(
f,
" Expressions factored: {}",
self.distributivity.expressions_factored
)?;
writeln!(
f,
" Expressions expanded: {}",
self.distributivity.expressions_expanded
)?;
writeln!(f, "\nQuantifier Optimization:")?;
writeln!(
f,
" Invariants hoisted: {}",
self.quantifier_opt.invariants_hoisted
)?;
writeln!(
f,
" Quantifiers reordered: {}",
self.quantifier_opt.quantifiers_reordered
)?;
writeln!(f, "\nDead Code Elimination:")?;
writeln!(
f,
" Branches eliminated: {}",
self.dead_code.branches_eliminated
)?;
writeln!(f, " Short circuits: {}", self.dead_code.short_circuits)?;
writeln!(
f,
" Unused quantifiers removed: {}",
self.dead_code.unused_quantifiers_removed
)?;
Ok(())
}
}
pub struct OptimizationPipeline {
config: PipelineConfig,
}
impl OptimizationPipeline {
pub fn new() -> Self {
Self {
config: PipelineConfig::default(),
}
}
pub fn with_config(config: PipelineConfig) -> Self {
Self { config }
}
pub fn optimize(&self, expr: &TLExpr) -> (TLExpr, PipelineStats) {
let mut current = expr.clone();
let mut stats = PipelineStats::default();
for iteration in 0..self.config.max_iterations {
let mut iter_stats = IterationStats::default();
let mut changed = false;
if self.config.enable_negation_opt {
let (optimized, neg_stats) = optimize_negations(¤t);
iter_stats.negation = neg_stats;
if optimized != current {
current = optimized;
changed = true;
}
}
if self.config.enable_constant_folding {
let (optimized, fold_stats) = fold_constants(¤t);
iter_stats.constant_folding = fold_stats;
if optimized != current {
current = optimized;
changed = true;
}
}
if self.config.enable_algebraic_simplification {
let (optimized, alg_stats) = simplify_algebraic(¤t);
iter_stats.algebraic = alg_stats;
if optimized != current {
current = optimized;
changed = true;
}
}
if self.config.enable_strength_reduction {
let (optimized, sr_stats) = reduce_strength(¤t);
iter_stats.strength_reduction = sr_stats;
if optimized != current {
current = optimized;
changed = true;
}
}
if self.config.enable_distributivity {
let (optimized, dist_stats) = optimize_distributivity(¤t);
iter_stats.distributivity = dist_stats;
if optimized != current {
current = optimized;
changed = true;
}
}
if self.config.enable_quantifier_opt {
let (optimized, quant_stats) = optimize_quantifiers(¤t);
iter_stats.quantifier_opt = quant_stats;
if optimized != current {
current = optimized;
changed = true;
}
}
if self.config.enable_dead_code_elimination {
let (optimized, dead_stats) = eliminate_dead_code(¤t);
iter_stats.dead_code = dead_stats;
if optimized != current {
current = optimized;
changed = true;
}
}
stats.total_iterations = iteration + 1;
stats.negation.double_negations_eliminated +=
iter_stats.negation.double_negations_eliminated;
stats.negation.demorgans_applied += iter_stats.negation.demorgans_applied;
stats.negation.quantifier_negations_pushed +=
iter_stats.negation.quantifier_negations_pushed;
stats.constant_folding.binary_ops_folded +=
iter_stats.constant_folding.binary_ops_folded;
stats.constant_folding.unary_ops_folded += iter_stats.constant_folding.unary_ops_folded;
stats.constant_folding.total_processed += iter_stats.constant_folding.total_processed;
stats.algebraic.identities_eliminated += iter_stats.algebraic.identities_eliminated;
stats.algebraic.annihilations_applied += iter_stats.algebraic.annihilations_applied;
stats.algebraic.idempotent_simplified += iter_stats.algebraic.idempotent_simplified;
stats.algebraic.total_processed += iter_stats.algebraic.total_processed;
stats.strength_reduction.power_reductions +=
iter_stats.strength_reduction.power_reductions;
stats.strength_reduction.operations_eliminated +=
iter_stats.strength_reduction.operations_eliminated;
stats.strength_reduction.special_function_optimizations +=
iter_stats.strength_reduction.special_function_optimizations;
stats.strength_reduction.total_processed +=
iter_stats.strength_reduction.total_processed;
stats.distributivity.expressions_factored +=
iter_stats.distributivity.expressions_factored;
stats.distributivity.expressions_expanded +=
iter_stats.distributivity.expressions_expanded;
stats.distributivity.common_terms_extracted +=
iter_stats.distributivity.common_terms_extracted;
stats.distributivity.total_processed += iter_stats.distributivity.total_processed;
stats.quantifier_opt.invariants_hoisted += iter_stats.quantifier_opt.invariants_hoisted;
stats.quantifier_opt.quantifiers_reordered +=
iter_stats.quantifier_opt.quantifiers_reordered;
stats.quantifier_opt.quantifiers_fused += iter_stats.quantifier_opt.quantifiers_fused;
stats.quantifier_opt.total_processed += iter_stats.quantifier_opt.total_processed;
stats.dead_code.branches_eliminated += iter_stats.dead_code.branches_eliminated;
stats.dead_code.short_circuits += iter_stats.dead_code.short_circuits;
stats.dead_code.unused_quantifiers_removed +=
iter_stats.dead_code.unused_quantifiers_removed;
stats.dead_code.identity_simplifications +=
iter_stats.dead_code.identity_simplifications;
stats.dead_code.total_processed += iter_stats.dead_code.total_processed;
stats.iterations.push(iter_stats);
if self.config.stop_on_fixed_point && !changed {
stats.reached_fixed_point = true;
break;
}
if iteration + 1 >= self.config.max_iterations {
stats.stopped_at_max_iterations = true;
}
}
(current, stats)
}
pub fn config(&self) -> &PipelineConfig {
&self.config
}
}
impl Default for OptimizationPipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
#[test]
fn test_pipeline_with_all_passes() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::negate(TLExpr::and(
TLExpr::add(x, TLExpr::Constant(0.0)),
TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
));
let pipeline = OptimizationPipeline::new();
let (optimized, stats) = pipeline.optimize(&expr);
assert!(stats.total_iterations > 0);
assert!(stats.constant_folding.binary_ops_folded > 0);
assert!(stats.algebraic.identities_eliminated > 0);
assert!(stats.negation.demorgans_applied > 0);
assert!(optimized != expr);
}
#[test]
fn test_constant_folding_only() {
let expr = TLExpr::add(
TLExpr::Constant(2.0),
TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
);
let config = PipelineConfig::constant_folding_only();
let pipeline = OptimizationPipeline::with_config(config);
let (optimized, stats) = pipeline.optimize(&expr);
assert!(matches!(optimized, TLExpr::Constant(_)));
assert_eq!(stats.constant_folding.binary_ops_folded, 2);
assert_eq!(stats.algebraic.identities_eliminated, 0);
assert_eq!(stats.negation.demorgans_applied, 0);
}
#[test]
fn test_algebraic_only() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::mul(TLExpr::add(x, TLExpr::Constant(0.0)), TLExpr::Constant(1.0));
let config = PipelineConfig::algebraic_only();
let pipeline = OptimizationPipeline::with_config(config);
let (_optimized, stats) = pipeline.optimize(&expr);
assert_eq!(stats.algebraic.identities_eliminated, 2);
assert_eq!(stats.constant_folding.binary_ops_folded, 0);
}
#[test]
fn test_fixed_point_detection() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let config = PipelineConfig::default().with_max_iterations(10);
let pipeline = OptimizationPipeline::with_config(config);
let (optimized, stats) = pipeline.optimize(&x);
assert_eq!(stats.total_iterations, 1);
assert!(stats.reached_fixed_point);
assert!(!stats.stopped_at_max_iterations);
assert_eq!(optimized, x);
}
#[test]
fn test_max_iterations_limit() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(x, TLExpr::Constant(0.0))));
let config = PipelineConfig::default().with_max_iterations(1);
let pipeline = OptimizationPipeline::with_config(config);
let (_, stats) = pipeline.optimize(&expr);
assert_eq!(stats.total_iterations, 1);
assert!(stats.stopped_at_max_iterations);
}
#[test]
fn test_aggressive_optimization() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::add(
TLExpr::negate(TLExpr::and(
TLExpr::negate(TLExpr::add(x.clone(), TLExpr::Constant(0.0))),
TLExpr::negate(TLExpr::mul(
TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
x,
)),
)),
TLExpr::mul(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
);
let config = PipelineConfig::aggressive();
let pipeline = OptimizationPipeline::with_config(config);
let (_, stats) = pipeline.optimize(&expr);
assert!(
stats.total_optimizations() >= 4,
"Expected at least 4 optimizations, got {}",
stats.total_optimizations()
);
assert!(stats.total_iterations >= 1);
}
#[test]
fn test_no_optimization() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::add(x.clone(), TLExpr::Constant(1.0));
let config = PipelineConfig::none();
let pipeline = OptimizationPipeline::with_config(config);
let (optimized, stats) = pipeline.optimize(&expr);
assert_eq!(optimized, expr);
assert_eq!(stats.total_optimizations(), 0);
}
#[test]
fn test_iteration_stats() {
let expr = TLExpr::add(
TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
TLExpr::Constant(0.0),
);
let pipeline = OptimizationPipeline::new();
let (_, stats) = pipeline.optimize(&expr);
assert!(!stats.iterations.is_empty());
assert!(stats.iterations[0].made_changes());
assert!(stats.iterations[0].total_optimizations() > 0);
}
#[test]
fn test_most_productive_iteration() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(
TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
TLExpr::mul(x, TLExpr::Constant(1.0)),
)));
let pipeline = OptimizationPipeline::new();
let (_, stats) = pipeline.optimize(&expr);
let (iter_idx, iter_stats) = stats.most_productive_iteration().unwrap();
assert!(iter_stats.total_optimizations() > 0);
assert!(iter_idx < stats.total_iterations);
}
#[test]
fn test_pipeline_display() {
let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
let pipeline = OptimizationPipeline::new();
let (_, stats) = pipeline.optimize(&expr);
let output = format!("{}", stats);
assert!(output.contains("Pipeline Statistics"));
assert!(output.contains("Iterations:"));
assert!(output.contains("Total optimizations:"));
}
#[test]
fn test_builder_pattern() {
let config = PipelineConfig::default()
.with_negation_opt(false)
.with_constant_folding(true)
.with_algebraic_simplification(false)
.with_max_iterations(5)
.with_stop_on_fixed_point(false);
assert!(!config.enable_negation_opt);
assert!(config.enable_constant_folding);
assert!(!config.enable_algebraic_simplification);
assert_eq!(config.max_iterations, 5);
assert!(!config.stop_on_fixed_point);
}
#[test]
fn test_complex_real_world_expression() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let max_val = TLExpr::pred("max", vec![]);
let temp = TLExpr::Constant(1.0);
let expr = TLExpr::exp(TLExpr::div(TLExpr::sub(x, max_val), temp));
let pipeline = OptimizationPipeline::new();
let (optimized, stats) = pipeline.optimize(&expr);
assert!(stats.algebraic.identities_eliminated > 0);
assert!(optimized != expr);
}
#[test]
fn test_dead_code_elimination_integration() {
let a = TLExpr::pred("a", vec![Term::var("i")]);
let b = TLExpr::pred("b", vec![Term::var("i")]);
let expr = TLExpr::IfThenElse {
condition: Box::new(TLExpr::Constant(1.0)), then_branch: Box::new(a.clone()),
else_branch: Box::new(b),
};
let pipeline = OptimizationPipeline::new();
let (optimized, stats) = pipeline.optimize(&expr);
assert!(stats.dead_code.branches_eliminated > 0);
assert!(matches!(optimized, TLExpr::Pred { .. }));
}
#[test]
fn test_all_passes_together() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let a = TLExpr::pred("a", vec![Term::var("i")]);
let b = TLExpr::pred("b", vec![Term::var("i")]);
let c = TLExpr::pred("c", vec![Term::var("i")]);
let expr = TLExpr::IfThenElse {
condition: Box::new(TLExpr::Constant(1.0)),
then_branch: Box::new(TLExpr::and(
TLExpr::negate(TLExpr::negate(TLExpr::add(
TLExpr::pow(x, TLExpr::Constant(2.0)),
TLExpr::Constant(0.0),
))),
TLExpr::add(
TLExpr::mul(a.clone(), b.clone()),
TLExpr::mul(a.clone(), c.clone()),
),
)),
else_branch: Box::new(TLExpr::Constant(0.0)),
};
let pipeline = OptimizationPipeline::new();
let (_, stats) = pipeline.optimize(&expr);
assert!(
stats.dead_code.branches_eliminated > 0,
"Dead code elimination should apply"
);
assert!(
stats.negation.double_negations_eliminated > 0,
"Negation optimization should apply"
);
assert!(
stats.algebraic.identities_eliminated > 0,
"Algebraic simplification should apply"
);
assert!(
stats.strength_reduction.power_reductions > 0,
"Strength reduction should apply"
);
assert!(
stats.distributivity.expressions_factored > 0,
"Distributivity should apply"
);
assert!(
stats.total_optimizations() >= 5,
"Should apply at least 5 optimizations"
);
}
}