#![cfg(test)]
use tensorlogic_ir::TLExpr;
use super::{partially_evaluate, specialize, specialize_batch, PEConfig, PEEnv};
fn var(name: &str) -> TLExpr {
TLExpr::pred(name, vec![])
}
fn cnst(v: f64) -> TLExpr {
TLExpr::Constant(v)
}
#[test]
fn test_free_variable_stays_symbolic() {
let expr = var("x");
let env = PEEnv::new(); let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(
matches!(res.expr, TLExpr::Pred { ref name, ref args } if name == "x" && args.is_empty())
);
assert!(res.residual_vars.contains(&"x".to_string()));
}
#[test]
fn test_bound_variable_replaced() {
let expr = var("x");
let env = PEEnv::new().with_f64("x", 42.0);
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if (v - 42.0).abs() < 1e-12));
assert!(!res.residual_vars.contains(&"x".to_string()));
}
#[test]
fn test_fully_concrete_add() {
let expr = TLExpr::add(cnst(2.0), cnst(3.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if (v - 5.0).abs() < 1e-12));
}
#[test]
fn test_add_identity_x_plus_zero() {
let expr = TLExpr::add(var("x"), cnst(0.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Pred { ref name, .. } if name == "x"));
}
#[test]
fn test_mul_by_zero() {
let expr = TLExpr::mul(var("x"), cnst(0.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if v == 0.0));
}
#[test]
fn test_mul_by_one() {
let expr = TLExpr::mul(cnst(1.0), var("y"));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Pred { ref name, .. } if name == "y"));
}
#[test]
fn test_div_by_one() {
let expr = TLExpr::div(var("z"), cnst(1.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Pred { ref name, .. } if name == "z"));
}
#[test]
fn test_div_by_zero_no_fold() {
let expr = TLExpr::div(cnst(5.0), cnst(0.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Div(_, _)));
}
#[test]
fn test_and_short_circuit_false() {
let expr = TLExpr::and(cnst(0.0), var("x")); let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if v == 0.0));
assert!(res.stats.branches_pruned > 0);
}
#[test]
fn test_and_short_circuit_true() {
let expr = TLExpr::and(cnst(1.0), var("x")); let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Pred { ref name, .. } if name == "x"));
assert!(res.stats.branches_pruned > 0);
}
#[test]
fn test_or_short_circuit_true() {
let expr = TLExpr::or(cnst(1.0), var("x")); let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if v == 1.0));
assert!(res.stats.branches_pruned > 0);
}
#[test]
fn test_or_short_circuit_false() {
let expr = TLExpr::or(cnst(0.0), var("x")); let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Pred { ref name, .. } if name == "x"));
assert!(res.stats.branches_pruned > 0);
}
#[test]
fn test_not_true() {
let expr = TLExpr::negate(cnst(1.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if v == 0.0));
}
#[test]
fn test_not_false() {
let expr = TLExpr::negate(cnst(0.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if v == 1.0));
}
#[test]
fn test_let_concrete_inlines() {
let expr = TLExpr::Let {
var: "a".to_string(),
value: Box::new(cnst(5.0)),
body: Box::new(TLExpr::add(var("a"), cnst(3.0))),
};
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if (v - 8.0).abs() < 1e-12));
assert!(res.stats.lets_inlined > 0);
}
#[test]
fn test_let_symbolic_keeps_let() {
let expr = TLExpr::Let {
var: "a".to_string(),
value: Box::new(var("x")),
body: Box::new(TLExpr::add(var("a"), cnst(3.0))),
};
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Let { .. }));
assert!(res.residual_vars.contains(&"x".to_string()));
}
#[test]
fn test_nested_partial_eval() {
let expr = TLExpr::mul(
TLExpr::add(var("x"), cnst(2.0)),
TLExpr::add(cnst(3.0), cnst(4.0)),
);
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
if let TLExpr::Mul(_, rhs) = &res.expr {
assert!(matches!(rhs.as_ref(), TLExpr::Constant(v) if (v - 7.0).abs() < 1e-12));
} else {
panic!("Expected Mul, got {:?}", res.expr);
}
}
#[test]
fn test_stats_nodes_reduced() {
let expr = TLExpr::add(cnst(1.0), cnst(2.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(res.stats.nodes_reduced > 0);
}
#[test]
fn test_stats_branches_pruned() {
let expr = TLExpr::and(cnst(0.0), var("x"));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(res.stats.branches_pruned > 0);
}
#[test]
fn test_residual_vars() {
let expr = TLExpr::add(TLExpr::add(var("x"), var("y")), cnst(5.0));
let env = PEEnv::new().with_f64("x", 1.0);
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(!res.residual_vars.contains(&"x".to_string()));
assert!(res.residual_vars.contains(&"y".to_string()));
}
#[test]
fn test_specialize_fully_concrete() {
let expr = TLExpr::mul(var("x"), var("x"));
let bindings = vec![("x".to_string(), 4.0)];
let cfg = PEConfig::default();
let res = specialize(&expr, &bindings, &cfg);
assert!(matches!(res.expr, TLExpr::Constant(v) if (v - 16.0).abs() < 1e-12));
assert!(res.residual_vars.is_empty());
}
#[test]
fn test_specialize_batch() {
let expr = TLExpr::add(var("x"), cnst(10.0));
let binding_sets = vec![
vec![("x".to_string(), 1.0)],
vec![("x".to_string(), 2.0)],
vec![("x".to_string(), 3.0)],
];
let cfg = PEConfig::default();
let results = specialize_batch(&expr, &binding_sets, &cfg);
assert_eq!(results.len(), 3);
let vals: Vec<f64> = results
.iter()
.map(|r| {
if let TLExpr::Constant(v) = r.expr {
v
} else {
panic!("Expected Constant")
}
})
.collect();
assert!((vals[0] - 11.0).abs() < 1e-12);
assert!((vals[1] - 12.0).abs() < 1e-12);
assert!((vals[2] - 13.0).abs() < 1e-12);
}
#[test]
fn test_config_no_fold_arithmetic() {
let expr = TLExpr::add(cnst(2.0), cnst(3.0));
let env = PEEnv::new();
let cfg = PEConfig {
fold_arithmetic: false,
..PEConfig::default()
};
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Add(_, _)));
}
#[test]
fn test_config_no_prune_branches() {
let expr = TLExpr::and(cnst(0.0), var("x")); let env = PEEnv::new();
let cfg = PEConfig {
prune_branches: false,
fold_logic: false,
..PEConfig::default()
};
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::And(_, _)));
assert_eq!(res.stats.branches_pruned, 0);
}
#[test]
fn test_reduction_rate() {
let expr = TLExpr::add(cnst(1.0), cnst(2.0));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
let rate = res.stats.reduction_rate();
assert!(rate > 0.0 && rate <= 1.0);
}
#[test]
fn test_env_builder_api() {
let env = PEEnv::new()
.with_f64("a", 1.0)
.with_f64("b", 2.0)
.with_bool("flag", true);
assert_eq!(env.len(), 3);
assert!(!env.is_empty());
assert!(env.lookup("a").is_some());
assert!(env.lookup("missing").is_none());
}
#[test]
fn test_imply_folds() {
let expr = TLExpr::imply(cnst(1.0), var("x"));
let env = PEEnv::new();
let cfg = PEConfig::default();
let res = partially_evaluate(&expr, &env, &cfg);
assert!(matches!(res.expr, TLExpr::Pred { ref name, .. } if name == "x"));
}
#[test]
fn test_pow_special_cases() {
let cfg = PEConfig::default();
let env = PEEnv::new();
let r0 = partially_evaluate(&TLExpr::pow(var("x"), cnst(0.0)), &env, &cfg);
assert!(matches!(r0.expr, TLExpr::Constant(v) if (v - 1.0).abs() < 1e-12));
let r1 = partially_evaluate(&TLExpr::pow(var("x"), cnst(1.0)), &env, &cfg);
assert!(matches!(r1.expr, TLExpr::Pred { ref name, .. } if name == "x"));
let r2 = partially_evaluate(&TLExpr::pow(cnst(0.0), var("x")), &env, &cfg);
assert!(matches!(r2.expr, TLExpr::Constant(v) if v == 0.0));
let r3 = partially_evaluate(&TLExpr::pow(cnst(1.0), var("x")), &env, &cfg);
assert!(matches!(r3.expr, TLExpr::Constant(v) if (v - 1.0).abs() < 1e-12));
}