use crate::error::EmlError;
use crate::lower::LoweredOp;
use tensorlogic_ir::{Pattern, RewriteRule, TLExpr, Term};
fn var_pred(i: usize) -> TLExpr {
let name = format!("x{i}");
TLExpr::pred(name.clone(), vec![Term::var(name)])
}
fn match_var_pred(expr: &TLExpr) -> Option<usize> {
if let TLExpr::Pred { name, args } = expr {
if args.len() == 1 {
if let Some(rest) = name.strip_prefix('x') {
if let Ok(idx) = rest.parse::<usize>() {
if let Term::Var(arg_name) = &args[0] {
if arg_name == name {
return Some(idx);
}
}
}
}
}
}
None
}
pub fn to_tlexpr(op: &LoweredOp) -> TLExpr {
match op {
LoweredOp::Const(v) => TLExpr::Constant(*v),
LoweredOp::NamedConst(nc) => TLExpr::Constant(nc.value()),
LoweredOp::Var(i) => var_pred(*i),
LoweredOp::Add(a, b) => TLExpr::add(to_tlexpr(a), to_tlexpr(b)),
LoweredOp::Sub(a, b) => TLExpr::sub(to_tlexpr(a), to_tlexpr(b)),
LoweredOp::Mul(a, b) => TLExpr::mul(to_tlexpr(a), to_tlexpr(b)),
LoweredOp::Div(a, b) => TLExpr::div(to_tlexpr(a), to_tlexpr(b)),
LoweredOp::Pow(a, b) => TLExpr::pow(to_tlexpr(a), to_tlexpr(b)),
LoweredOp::Exp(a) => TLExpr::exp(to_tlexpr(a)),
LoweredOp::Ln(a) => TLExpr::log(to_tlexpr(a)),
LoweredOp::Sin(a) => TLExpr::sin(to_tlexpr(a)),
LoweredOp::Cos(a) => TLExpr::cos(to_tlexpr(a)),
LoweredOp::Neg(a) => TLExpr::sub(TLExpr::Constant(0.0), to_tlexpr(a)),
LoweredOp::Tan(a) => TLExpr::tan(to_tlexpr(a)),
LoweredOp::Sinh(a) => {
let x = to_tlexpr(a);
let neg_x = TLExpr::sub(TLExpr::Constant(0.0), x.clone());
TLExpr::div(
TLExpr::sub(TLExpr::exp(x), TLExpr::exp(neg_x)),
TLExpr::Constant(2.0),
)
}
LoweredOp::Cosh(a) => {
let x = to_tlexpr(a);
let neg_x = TLExpr::sub(TLExpr::Constant(0.0), x.clone());
TLExpr::div(
TLExpr::add(TLExpr::exp(x), TLExpr::exp(neg_x)),
TLExpr::Constant(2.0),
)
}
LoweredOp::Tanh(a) => {
let x = to_tlexpr(a);
let neg_x = TLExpr::sub(TLExpr::Constant(0.0), x.clone());
let sinh_x = TLExpr::div(
TLExpr::sub(TLExpr::exp(x.clone()), TLExpr::exp(neg_x.clone())),
TLExpr::Constant(2.0),
);
let cosh_x = TLExpr::div(
TLExpr::add(TLExpr::exp(x), TLExpr::exp(neg_x)),
TLExpr::Constant(2.0),
);
TLExpr::div(sinh_x, cosh_x)
}
LoweredOp::Arcsinh(a) => {
let x = to_tlexpr(a);
let x_sq = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
let inner = TLExpr::add(x_sq, TLExpr::Constant(1.0));
TLExpr::log(TLExpr::add(x, TLExpr::sqrt(inner)))
}
LoweredOp::Arccosh(a) => {
let x = to_tlexpr(a);
let x_sq = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
let inner = TLExpr::sub(x_sq, TLExpr::Constant(1.0));
TLExpr::log(TLExpr::add(x, TLExpr::sqrt(inner)))
}
LoweredOp::Arctanh(a) => {
let x = to_tlexpr(a);
let one_plus_x = TLExpr::add(TLExpr::Constant(1.0), x.clone());
let one_minus_x = TLExpr::sub(TLExpr::Constant(1.0), x);
TLExpr::div(
TLExpr::sub(TLExpr::log(one_plus_x), TLExpr::log(one_minus_x)),
TLExpr::Constant(2.0),
)
}
LoweredOp::Arctan(_) | LoweredOp::Arcsin(_) | LoweredOp::Arccos(_) => {
TLExpr::Constant(f64::NAN)
}
}
}
pub fn from_tlexpr(expr: &TLExpr) -> Result<LoweredOp, EmlError> {
match expr {
TLExpr::Constant(v) => Ok(LoweredOp::Const(*v)),
TLExpr::Pred { name, args } => match match_var_pred(expr) {
Some(idx) => Ok(LoweredOp::Var(idx)),
None => Err(EmlError::UnsupportedTlExpr(format!(
"Pred {{ name: {name:?}, arity: {} }} does not match the \
OxiEML variable convention `x<usize>`",
args.len()
))),
},
TLExpr::Add(a, b) => Ok(LoweredOp::Add(
Box::new(from_tlexpr(a)?),
Box::new(from_tlexpr(b)?),
)),
TLExpr::Sub(a, b) => Ok(LoweredOp::Sub(
Box::new(from_tlexpr(a)?),
Box::new(from_tlexpr(b)?),
)),
TLExpr::Mul(a, b) => Ok(LoweredOp::Mul(
Box::new(from_tlexpr(a)?),
Box::new(from_tlexpr(b)?),
)),
TLExpr::Div(a, b) => Ok(LoweredOp::Div(
Box::new(from_tlexpr(a)?),
Box::new(from_tlexpr(b)?),
)),
TLExpr::Pow(a, b) => Ok(LoweredOp::Pow(
Box::new(from_tlexpr(a)?),
Box::new(from_tlexpr(b)?),
)),
TLExpr::Exp(a) => Ok(LoweredOp::Exp(Box::new(from_tlexpr(a)?))),
TLExpr::Log(a) => Ok(LoweredOp::Ln(Box::new(from_tlexpr(a)?))),
TLExpr::Sin(a) => Ok(LoweredOp::Sin(Box::new(from_tlexpr(a)?))),
TLExpr::Cos(a) => Ok(LoweredOp::Cos(Box::new(from_tlexpr(a)?))),
TLExpr::Tan(a) => Ok(LoweredOp::Tan(Box::new(from_tlexpr(a)?))),
other => Err(EmlError::UnsupportedTlExpr(describe_variant(other))),
}
}
fn describe_variant(expr: &TLExpr) -> String {
let tag = match expr {
TLExpr::Pred { .. } => "Pred",
TLExpr::And(_, _) => "And",
TLExpr::Or(_, _) => "Or",
TLExpr::Not(_) => "Not",
TLExpr::Exists { .. } => "Exists",
TLExpr::ForAll { .. } => "ForAll",
TLExpr::Imply(_, _) => "Imply",
TLExpr::Score(_) => "Score",
TLExpr::Add(_, _) => "Add",
TLExpr::Sub(_, _) => "Sub",
TLExpr::Mul(_, _) => "Mul",
TLExpr::Div(_, _) => "Div",
TLExpr::Pow(_, _) => "Pow",
TLExpr::Mod(_, _) => "Mod",
TLExpr::Min(_, _) => "Min",
TLExpr::Max(_, _) => "Max",
TLExpr::Abs(_) => "Abs",
TLExpr::Floor(_) => "Floor",
TLExpr::Ceil(_) => "Ceil",
TLExpr::Round(_) => "Round",
TLExpr::Sqrt(_) => "Sqrt",
TLExpr::Exp(_) => "Exp",
TLExpr::Log(_) => "Log",
TLExpr::Sin(_) => "Sin",
TLExpr::Cos(_) => "Cos",
TLExpr::Tan(_) => "Tan",
TLExpr::Eq(_, _) => "Eq",
TLExpr::Lt(_, _) => "Lt",
TLExpr::Gt(_, _) => "Gt",
TLExpr::Lte(_, _) => "Lte",
TLExpr::Gte(_, _) => "Gte",
TLExpr::IfThenElse { .. } => "IfThenElse",
TLExpr::Constant(_) => "Constant",
TLExpr::Aggregate { .. } => "Aggregate",
TLExpr::Let { .. } => "Let",
TLExpr::Box(_) => "Box",
TLExpr::Diamond(_) => "Diamond",
TLExpr::Next(_) => "Next",
TLExpr::Eventually(_) => "Eventually",
TLExpr::Always(_) => "Always",
TLExpr::Until { .. } => "Until",
TLExpr::TNorm { .. } => "TNorm",
TLExpr::TCoNorm { .. } => "TCoNorm",
TLExpr::FuzzyNot { .. } => "FuzzyNot",
TLExpr::FuzzyImplication { .. } => "FuzzyImplication",
TLExpr::SoftExists { .. } => "SoftExists",
TLExpr::SoftForAll { .. } => "SoftForAll",
TLExpr::WeightedRule { .. } => "WeightedRule",
TLExpr::ProbabilisticChoice { .. } => "ProbabilisticChoice",
TLExpr::Release { .. } => "Release",
TLExpr::WeakUntil { .. } => "WeakUntil",
TLExpr::StrongRelease { .. } => "StrongRelease",
TLExpr::Lambda { .. } => "Lambda",
TLExpr::Apply { .. } => "Apply",
TLExpr::SetMembership { .. } => "SetMembership",
TLExpr::SetUnion { .. } => "SetUnion",
TLExpr::SetIntersection { .. } => "SetIntersection",
TLExpr::SetDifference { .. } => "SetDifference",
TLExpr::SetCardinality { .. } => "SetCardinality",
TLExpr::EmptySet => "EmptySet",
TLExpr::SetComprehension { .. } => "SetComprehension",
TLExpr::CountingExists { .. } => "CountingExists",
TLExpr::CountingForAll { .. } => "CountingForAll",
TLExpr::ExactCount { .. } => "ExactCount",
TLExpr::Majority { .. } => "Majority",
TLExpr::LeastFixpoint { .. } => "LeastFixpoint",
TLExpr::GreatestFixpoint { .. } => "GreatestFixpoint",
TLExpr::Nominal { .. } => "Nominal",
TLExpr::At { .. } => "At",
TLExpr::Somewhere { .. } => "Somewhere",
TLExpr::Everywhere { .. } => "Everywhere",
TLExpr::AllDifferent { .. } => "AllDifferent",
TLExpr::GlobalCardinality { .. } => "GlobalCardinality",
TLExpr::Abducible { .. } => "Abducible",
TLExpr::Explain { .. } => "Explain",
TLExpr::SymbolLiteral(_) => "SymbolLiteral",
TLExpr::Match { .. } => "Match",
};
tag.to_string()
}
pub fn canonical_rewrite_rules() -> Vec<RewriteRule> {
fn bound(bindings: &std::collections::HashMap<String, TLExpr>, key: &str) -> TLExpr {
bindings
.get(key)
.cloned()
.unwrap_or(TLExpr::Constant(f64::NAN))
}
vec![
RewriteRule {
pattern: Pattern::exp(Pattern::log(Pattern::var("x"))),
template: |b| bound(b, "x"),
name: Some("exp_log_inverse".to_string()),
},
RewriteRule {
pattern: Pattern::log(Pattern::exp(Pattern::var("x"))),
template: |b| bound(b, "x"),
name: Some("log_exp_inverse".to_string()),
},
RewriteRule {
pattern: Pattern::neg(Pattern::neg(Pattern::var("x"))),
template: |b| bound(b, "x"),
name: Some("double_negation".to_string()),
},
RewriteRule {
pattern: Pattern::add(Pattern::constant(0.0), Pattern::var("x")),
template: |b| bound(b, "x"),
name: Some("zero_add_left".to_string()),
},
RewriteRule {
pattern: Pattern::add(Pattern::var("x"), Pattern::constant(0.0)),
template: |b| bound(b, "x"),
name: Some("zero_add_right".to_string()),
},
RewriteRule {
pattern: Pattern::mul(Pattern::var("x"), Pattern::constant(1.0)),
template: |b| bound(b, "x"),
name: Some("one_mul_right".to_string()),
},
RewriteRule {
pattern: Pattern::mul(Pattern::constant(1.0), Pattern::var("x")),
template: |b| bound(b, "x"),
name: Some("one_mul_left".to_string()),
},
RewriteRule {
pattern: Pattern::div(Pattern::var("x"), Pattern::constant(1.0)),
template: |b| bound(b, "x"),
name: Some("div_by_one".to_string()),
},
RewriteRule {
pattern: Pattern::pow(Pattern::var("_x"), Pattern::constant(0.0)),
template: |_b| TLExpr::Constant(1.0),
name: Some("pow_zero".to_string()),
},
RewriteRule {
pattern: Pattern::pow(Pattern::var("x"), Pattern::constant(1.0)),
template: |b| bound(b, "x"),
name: Some("pow_one".to_string()),
},
]
}
const CANONICAL_EPS: f64 = 1e-15;
fn is_const_zero(e: &TLExpr) -> bool {
matches!(e, TLExpr::Constant(c) if c.abs() < CANONICAL_EPS)
}
fn is_const_one(e: &TLExpr) -> bool {
matches!(e, TLExpr::Constant(c) if (c - 1.0).abs() < CANONICAL_EPS)
}
fn simplify_one_pass(expr: &TLExpr) -> TLExpr {
match expr {
TLExpr::Add(a, b) => {
let a = simplify_one_pass(a);
let b = simplify_one_pass(b);
if let (TLExpr::Constant(ac), TLExpr::Constant(bc)) = (&a, &b) {
return TLExpr::Constant(ac + bc);
}
if is_const_zero(&a) {
return b;
}
if is_const_zero(&b) {
return a;
}
TLExpr::add(a, b)
}
TLExpr::Sub(a, b) => {
let a = simplify_one_pass(a);
let b = simplify_one_pass(b);
if let (TLExpr::Constant(ac), TLExpr::Constant(bc)) = (&a, &b) {
return TLExpr::Constant(ac - bc);
}
if is_const_zero(&a) {
if let TLExpr::Sub(inner_a, inner_x) = &b {
if is_const_zero(inner_a) {
return (**inner_x).clone();
}
}
}
if is_const_zero(&b) {
return a;
}
TLExpr::sub(a, b)
}
TLExpr::Mul(a, b) => {
let a = simplify_one_pass(a);
let b = simplify_one_pass(b);
if let (TLExpr::Constant(ac), TLExpr::Constant(bc)) = (&a, &b) {
return TLExpr::Constant(ac * bc);
}
if is_const_zero(&a) || is_const_zero(&b) {
return TLExpr::Constant(0.0);
}
if is_const_one(&a) {
return b;
}
if is_const_one(&b) {
return a;
}
TLExpr::mul(a, b)
}
TLExpr::Div(a, b) => {
let a = simplify_one_pass(a);
let b = simplify_one_pass(b);
if let (TLExpr::Constant(ac), TLExpr::Constant(bc)) = (&a, &b) {
if bc.abs() > CANONICAL_EPS {
let v = ac / bc;
if v.is_finite() {
return TLExpr::Constant(v);
}
}
}
if is_const_one(&b) {
return a;
}
TLExpr::div(a, b)
}
TLExpr::Pow(a, b) => {
let a = simplify_one_pass(a);
let b = simplify_one_pass(b);
if is_const_zero(&b) {
return TLExpr::Constant(1.0);
}
if is_const_one(&b) {
return a;
}
if let (TLExpr::Constant(ac), TLExpr::Constant(bc)) = (&a, &b) {
let v = ac.powf(*bc);
if v.is_finite() {
return TLExpr::Constant(v);
}
}
TLExpr::pow(a, b)
}
TLExpr::Exp(a) => {
let a = simplify_one_pass(a);
if let TLExpr::Log(inner) = &a {
return (**inner).clone();
}
if let TLExpr::Constant(c) = &a {
let v = c.exp();
if v.is_finite() {
return TLExpr::Constant(v);
}
}
TLExpr::exp(a)
}
TLExpr::Log(a) => {
let a = simplify_one_pass(a);
if let TLExpr::Exp(inner) = &a {
return (**inner).clone();
}
if let TLExpr::Constant(c) = &a {
if *c > 0.0 {
let v = c.ln();
if v.is_finite() {
return TLExpr::Constant(v);
}
}
}
TLExpr::log(a)
}
TLExpr::Sin(a) => {
let a = simplify_one_pass(a);
if let TLExpr::Constant(c) = &a {
let v = c.sin();
if v.is_finite() {
return TLExpr::Constant(v);
}
}
TLExpr::sin(a)
}
TLExpr::Cos(a) => {
let a = simplify_one_pass(a);
if let TLExpr::Constant(c) = &a {
let v = c.cos();
if v.is_finite() {
return TLExpr::Constant(v);
}
}
TLExpr::cos(a)
}
_ => expr.clone(),
}
}
pub fn canonical_simplify(expr: &TLExpr) -> TLExpr {
let mut current = expr.clone();
loop {
let next = simplify_one_pass(¤t);
if current == next {
return next;
}
current = next;
}
}
pub fn formulas_to_tl_weighted_rules(
formulas: &[crate::symreg::DiscoveredFormula],
weights: &[f64],
) -> Result<Vec<TLExpr>, crate::EmlError> {
if formulas.len() != weights.len() {
return Err(crate::EmlError::DimensionMismatch(
formulas.len(),
weights.len(),
));
}
Ok(formulas
.iter()
.zip(weights.iter().copied())
.map(|(f, w)| f.to_tl_weighted_rule(w))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn const_roundtrip() {
let op = LoweredOp::Const(std::f64::consts::PI);
let tl = to_tlexpr(&op);
assert!(matches!(tl, TLExpr::Constant(v) if (v - std::f64::consts::PI).abs() < 1e-15));
let back = from_tlexpr(&tl).expect("const round-trip");
assert_eq!(op, back);
}
#[test]
fn var_pred_shape() {
let op = LoweredOp::Var(7);
let tl = to_tlexpr(&op);
let back = from_tlexpr(&tl).expect("var round-trip");
assert_eq!(op, back);
match tl {
TLExpr::Pred { name, args } => {
assert_eq!(name, "x7");
assert_eq!(args.len(), 1);
assert!(matches!(&args[0], Term::Var(n) if n == "x7"));
}
other => panic!("expected Pred, got {other:?}"),
}
}
#[test]
fn unsupported_variant_rejected() {
let tl = TLExpr::and(
TLExpr::pred("P", vec![Term::var("a")]),
TLExpr::pred("Q", vec![Term::var("b")]),
);
let err = from_tlexpr(&tl).expect_err("And must be rejected");
match err {
EmlError::UnsupportedTlExpr(desc) => {
assert!(desc.contains("And"), "got {desc}");
}
other => panic!("unexpected error variant {other:?}"),
}
}
#[test]
fn stray_pred_name_rejected() {
let tl = TLExpr::pred("foo", vec![Term::var("a")]);
let err = from_tlexpr(&tl).expect_err("non-x predicate must be rejected");
assert!(matches!(err, EmlError::UnsupportedTlExpr(_)));
}
#[test]
fn canonical_rewrite_rules_has_ten_rules() {
let rules = canonical_rewrite_rules();
assert_eq!(rules.len(), 10);
}
#[test]
fn canonical_simplify_exp_log_eliminates() {
let x = var_pred(0);
let expr = TLExpr::exp(TLExpr::log(x.clone()));
assert_eq!(canonical_simplify(&expr), x);
let expr2 = TLExpr::log(TLExpr::exp(x.clone()));
assert_eq!(canonical_simplify(&expr2), x);
}
#[test]
fn canonical_simplify_folds_const_arithmetic() {
let expr = TLExpr::sub(
TLExpr::mul(
TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
TLExpr::Constant(4.0),
),
TLExpr::Constant(1.0),
);
let simplified = canonical_simplify(&expr);
match simplified {
TLExpr::Constant(v) => assert!(
(v - 19.0).abs() < 1e-12,
"expected Constant(19.0), got Constant({v})"
),
other => panic!("expected folded Constant, got {other:?}"),
}
}
#[test]
fn canonical_simplify_removes_zero_add() {
let x = var_pred(0);
let a = TLExpr::add(x.clone(), TLExpr::Constant(0.0));
assert_eq!(canonical_simplify(&a), x);
let b = TLExpr::add(TLExpr::Constant(0.0), x.clone());
assert_eq!(canonical_simplify(&b), x);
let c = TLExpr::sub(x.clone(), TLExpr::Constant(0.0));
assert_eq!(canonical_simplify(&c), x);
}
#[test]
fn canonical_simplify_removes_one_mul() {
let x = var_pred(0);
let a = TLExpr::mul(TLExpr::Constant(1.0), x.clone());
assert_eq!(canonical_simplify(&a), x);
let b = TLExpr::mul(x.clone(), TLExpr::Constant(1.0));
assert_eq!(canonical_simplify(&b), x);
let z = TLExpr::mul(x.clone(), TLExpr::Constant(0.0));
assert_eq!(canonical_simplify(&z), TLExpr::Constant(0.0));
let z2 = TLExpr::mul(TLExpr::Constant(0.0), x.clone());
assert_eq!(canonical_simplify(&z2), TLExpr::Constant(0.0));
let d = TLExpr::div(x.clone(), TLExpr::Constant(1.0));
assert_eq!(canonical_simplify(&d), x);
let p0 = TLExpr::pow(x.clone(), TLExpr::Constant(0.0));
assert_eq!(canonical_simplify(&p0), TLExpr::Constant(1.0));
let p1 = TLExpr::pow(x.clone(), TLExpr::Constant(1.0));
assert_eq!(canonical_simplify(&p1), x);
}
#[test]
fn canonical_simplify_double_neg() {
let x = var_pred(0);
let expr = TLExpr::sub(
TLExpr::Constant(0.0),
TLExpr::sub(TLExpr::Constant(0.0), x.clone()),
);
assert_eq!(canonical_simplify(&expr), x);
let single = TLExpr::sub(TLExpr::Constant(0.0), x.clone());
assert_eq!(canonical_simplify(&single), single);
}
#[test]
fn canonical_simplify_leaves_complex_untouched() {
let logic = TLExpr::and(
TLExpr::pred("P", vec![Term::var("a")]),
TLExpr::pred("Q", vec![Term::var("b")]),
);
assert_eq!(canonical_simplify(&logic), logic);
let p = TLExpr::pred("R", vec![Term::var("z")]);
assert_eq!(canonical_simplify(&p), p);
let x = var_pred(3);
assert_eq!(canonical_simplify(&x), x);
}
}