use dslcompile::ast::ASTRepr;
use dslcompile::ast::normalization::{count_operations, denormalize, is_canonical, normalize};
use dslcompile::final_tagless::{ASTEval, ASTMathExpr};
#[cfg(feature = "optimization")]
use dslcompile::symbolic::native_egglog::optimize_with_native_egglog;
#[test]
fn test_basic_subtraction_normalization() {
let expr = ASTRepr::Sub(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::<f64>::Variable(1)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
match normalized {
ASTRepr::Add(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
match right.as_ref() {
ASTRepr::Neg(inner) => {
assert!(matches!(inner.as_ref(), ASTRepr::Variable(1)));
}
_ => panic!("Expected Neg operation in normalized subtraction"),
}
}
_ => panic!("Expected Add operation after normalization"),
}
}
#[test]
fn test_basic_division_normalization() {
let expr = ASTRepr::Div(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::<f64>::Variable(1)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
match normalized {
ASTRepr::Mul(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
match right.as_ref() {
ASTRepr::Pow(base, exp) => {
assert!(matches!(base.as_ref(), ASTRepr::Variable(1)));
match exp.as_ref() {
ASTRepr::Constant(val) => {
assert!((*val - (-1.0_f64)).abs() < 1e-12);
}
_ => panic!("Expected Constant(-1.0) in power exponent"),
}
}
_ => panic!("Expected Pow operation in normalized division"),
}
}
_ => panic!("Expected Mul operation after normalization"),
}
}
#[test]
fn test_complex_expression_normalization() {
let expr = ASTRepr::Div(
Box::new(ASTRepr::Sub(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::<f64>::Variable(1)),
)),
Box::new(ASTRepr::Add(
Box::new(ASTRepr::<f64>::Variable(2)),
Box::new(ASTRepr::<f64>::Variable(3)),
)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
let (add_orig, mul_orig, sub_orig, div_orig) = count_operations(&expr);
let (add_norm, mul_norm, sub_norm, div_norm) = count_operations(&normalized);
assert_eq!(add_orig, 1);
assert_eq!(mul_orig, 0);
assert_eq!(sub_orig, 1);
assert_eq!(div_orig, 1);
assert!(add_norm > add_orig);
assert!(mul_norm > mul_orig);
assert_eq!(sub_norm, 0);
assert_eq!(div_norm, 0);
}
#[test]
fn test_denormalization_roundtrip() {
let original = ASTRepr::Sub(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::<f64>::Variable(1)),
);
let normalized = normalize(&original);
let denormalized = denormalize(&normalized);
match denormalized {
ASTRepr::Sub(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
assert!(matches!(right.as_ref(), ASTRepr::Variable(1)));
}
_ => panic!("Expected Sub operation after denormalization"),
}
}
#[test]
fn test_division_denormalization_roundtrip() {
let original = ASTRepr::Div(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::<f64>::Variable(1)),
);
let normalized = normalize(&original);
let denormalized = denormalize(&normalized);
match denormalized {
ASTRepr::Div(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Variable(0)));
assert!(matches!(right.as_ref(), ASTRepr::Variable(1)));
}
_ => panic!("Expected Div operation after denormalization"),
}
}
#[test]
fn test_nested_operations_normalization() {
let expr = ASTRepr::Sub(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::Div(
Box::new(ASTRepr::<f64>::Variable(1)),
Box::new(ASTRepr::<f64>::Variable(2)),
)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
let (_, _, sub_count, div_count) = count_operations(&normalized);
assert_eq!(sub_count, 0);
assert_eq!(div_count, 0);
}
#[test]
fn test_transcendental_functions_preserved() {
let expr = ASTRepr::Sub(
Box::new(ASTRepr::Sin(Box::new(ASTRepr::<f64>::Variable(0)))),
Box::new(ASTRepr::Ln(Box::new(ASTRepr::<f64>::Variable(1)))),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
match normalized {
ASTRepr::Add(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Sin(_)));
match right.as_ref() {
ASTRepr::Neg(inner) => {
assert!(matches!(inner.as_ref(), ASTRepr::Ln(_)));
}
_ => panic!("Expected Neg(Ln(_)) in normalized expression"),
}
}
_ => panic!("Expected Add operation after normalization"),
}
}
#[test]
fn test_constants_preserved() {
let expr = ASTRepr::Sub(
Box::new(ASTRepr::Constant(5.0_f64)),
Box::new(ASTRepr::Constant(3.0_f64)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
match normalized {
ASTRepr::Add(left, right) => {
assert!(
matches!(left.as_ref(), ASTRepr::Constant(val) if (*val - 5.0_f64).abs() < 1e-12)
);
match right.as_ref() {
ASTRepr::Neg(inner) => {
assert!(
matches!(inner.as_ref(), ASTRepr::Constant(val) if (*val - 3.0_f64).abs() < 1e-12)
);
}
_ => panic!("Expected Neg(Constant(3.0)) in normalized expression"),
}
}
_ => panic!("Expected Add operation after normalization"),
}
}
#[test]
fn test_already_canonical_expressions() {
let expr = ASTRepr::Add(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::Neg(Box::new(ASTRepr::<f64>::Variable(1)))),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
assert!(is_canonical(&expr));
let (add1, mul1, sub1, div1) = count_operations(&expr);
let (add2, mul2, sub2, div2) = count_operations(&normalized);
assert_eq!(add1, add2);
assert_eq!(mul1, mul2);
assert_eq!(sub1, sub2);
assert_eq!(div1, div2);
}
#[test]
fn test_operation_count_reduction() {
let expr = ASTRepr::Add(
Box::new(ASTRepr::Sub(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::<f64>::Variable(1)),
)),
Box::new(ASTRepr::Div(
Box::new(ASTRepr::<f64>::Variable(2)),
Box::new(ASTRepr::<f64>::Variable(3)),
)),
);
let normalized = normalize(&expr);
let (add_orig, mul_orig, sub_orig, div_orig) = count_operations(&expr);
let (add_norm, mul_norm, sub_norm, div_norm) = count_operations(&normalized);
assert!(sub_orig > 0);
assert!(div_orig > 0);
assert_eq!(sub_norm, 0);
assert_eq!(div_norm, 0);
assert!(add_norm > add_orig);
assert!(mul_norm > mul_orig);
}
#[test]
fn test_ergonomic_builder_integration() {
let expr: ASTRepr<f64> = ASTEval::sub(ASTEval::var(0), ASTEval::var(1));
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
let denormalized = denormalize(&normalized);
assert!(!is_canonical(&denormalized)); }
#[cfg(feature = "optimization")]
#[test]
fn test_native_egglog_integration_with_normalization() {
let expr = ASTRepr::Add(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::Constant(0.0_f64)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
#[cfg(feature = "optimization")]
{
use dslcompile::symbolic::native_egglog::NativeEgglogOptimizer;
let optimizer_result = NativeEgglogOptimizer::new();
match optimizer_result {
Ok(mut optimizer) => {
println!("Native egglog optimizer created successfully");
let simple_expr = ASTRepr::<f64>::Variable(0);
let result = optimizer.optimize(&simple_expr);
match result {
Ok(optimized) => {
println!("Simple optimization succeeded: {optimized:?}");
}
Err(e) => {
println!("Simple optimization failed (acceptable): {e}");
}
}
let result2 = optimize_with_native_egglog(&simple_expr);
match result2 {
Ok(optimized) => {
println!("Helper function optimization succeeded: {optimized:?}");
}
Err(e) => {
println!("Helper function optimization failed (acceptable): {e}");
}
}
}
Err(e) => {
println!("Native egglog optimizer creation failed (acceptable in test): {e}");
}
}
}
#[cfg(not(feature = "optimization"))]
{
println!("Optimization feature disabled, testing normalization only");
let denormalized = denormalize(&normalized);
assert!(!is_canonical(&denormalized));
}
}
#[test]
fn test_complex_mixed_operations() {
let expr = ASTRepr::Div(
Box::new(ASTRepr::Sub(
Box::new(ASTRepr::Mul(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::<f64>::Variable(1)),
)),
Box::new(ASTRepr::Constant(2.0_f64)),
)),
Box::new(ASTRepr::Add(
Box::new(ASTRepr::<f64>::Variable(2)),
Box::new(ASTRepr::Constant(1.0_f64)),
)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
let denormalized = denormalize(&normalized);
assert!(!is_canonical(&denormalized));
}
#[test]
fn test_power_operations_preserved() {
let expr = ASTRepr::Div(
Box::new(ASTRepr::Pow(
Box::new(ASTRepr::<f64>::Variable(0)),
Box::new(ASTRepr::Constant(2.0_f64)),
)),
Box::new(ASTRepr::<f64>::Variable(1)),
);
let normalized = normalize(&expr);
assert!(is_canonical(&normalized));
let (_, _, _, div_count) = count_operations(&normalized);
assert_eq!(div_count, 0);
match normalized {
ASTRepr::Mul(left, right) => {
assert!(matches!(left.as_ref(), ASTRepr::Pow(_, _)));
match right.as_ref() {
ASTRepr::Pow(base, exp) => {
assert!(matches!(base.as_ref(), ASTRepr::Variable(1)));
assert!(
matches!(exp.as_ref(), ASTRepr::Constant(val) if (*val - (-1.0_f64)).abs() < 1e-12)
);
}
_ => panic!("Expected Pow operation for reciprocal"),
}
}
_ => panic!("Expected Mul operation after normalization"),
}
}