use dslcompile::final_tagless::{ASTEval, ASTMathExpr, ASTRepr};
use dslcompile::{CompilationStrategy, OptimizationConfig, RustOptLevel, SymbolicOptimizer};
use std::path::PathBuf;
fn var(index: usize) -> ASTRepr<f64> {
ASTEval::var(index)
}
fn constant(value: f64) -> ASTRepr<f64> {
ASTEval::constant(value)
}
fn add(left: ASTRepr<f64>, right: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::add(left, right)
}
fn mul(left: ASTRepr<f64>, right: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::mul(left, right)
}
fn sub(left: ASTRepr<f64>, right: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::sub(left, right)
}
fn pow(base: ASTRepr<f64>, exp: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::pow(base, exp)
}
fn sin(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::sin(x)
}
fn cos(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::cos(x)
}
fn exp(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::exp(x)
}
fn log(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::ln(x)
}
fn create_var_expr(index: usize) -> ASTRepr<f64> {
ASTEval::var(index)
}
#[test]
fn test_current_optimization_capabilities() {
println!("๐งช Testing current optimization capabilities...");
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let x = var(0);
let zero = constant(0.0);
let expr = add(x, zero);
println!("Original expression: {expr:?}");
let optimized = optimizer.optimize(&expr).unwrap();
println!("Optimized expression: {optimized:?}");
println!("Optimization completed successfully");
}
#[test]
fn test_rust_code_generation() {
println!("๐ฆ Testing Rust code generation...");
let optimizer = SymbolicOptimizer::new().unwrap();
let x = var(0);
let two = constant(2.0);
let one = constant(1.0);
let expr = add(add(pow(x, two), mul(constant(2.0), var(0))), one);
let rust_code = optimizer.generate_rust_source(&expr, "poly_func").unwrap();
println!("Generated Rust code:\n{rust_code}");
assert!(rust_code.contains("#[no_mangle]"));
assert!(rust_code.contains("pub extern \"C\" fn poly_func"));
assert!(
rust_code.contains("x * x")
|| rust_code.contains("x.powf(2")
|| rust_code.contains("x.powi(2")
);
assert!(
rust_code.contains("2.0 * x")
|| rust_code.contains("2.0_f64 * x")
|| rust_code.contains("2 * x")
);
}
#[test]
fn test_compilation_strategy_selection() {
println!("โ๏ธ Testing compilation strategy selection...");
let mut optimizer = SymbolicOptimizer::new().unwrap();
let simple_expr = add(var(0), constant(1.0));
let approach = optimizer.choose_compilation_approach(&simple_expr, "simple");
println!("Simple expression approach: {approach:?}");
optimizer.set_compilation_strategy(CompilationStrategy::Adaptive {
call_threshold: 3,
complexity_threshold: 10,
});
for i in 0..5 {
let approach = optimizer.choose_compilation_approach(&simple_expr, "adaptive_test");
println!("Call {i}: {approach:?}");
optimizer.record_execution("adaptive_test", 1000);
}
let stats = optimizer.get_expression_stats();
println!("Expression stats: {stats:?}");
}
#[test]
fn test_hot_loading_strategy() {
println!("๐ฅ Testing hot-loading compilation strategy...");
let strategy = CompilationStrategy::HotLoadRust {
source_dir: PathBuf::from("/tmp/dslcompile_test_sources"),
lib_dir: PathBuf::from("/tmp/dslcompile_test_libs"),
opt_level: RustOptLevel::O2,
};
let optimizer = SymbolicOptimizer::with_strategy(strategy).unwrap();
let x = var(0);
let y = var(1);
let two = constant(2.0);
let expr = sin(add(mul(two, x), cos(y)));
let rust_code = optimizer
.generate_rust_source(&expr, "complex_func")
.unwrap();
println!("Complex function Rust code:\n{rust_code}");
assert!(rust_code.contains("sin"));
assert!(rust_code.contains("cos"));
assert!(rust_code.contains("complex_func"));
}
#[test]
fn test_algebraic_optimizations() {
println!("๐งฎ Testing algebraic optimizations...");
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let a = var(0);
let b = var(1);
let exp_expr = mul(exp(a), exp(b));
let optimized_exp = optimizer.optimize(&exp_expr).unwrap();
println!("exp(a) * exp(b) optimized to: {optimized_exp:?}");
let x = var(2);
let log_exp_expr = log(exp(x));
let optimized_log_exp = optimizer.optimize(&log_exp_expr).unwrap();
println!("log(exp(x)) optimized to: {optimized_log_exp:?}");
let x = var(3);
let a = var(4);
let b = var(5);
let power_expr = mul(pow(x.clone(), a), pow(x, b));
let optimized_power = optimizer.optimize(&power_expr).unwrap();
println!("x^a * x^b optimized to: {optimized_power:?}");
}
#[test]
fn test_end_to_end_optimization_and_generation() {
println!("๐ฏ Testing end-to-end optimization and Rust generation...");
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
config.constant_folding = true;
config.aggressive = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let x = var(6);
let y = var(7);
let zero = constant(0.0);
let one = constant(1.0);
let complex_expr = add(mul(add(x, zero), one), sub(log(exp(y)), constant(0.0)));
println!("Original complex expression: {complex_expr:?}");
let optimized = optimizer.optimize(&complex_expr).unwrap();
println!("Optimized expression: {optimized:?}");
let rust_code = optimizer
.generate_rust_source(&optimized, "optimized_func")
.unwrap();
println!("Generated Rust code for optimized expression:\n{rust_code}");
println!("End-to-end optimization and generation completed successfully");
assert!(rust_code.contains("optimized_func"));
}
#[cfg(feature = "ad_trait")]
#[test]
fn test_autodiff_integration() {
println!("๐ฌ Testing autodiff integration with symbolic optimization...");
use dslcompile::symbolic::symbolic_ad::convenience;
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let x = var(8);
let y = var(9);
let zero = constant(0.0);
let one = constant(1.0);
let expr = add(mul(add(x, zero), one), log(exp(y)));
println!("Original expression: {expr:?}");
let optimized = optimizer.optimize(&expr).unwrap();
println!("Optimized expression: {optimized:?}");
let gradient = convenience::gradient(&optimized, &["0", "1"]).unwrap();
println!("Gradient computed");
assert!(gradient.contains_key("0") || gradient.contains_key("8")); assert!(gradient.contains_key("1") || gradient.contains_key("9"));
println!("โ
Autodiff integration test passed!");
}