#![cfg(feature = "proptest_macro_integration")]
use dslcompile::compile_time::{constant, optimize_compile_time, var};
use dslcompile::final_tagless::{ASTEval, DirectEval};
use proptest::prelude::*;
#[derive(Debug, Clone)]
enum MacroExpr {
Var(usize),
Constant(f64),
Add(Box<MacroExpr>, Box<MacroExpr>),
Mul(Box<MacroExpr>, Box<MacroExpr>),
Sub(Box<MacroExpr>, Box<MacroExpr>),
Sin(Box<MacroExpr>),
Cos(Box<MacroExpr>),
Exp(Box<MacroExpr>),
Ln(Box<MacroExpr>),
}
impl MacroExpr {
fn eval_reference(&self, values: &[f64]) -> f64 {
match self {
MacroExpr::Var(idx) => values.get(*idx).copied().unwrap_or(0.0),
MacroExpr::Constant(c) => *c,
MacroExpr::Add(left, right) => {
left.eval_reference(values) + right.eval_reference(values)
}
MacroExpr::Mul(left, right) => {
left.eval_reference(values) * right.eval_reference(values)
}
MacroExpr::Sub(left, right) => {
left.eval_reference(values) - right.eval_reference(values)
}
MacroExpr::Sin(inner) => inner.eval_reference(values).sin(),
MacroExpr::Cos(inner) => inner.eval_reference(values).cos(),
MacroExpr::Exp(inner) => inner.eval_reference(values).exp(),
MacroExpr::Ln(inner) => inner.eval_reference(values).ln(),
}
}
fn has_valid_args(&self, values: &[f64]) -> bool {
match self {
MacroExpr::Var(_) | MacroExpr::Constant(_) => true,
MacroExpr::Add(l, r) | MacroExpr::Mul(l, r) | MacroExpr::Sub(l, r) => {
l.has_valid_args(values) && r.has_valid_args(values)
}
MacroExpr::Sin(inner) | MacroExpr::Cos(inner) | MacroExpr::Exp(inner) => {
inner.has_valid_args(values)
}
MacroExpr::Ln(inner) => {
inner.has_valid_args(values) && inner.eval_reference(values) > 0.0
}
}
}
}
fn arb_macro_expr(max_depth: usize, var_count: usize) -> impl Strategy<Value = MacroExpr> {
let leaf = prop_oneof![
(0..var_count).prop_map(MacroExpr::Var),
(-10.0_f64..10.0_f64).prop_map(MacroExpr::Constant),
Just(MacroExpr::Constant(0.0)),
Just(MacroExpr::Constant(1.0)),
Just(MacroExpr::Constant(-1.0)),
];
leaf.prop_recursive(max_depth, 32, 4, |inner| {
prop_oneof![
(inner.clone(), inner.clone())
.prop_map(|(l, r)| MacroExpr::Add(Box::new(l), Box::new(r))),
(inner.clone(), inner.clone())
.prop_map(|(l, r)| MacroExpr::Mul(Box::new(l), Box::new(r))),
(inner.clone(), inner.clone())
.prop_map(|(l, r)| MacroExpr::Sub(Box::new(l), Box::new(r))),
inner.clone().prop_map(|e| MacroExpr::Sin(Box::new(e))),
inner.clone().prop_map(|e| MacroExpr::Cos(Box::new(e))),
inner.clone().prop_map(|e| MacroExpr::Exp(Box::new(e))),
inner
.clone()
.prop_filter("positive for ln", |e| match e {
MacroExpr::Constant(c) => *c > 0.0,
MacroExpr::Exp(_) => true, _ => true, })
.prop_map(|e| MacroExpr::Ln(Box::new(e))),
]
})
}
fn test_optimization_patterns() -> Vec<(MacroExpr, &'static str)> {
vec![
(
MacroExpr::Ln(Box::new(MacroExpr::Exp(Box::new(MacroExpr::Var(0))))),
"ln(exp(x))",
),
(
MacroExpr::Exp(Box::new(MacroExpr::Ln(Box::new(MacroExpr::Var(0))))),
"exp(ln(x))",
),
(
MacroExpr::Add(
Box::new(MacroExpr::Var(0)),
Box::new(MacroExpr::Constant(0.0)),
),
"x + 0",
),
(
MacroExpr::Add(
Box::new(MacroExpr::Constant(0.0)),
Box::new(MacroExpr::Var(0)),
),
"0 + x",
),
(
MacroExpr::Mul(
Box::new(MacroExpr::Var(0)),
Box::new(MacroExpr::Constant(1.0)),
),
"x * 1",
),
(
MacroExpr::Mul(
Box::new(MacroExpr::Constant(1.0)),
Box::new(MacroExpr::Var(0)),
),
"1 * x",
),
(
MacroExpr::Mul(
Box::new(MacroExpr::Var(0)),
Box::new(MacroExpr::Constant(0.0)),
),
"x * 0",
),
(
MacroExpr::Mul(
Box::new(MacroExpr::Constant(0.0)),
Box::new(MacroExpr::Var(0)),
),
"0 * x",
),
]
}
macro_rules! test_1var_expr {
($expr:expr, $x:expr) => {{
let x = $x;
optimize_compile_time!($expr, [x])
}};
}
macro_rules! test_2var_expr {
($expr:expr, $x:expr, $y:expr) => {{
let x = $x;
let y = $y;
optimize_compile_time!($expr, [x, y])
}};
}
macro_rules! test_3var_expr {
($expr:expr, $x:expr, $y:expr, $z:expr) => {{
let x = $x;
let y = $y;
let z = $z;
optimize_compile_time!($expr, [x, y, z])
}};
}
fn test_macro_expression(expr: &MacroExpr, values: &[f64]) -> Result<(), String> {
if !expr.has_valid_args(values) {
return Ok(());
}
let reference_result = expr.eval_reference(values);
let macro_result = match (expr, values.len()) {
(_, 1) => {
let x = values[0];
match expr {
MacroExpr::Var(0) => test_1var_expr!(var::<0>(), x),
MacroExpr::Constant(c) => test_1var_expr!(constant(*c), x),
MacroExpr::Add(l, r)
if matches!(
(l.as_ref(), r.as_ref()),
(MacroExpr::Var(0), MacroExpr::Constant(c))
) =>
{
if let MacroExpr::Constant(c) = r.as_ref() {
test_1var_expr!(var::<0>().add(constant(*c)), x)
} else {
return Ok(());
}
}
MacroExpr::Add(l, r)
if matches!(
(l.as_ref(), r.as_ref()),
(MacroExpr::Constant(c), MacroExpr::Var(0))
) =>
{
if let MacroExpr::Constant(c) = l.as_ref() {
test_1var_expr!(constant(*c).add(var::<0>()), x)
} else {
return Ok(());
}
}
MacroExpr::Mul(l, r)
if matches!(
(l.as_ref(), r.as_ref()),
(MacroExpr::Var(0), MacroExpr::Constant(c))
) =>
{
if let MacroExpr::Constant(c) = r.as_ref() {
test_1var_expr!(var::<0>().mul(constant(*c)), x)
} else {
return Ok(());
}
}
MacroExpr::Mul(l, r)
if matches!(
(l.as_ref(), r.as_ref()),
(MacroExpr::Constant(c), MacroExpr::Var(0))
) =>
{
if let MacroExpr::Constant(c) = l.as_ref() {
test_1var_expr!(constant(*c).mul(var::<0>()), x)
} else {
return Ok(());
}
}
MacroExpr::Sin(inner) if matches!(inner.as_ref(), MacroExpr::Var(0)) => {
test_1var_expr!(var::<0>().sin(), x)
}
MacroExpr::Cos(inner) if matches!(inner.as_ref(), MacroExpr::Var(0)) => {
test_1var_expr!(var::<0>().cos(), x)
}
MacroExpr::Exp(inner) if matches!(inner.as_ref(), MacroExpr::Var(0)) => {
test_1var_expr!(var::<0>().exp(), x)
}
MacroExpr::Ln(inner) if matches!(inner.as_ref(), MacroExpr::Var(0)) => {
test_1var_expr!(var::<0>().ln(), x)
}
MacroExpr::Ln(inner)
if matches!(inner.as_ref(), MacroExpr::Exp(e))
&& matches!(e.as_ref(), MacroExpr::Var(0)) =>
{
test_1var_expr!(var::<0>().exp().ln(), x)
}
MacroExpr::Exp(inner)
if matches!(inner.as_ref(), MacroExpr::Ln(e))
&& matches!(e.as_ref(), MacroExpr::Var(0)) =>
{
test_1var_expr!(var::<0>().ln().exp(), x)
}
_ => return Ok(()), }
}
(_, 2) => {
let x = values[0];
let y = values[1];
match expr {
MacroExpr::Add(l, r)
if matches!(
(l.as_ref(), r.as_ref()),
(MacroExpr::Var(0), MacroExpr::Var(1))
) =>
{
test_2var_expr!(var::<0>().add(var::<1>()), x, y)
}
MacroExpr::Mul(l, r)
if matches!(
(l.as_ref(), r.as_ref()),
(MacroExpr::Var(0), MacroExpr::Var(1))
) =>
{
test_2var_expr!(var::<0>().mul(var::<1>()), x, y)
}
MacroExpr::Sub(l, r)
if matches!(
(l.as_ref(), r.as_ref()),
(MacroExpr::Var(0), MacroExpr::Var(1))
) =>
{
test_2var_expr!(var::<0>().sub(var::<1>()), x, y)
}
_ => return Ok(()), }
}
_ => return Ok(()), };
if !is_numerically_equivalent(reference_result, macro_result, 1e-10) {
return Err(format!(
"Macro result differs from reference: {} vs {} (diff: {})",
macro_result,
reference_result,
(macro_result - reference_result).abs()
));
}
Ok(())
}
fn is_numerically_equivalent(a: f64, b: f64, tolerance: f64) -> bool {
if a.is_nan() && b.is_nan() {
return true;
}
if a.is_infinite() && b.is_infinite() {
return a.signum() == b.signum();
}
if a.is_finite() && b.is_finite() {
return (a - b).abs() <= tolerance || (a - b).abs() <= tolerance * a.abs().max(b.abs());
}
false
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn prop_1var_macro_expressions(
expr in arb_macro_expr(3, 1),
x in -10.0_f64..10.0_f64
) {
let values = vec![x.abs() + 0.1]; test_macro_expression(&expr, &values)?;
}
#[test]
fn prop_2var_macro_expressions(
expr in arb_macro_expr(2, 2),
x in -5.0_f64..5.0_f64,
y in -5.0_f64..5.0_f64
) {
let values = vec![x, y];
test_macro_expression(&expr, &values)?;
}
#[test]
fn prop_simple_arithmetic(
x in -10.0_f64..10.0_f64,
y in -10.0_f64..10.0_f64,
c in -5.0_f64..5.0_f64
) {
let result1 = test_2var_expr!(var::<0>().add(var::<1>()), x, y);
let expected1 = x + y;
prop_assert!((result1 - expected1).abs() < 1e-10);
let result2 = test_2var_expr!(var::<0>().mul(var::<1>()), x, y);
let expected2 = x * y;
prop_assert!((result2 - expected2).abs() < 1e-10);
let result3 = test_1var_expr!(var::<0>().add(constant(c)), x);
let expected3 = x + c;
prop_assert!((result3 - expected3).abs() < 1e-10);
}
#[test]
fn prop_transcendental_functions(
x in 0.1_f64..10.0_f64 ) {
let result1 = test_1var_expr!(var::<0>().sin(), x);
let expected1 = x.sin();
prop_assert!((result1 - expected1).abs() < 1e-10);
let result2 = test_1var_expr!(var::<0>().cos(), x);
let expected2 = x.cos();
prop_assert!((result2 - expected2).abs() < 1e-10);
let result3 = test_1var_expr!(var::<0>().exp(), x);
let expected3 = x.exp();
prop_assert!((result3 - expected3).abs() < 1e-10);
let result4 = test_1var_expr!(var::<0>().ln(), x);
let expected4 = x.ln();
prop_assert!((result4 - expected4).abs() < 1e-10);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_known_optimization_patterns() {
let patterns = test_optimization_patterns();
for (expr, name) in patterns {
println!("Testing optimization pattern: {}", name);
let values = match name {
"exp(ln(x))" | "ln(exp(x))" => vec![2.5], _ => vec![3.14, 2.71, 1.41], };
let test_values = vec![values[0]];
test_macro_expression(&expr, &test_values).unwrap();
}
}
#[test]
fn test_complex_optimization_example() {
let x = 1.5_f64;
let y = 2.5_f64;
let z = 999.0_f64;
let result = test_3var_expr!(
var::<0>()
.exp()
.ln()
.add(var::<1>().mul(constant(1.0)))
.add(constant(0.0).mul(var::<2>())),
x,
y,
z
);
let expected = x + y; assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_identity_optimizations() {
let x = 3.14_f64;
let result1 = test_1var_expr!(var::<0>().add(constant(0.0)), x);
assert!((result1 - x).abs() < 1e-10);
let result2 = test_1var_expr!(var::<0>().mul(constant(1.0)), x);
assert!((result2 - x).abs() < 1e-10);
let result3 = test_1var_expr!(var::<0>().mul(constant(0.0)), x);
assert!(result3.abs() < 1e-10);
}
#[test]
fn test_inverse_function_optimizations() {
let x = 2.5_f64;
let result1 = test_1var_expr!(var::<0>().exp().ln(), x);
assert!((result1 - x).abs() < 1e-10);
let result2 = test_1var_expr!(var::<0>().ln().exp(), x);
assert!((result2 - x).abs() < 1e-10);
}
}