use criterion::{Criterion, criterion_group, criterion_main};
use std::hint::black_box;
use dslcompile::OptimizationConfig;
use dslcompile::backends::cranelift::CraneliftCompiler;
use dslcompile::final_tagless::{ASTMathExpr, DirectEval, VariableRegistry};
use dslcompile::prelude::*;
use dslcompile::symbolic::symbolic::SymbolicOptimizer;
fn create_complex_expression() -> ASTRepr<f64> {
let math = MathBuilder::new();
let x = math.var();
let y = math.var();
let _simple_expr = math.constant(2.0) * &x + &y;
let _medium_expr = &x * &y + x.clone().sin();
let result: TypedBuilderExpr<f64> = &x * x.clone().pow(math.constant(2.0)) + y.exp();
result.into_ast()
}
fn bench_optimization_performance(c: &mut Criterion) {
let mut group = c.benchmark_group("optimization_performance");
let complex_expr = create_complex_expression();
let mut basic_optimizer = SymbolicOptimizer::new().unwrap();
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
config.constant_folding = true;
config.aggressive = true;
let mut advanced_optimizer = SymbolicOptimizer::with_config(config).unwrap();
let basic_optimized = basic_optimizer.optimize(&complex_expr).unwrap();
let advanced_optimized = advanced_optimizer.optimize(&complex_expr).unwrap();
println!("\nð Expression Analysis:");
println!(
"Original expression operations: {}",
complex_expr.count_operations()
);
println!(
"Basic optimized operations: {}",
basic_optimized.count_operations()
);
println!(
"Advanced optimized operations: {}",
advanced_optimized.count_operations()
);
println!("\nð Optimization Results:");
println!("Original: {complex_expr:?}");
println!("Advanced: {advanced_optimized:?}");
let x = 2.5;
let y = 1.8;
let original_result = DirectEval::eval_two_vars(&complex_expr, x, y);
let optimized_result = DirectEval::eval_two_vars(&advanced_optimized, x, y);
println!("\nâ
Correctness Check:");
println!("Original result: {original_result}");
println!("Optimized result: {optimized_result}");
println!("Difference: {}", (original_result - optimized_result).abs());
group.bench_function("original_expression", |b| {
b.iter(|| DirectEval::eval_two_vars(black_box(&complex_expr), black_box(x), black_box(y)));
});
group.bench_function("basic_optimized", |b| {
b.iter(|| {
DirectEval::eval_two_vars(black_box(&basic_optimized), black_box(x), black_box(y))
});
});
group.bench_function("advanced_optimized", |b| {
b.iter(|| {
DirectEval::eval_two_vars(black_box(&advanced_optimized), black_box(x), black_box(y))
});
});
#[cfg(feature = "cranelift")]
{
let mut jit_compiler = CraneliftCompiler::new_default().unwrap();
let registry = VariableRegistry::for_expression(&advanced_optimized);
let jit_func = jit_compiler
.compile_expression(&advanced_optimized, ®istry)
.unwrap();
group.bench_function("cranelift_jit", |b| {
b.iter(|| jit_func.call(&[black_box(x), black_box(y)]));
});
let mut jit_compiler = CraneliftCompiler::new_default().unwrap();
let registry = VariableRegistry::for_expression(&advanced_optimized);
let jit_func = jit_compiler
.compile_expression(&advanced_optimized, ®istry)
.unwrap();
group.bench_function("precompiled_jit", |b| {
b.iter(|| jit_func.call(&[black_box(x), black_box(y)]));
});
println!("\nð§ JIT Compilation Stats:");
println!("Compilation time: {} Ξs", jit_func.metadata().compilation_time_ms);
println!("Expression complexity: {} operations", jit_func.metadata().expression_complexity);
}
group.finish();
}
fn bench_optimization_tradeoff(c: &mut Criterion) {
let mut group = c.benchmark_group("optimization_tradeoff");
let complex_expr = create_complex_expression();
group.bench_function("optimization_time", |b| {
b.iter(|| {
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
config.constant_folding = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
optimizer.optimize(black_box(&complex_expr)).unwrap()
});
});
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
config.constant_folding = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let optimized = optimizer.optimize(&complex_expr).unwrap();
let x = 2.5;
let y = 1.8;
let original_time = std::time::Instant::now();
for _ in 0..10000 {
DirectEval::eval_two_vars(&complex_expr, x, y);
}
let original_duration = original_time.elapsed();
let optimized_time = std::time::Instant::now();
for _ in 0..10000 {
DirectEval::eval_two_vars(&optimized, x, y);
}
let optimized_duration = optimized_time.elapsed();
#[cfg(feature = "cranelift")]
{
let mut jit_compiler = CraneliftCompiler::new_default().unwrap();
let registry = VariableRegistry::for_expression(&optimized);
let jit_func = jit_compiler
.compile_expression(&optimized, ®istry)
.unwrap();
let jit_time = std::time::Instant::now();
for _ in 0..10000 {
jit_func.call(&[black_box(x), black_box(y)]);
}
let jit_duration = jit_time.elapsed();
let speedup_opt =
original_duration.as_nanos() as f64 / optimized_duration.as_nanos() as f64;
let speedup_jit = original_duration.as_nanos() as f64 / jit_duration.as_nanos() as f64;
let jit_vs_opt = optimized_duration.as_nanos() as f64 / jit_duration.as_nanos() as f64;
println!("\nð Performance Analysis:");
println!("Original (10k evals): {original_duration:?}");
println!("Optimized (10k evals): {optimized_duration:?}");
println!("JIT (10k evals): {jit_duration:?}");
println!("Optimization speedup: {speedup_opt:.2}x");
println!("JIT speedup: {speedup_jit:.2}x");
println!("JIT vs Optimized: {jit_vs_opt:.2}x");
}
#[cfg(not(feature = "cranelift"))]
{
let speedup_opt =
original_duration.as_nanos() as f64 / optimized_duration.as_nanos() as f64;
println!("\nð Performance Analysis:");
println!("Original (10k evals): {original_duration:?}");
println!("Optimized (10k evals): {optimized_duration:?}");
println!("Optimization speedup: {speedup_opt:.2}x");
}
group.finish();
}
criterion_group!(
benches,
bench_optimization_performance,
bench_optimization_tradeoff
);
criterion_main!(benches);