use mathcompile::final_tagless::{ASTEval, ASTMathExpr, ASTRepr};
use mathcompile::symbolic::{
CompilationStrategy, OptimizationConfig, RustOptLevel, SymbolicOptimizer,
};
use std::path::PathBuf;
fn var(name: &str) -> ASTRepr<f64> {
ASTEval::var_by_name(name)
}
fn constant(value: f64) -> ASTRepr<f64> {
ASTEval::constant(value)
}
fn add(left: ASTRepr<f64>, right: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::add(left, right)
}
fn mul(left: ASTRepr<f64>, right: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::mul(left, right)
}
fn sub(left: ASTRepr<f64>, right: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::sub(left, right)
}
fn pow(base: ASTRepr<f64>, exp: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::pow(base, exp)
}
fn sin(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::sin(x)
}
fn cos(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::cos(x)
}
fn exp(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::exp(x)
}
fn log(x: ASTRepr<f64>) -> ASTRepr<f64> {
ASTEval::ln(x)
}
fn create_var_expr(name: &str) -> ASTRepr<f64> {
ASTEval::var_by_name(name)
}
#[test]
fn test_current_optimization_capabilities() {
println!("๐งช Testing current optimization capabilities...");
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let x = var("x");
let y = var("y");
let zero = constant(0.0);
let one = constant(1.0);
let expr = add(mul(add(x, zero), one), log(exp(y)));
println!("Original expression: {expr:?}");
let optimized = optimizer.optimize(&expr).unwrap();
println!("Optimized expression: {optimized:?}");
assert_ne!(format!("{expr:?}"), format!("{optimized:?}"));
}
#[test]
fn test_rust_code_generation() {
println!("๐ฆ Testing Rust code generation...");
let optimizer = SymbolicOptimizer::new().unwrap();
let x = var("x");
let two = constant(2.0);
let one = constant(1.0);
let expr = add(add(pow(x, two), mul(constant(2.0), var("x"))), one);
let rust_code = optimizer.generate_rust_source(&expr, "quadratic").unwrap();
println!("Generated Rust code:\n{rust_code}");
assert!(rust_code.contains("#[no_mangle]"));
assert!(rust_code.contains("pub extern \"C\" fn quadratic"));
assert!(rust_code.contains("x * x") || rust_code.contains("x.powf(2"));
assert!(
rust_code.contains("2.0 * x")
|| rust_code.contains("2.0_f64 * x")
|| rust_code.contains("2 * x")
);
}
#[test]
fn test_compilation_strategy_selection() {
println!("โ๏ธ Testing compilation strategy selection...");
let mut optimizer = SymbolicOptimizer::new().unwrap();
let simple_expr = add(var("x"), constant(1.0));
let approach = optimizer.choose_compilation_approach(&simple_expr, "simple");
println!("Simple expression approach: {approach:?}");
optimizer.set_compilation_strategy(CompilationStrategy::Adaptive {
call_threshold: 3,
complexity_threshold: 10,
});
for i in 0..5 {
let approach = optimizer.choose_compilation_approach(&simple_expr, "adaptive_test");
println!("Call {i}: {approach:?}");
optimizer.record_execution("adaptive_test", 1000);
}
let stats = optimizer.get_expression_stats();
println!("Expression stats: {stats:?}");
}
#[test]
fn test_hot_loading_strategy() {
println!("๐ฅ Testing hot-loading compilation strategy...");
let strategy = CompilationStrategy::HotLoadRust {
source_dir: PathBuf::from("/tmp/mathcompile_test_sources"),
lib_dir: PathBuf::from("/tmp/mathcompile_test_libs"),
opt_level: RustOptLevel::O2,
};
let optimizer = SymbolicOptimizer::with_strategy(strategy).unwrap();
let x = var("x");
let y = var("y");
let two = constant(2.0);
let expr = sin(add(mul(two, x), cos(y)));
let rust_code = optimizer
.generate_rust_source(&expr, "complex_func")
.unwrap();
println!("Complex function Rust code:\n{rust_code}");
assert!(rust_code.contains("sin"));
assert!(rust_code.contains("cos"));
assert!(rust_code.contains("complex_func"));
}
#[test]
fn test_algebraic_optimizations() {
println!("๐งฎ Testing algebraic optimizations...");
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let a = var("a");
let b = var("b");
let exp_expr = mul(exp(a), exp(b));
let optimized_exp = optimizer.optimize(&exp_expr).unwrap();
println!("exp(a) * exp(b) optimized to: {optimized_exp:?}");
let x = var("x");
let log_exp_expr = log(exp(x));
let optimized_log_exp = optimizer.optimize(&log_exp_expr).unwrap();
println!("log(exp(x)) optimized to: {optimized_log_exp:?}");
let x = var("x");
let a = var("a");
let b = var("b");
let power_expr = mul(pow(x, a), pow(var("x"), b));
let optimized_power = optimizer.optimize(&power_expr).unwrap();
println!("x^a * x^b optimized to: {optimized_power:?}");
}
#[test]
fn test_end_to_end_optimization_and_generation() {
println!("๐ฏ Testing end-to-end optimization and Rust generation...");
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
config.constant_folding = true;
config.aggressive = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let x = var("x");
let y = var("y");
let zero = constant(0.0);
let one = constant(1.0);
let complex_expr = add(mul(add(x, zero), one), sub(log(exp(y)), constant(0.0)));
println!("Original complex expression: {complex_expr:?}");
let optimized = optimizer.optimize(&complex_expr).unwrap();
println!("Optimized expression: {optimized:?}");
let rust_code = optimizer
.generate_rust_source(&optimized, "optimized_func")
.unwrap();
println!("Generated Rust code for optimized expression:\n{rust_code}");
assert_ne!(format!("{complex_expr:?}"), format!("{optimized:?}"));
assert!(rust_code.contains("optimized_func"));
}
#[cfg(feature = "autodiff")]
#[test]
fn test_autodiff_integration() {
println!("๐ฌ Testing autodiff integration with symbolic optimization...");
use ad_trait::forward_ad::adfn::adfn;
use mathcompile::autodiff::{ForwardAD, convenience};
let mut config = OptimizationConfig::default();
config.egglog_optimization = true;
let mut optimizer = SymbolicOptimizer::with_config(config).unwrap();
let x = var("x");
let y = var("y");
let zero = constant(0.0);
let one = constant(1.0);
let expr = add(mul(add(x, zero), one), log(exp(y)));
println!("Original expression: {expr:?}");
let optimized = optimizer.optimize(&expr).unwrap();
println!("Optimized expression: {optimized:?}");
let forward_ad = ForwardAD::new();
let simple_func = |x: adfn<1>| {
let y = adfn::new(2.0, [0.0]); x + y
};
let (value, derivative) = forward_ad.differentiate(simple_func, 3.0).unwrap();
println!("f(3) = {value}, f'(3) = {derivative}");
assert!((value - 5.0).abs() < 1e-10);
assert!((derivative - 1.0).abs() < 1e-10);
let multi_var = |vars: &[f64]| {
let x = vars[0];
let y = vars[1];
x + y };
let gradient = convenience::gradient(multi_var, &[3.0, 2.0]).unwrap();
println!("Gradient: [{}, {}]", gradient[0], gradient[1]);
assert!((gradient[0] - 1.0).abs() < 1e-6); assert!((gradient[1] - 1.0).abs() < 1e-6);
println!("โ
Autodiff integration test passed!");
}