use crate::error::{DSLCompileError, Result};
use crate::final_tagless::ASTRepr;
use std::collections::HashMap;
#[cfg(feature = "optimization")]
use egglog::EGraph;
#[cfg(feature = "optimization")]
pub struct NativeEgglogOptimizer {
egraph: EGraph,
var_counter: usize,
expr_cache: HashMap<String, ASTRepr<f64>>,
}
#[cfg(feature = "optimization")]
impl NativeEgglogOptimizer {
pub fn new() -> Result<Self> {
let mut egraph = EGraph::default();
let program = Self::create_domain_aware_program();
egraph.parse_and_run_program(None, &program).map_err(|e| {
DSLCompileError::Optimization(format!(
"Failed to initialize native egglog with domain analysis: {e}"
))
})?;
Ok(Self {
egraph,
var_counter: 0,
expr_cache: HashMap::new(),
})
}
fn create_domain_aware_program() -> String {
r"
; ========================================
; CORE DATATYPES
; ========================================
(datatype Math
(Num f64)
(Var String)
(Add Math Math)
(Mul Math Math)
(Neg Math)
(Pow Math Math)
(Ln Math)
(Exp Math)
(Sin Math)
(Cos Math)
(Sqrt Math))
; ========================================
; INTERVAL DOMAIN FOR ABSTRACT INTERPRETATION
; ========================================
(datatype Interval
(IVal f64 f64) ; [lower, upper] bounds
(IBot) ; Bottom (empty interval)
(ITop)) ; Top (all reals)
; Interval analysis function
(function ival (Math) Interval :merge (ITop))
; ========================================
; DOMAIN PREDICATES
; ========================================
; Check if an expression is provably positive
(function ival-positive (Math) bool :merge false)
; Check if an expression is provably non-negative
(function ival-nonneg (Math) bool :merge false)
; Check if an expression is provably non-zero
(function ival-nonzero (Math) bool :merge false)
; ========================================
; BASIC INTERVAL ANALYSIS RULES
; ========================================
; Constants have singleton intervals
(rule ((= e (Num ?x)))
((set (ival e) (IVal ?x ?x))))
; Variables have top interval (unknown bounds)
(rule ((= e (Var ?name)))
((set (ival e) (ITop))))
; Positive constants are positive
(rule ((= e (Num ?x))
(> ?x 0.0))
((set (ival-positive e) true)))
; Non-negative constants are non-negative
(rule ((= e (Num ?x))
(>= ?x 0.0))
((set (ival-nonneg e) true)))
; Non-zero constants are non-zero
(rule ((= e (Num ?x))
(!= ?x 0.0))
((set (ival-nonzero e) true)))
; Exponential is always positive
(rule ((= e (Exp ?x)))
((set (ival-positive e) true)))
; Exponential is always non-negative
(rule ((= e (Exp ?x)))
((set (ival-nonneg e) true)))
; Exponential is always non-zero
(rule ((= e (Exp ?x)))
((set (ival-nonzero e) true)))
; Square root is always non-negative
(rule ((= e (Sqrt ?x)))
((set (ival-nonneg e) true)))
; Product of positive expressions is positive
(rule ((= e (Mul ?a ?b))
(= (ival-positive ?a) true)
(= (ival-positive ?b) true))
((set (ival-positive e) true)))
; ========================================
; BASIC MATHEMATICAL RULES
; ========================================
; Additive identity
(rewrite (Add a (Num 0.0)) a)
(rewrite (Add (Num 0.0) a) a)
; Multiplicative identity
(rewrite (Mul a (Num 1.0)) a)
(rewrite (Mul (Num 1.0) a) a)
; Multiplicative zero
(rewrite (Mul a (Num 0.0)) (Num 0.0))
(rewrite (Mul (Num 0.0) a) (Num 0.0))
; Power rules
(rewrite (Pow a (Num 0.0)) (Num 1.0))
(rewrite (Pow a (Num 1.0)) a)
; Commutativity
(rewrite (Add a b) (Add b a))
(rewrite (Mul a b) (Mul b a))
; Associativity for addition
(rewrite (Add (Add a b) c) (Add a (Add b c)))
; Associativity for multiplication
(rewrite (Mul (Mul a b) c) (Mul a (Mul b c)))
; ========================================
; DOMAIN-AWARE TRANSCENDENTAL RULES
; ========================================
; ln(exp(x)) = x (always safe)
(rewrite (Ln (Exp x)) x)
; exp(ln(x)) = x (only if x is positive)
(rule ((= e (Exp (Ln ?x)))
(= (ival-positive ?x) true))
((union e ?x)))
; ln(a * b) = ln(a) + ln(b) (only if both a and b are positive)
(rule ((= e (Ln (Mul ?a ?b)))
(= (ival-positive ?a) true)
(= (ival-positive ?b) true))
((union e (Add (Ln ?a) (Ln ?b)))))
; Special case: ln(exp(x) * exp(y)) = ln(exp(x)) + ln(exp(y)) = x + y
(rewrite (Ln (Mul (Exp a) (Exp b))) (Add a b))
; Special case: ln(exp(x) * y) = ln(exp(x)) + ln(y) = x + ln(y) (if y is positive)
(rule ((= e (Ln (Mul (Exp ?x) ?y)))
(= (ival-positive ?y) true))
((union e (Add ?x (Ln ?y)))))
; Special case: ln(x * exp(y)) = ln(x) + ln(exp(y)) = ln(x) + y (if x is positive)
(rule ((= e (Ln (Mul ?x (Exp ?y))))
(= (ival-positive ?x) true))
((union e (Add (Ln ?x) ?y))))
; ln(a / b) = ln(a) - ln(b) (only if both a and b are positive)
; Note: Division is represented as Mul(a, Pow(b, Neg(Num 1.0))) in canonical form
(rule ((= e (Ln (Mul ?a (Pow ?b (Neg (Num 1.0))))))
(= (ival-positive ?a) true)
(= (ival-positive ?b) true))
((union e (Add (Ln ?a) (Neg (Ln ?b))))))
; ln(x^a) = a * ln(x) (only if x is positive)
(rule ((= e (Ln (Pow ?x ?a)))
(= (ival-positive ?x) true))
((union e (Mul ?a (Ln ?x)))))
; exp(a + b) = exp(a) * exp(b)
(rewrite (Exp (Add a b)) (Mul (Exp a) (Exp b)))
; exp(a - b) = exp(a) / exp(b) -> exp(a) * exp(-b) in canonical form
(rewrite (Exp (Add a (Neg b))) (Mul (Exp a) (Exp (Neg b))))
; ========================================
; DOMAIN-AWARE SQUARE ROOT RULES
; ========================================
; sqrt(0) = 0
(rewrite (Sqrt (Num 0.0)) (Num 0.0))
; sqrt(1) = 1
(rewrite (Sqrt (Num 1.0)) (Num 1.0))
; sqrt(x^2) = |x| = x (only if x is non-negative)
(rule ((= e (Sqrt (Pow ?x (Num 2.0))))
(= (ival-nonneg ?x) true))
((union e ?x)))
; sqrt(x * x) = |x| = x (only if x is non-negative)
(rule ((= e (Sqrt (Mul ?x ?x)))
(= (ival-nonneg ?x) true))
((union e ?x)))
; sqrt(a * b) = sqrt(a) * sqrt(b) (only if both a and b are non-negative)
(rule ((= e (Sqrt (Mul ?a ?b)))
(= (ival-nonneg ?a) true)
(= (ival-nonneg ?b) true))
((union e (Mul (Sqrt ?a) (Sqrt ?b)))))
; ========================================
; POWER SIMPLIFICATION RULES
; ========================================
; x^(a + b) = x^a * x^b (only if x is positive)
(rule ((= e (Pow ?x (Add ?a ?b)))
(= (ival-positive ?x) true))
((union e (Mul (Pow ?x ?a) (Pow ?x ?b)))))
; (x^a)^b = x^(a*b) (only if x is positive)
(rule ((= e (Pow (Pow ?x ?a) ?b))
(= (ival-positive ?x) true))
((union e (Pow ?x (Mul ?a ?b)))))
; (a * b)^c = a^c * b^c (only if a and b are positive)
(rule ((= e (Pow (Mul ?a ?b) ?c))
(= (ival-positive ?a) true)
(= (ival-positive ?b) true))
((union e (Mul (Pow ?a ?c) (Pow ?b ?c)))))
"
.to_string()
}
pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
let egglog_expr = self.ast_to_egglog(expr)?;
let expr_id = format!("expr_{}", self.var_counter);
self.var_counter += 1;
self.expr_cache.insert(expr_id.clone(), expr.clone());
let add_command = format!("(let {expr_id} {egglog_expr})");
self.egraph
.parse_and_run_program(None, &add_command)
.map_err(|e| {
DSLCompileError::Optimization(format!("Failed to add expression to egglog: {e}"))
})?;
self.egraph
.parse_and_run_program(None, "(run 10)")
.map_err(|e| {
DSLCompileError::Optimization(format!("Failed to run mathematical rules: {e}"))
})?;
self.extract_best(&expr_id)
}
pub fn analyze_interval(&mut self, expr: &ASTRepr<f64>) -> Result<String> {
let egglog_expr = self.ast_to_egglog(expr)?;
let expr_id = format!("interval_expr_{}", self.var_counter);
self.var_counter += 1;
let add_command = format!("(let {expr_id} {egglog_expr})");
self.egraph
.parse_and_run_program(None, &add_command)
.map_err(|e| {
DSLCompileError::Optimization(format!(
"Failed to add expression for interval analysis: {e}"
))
})?;
self.egraph
.parse_and_run_program(None, "(run 5)")
.map_err(|e| {
DSLCompileError::Optimization(format!("Failed to run interval analysis: {e}"))
})?;
self.analyze_interval_heuristic(expr)
}
pub fn is_domain_safe(&mut self, expr: &ASTRepr<f64>, operation: &str) -> Result<bool> {
match operation {
"ln" => self.is_positive_definite(expr),
"sqrt" => self.is_non_negative(expr),
"div" => self.is_nonzero_denominator(expr),
_ => Ok(true), }
}
fn analyze_interval_heuristic(&self, expr: &ASTRepr<f64>) -> Result<String> {
match expr {
ASTRepr::Constant(val) => Ok(format!("[{val}, {val}] (singleton interval)")),
ASTRepr::Variable(_) => Ok("(-∞, +∞) (unknown variable bounds)".to_string()),
ASTRepr::Add(left, right) => {
let left_analysis = self.analyze_interval_heuristic(left)?;
let right_analysis = self.analyze_interval_heuristic(right)?;
Ok(format!(
"Sum of intervals: {left_analysis} + {right_analysis}"
))
}
ASTRepr::Mul(left, right) => {
let left_analysis = self.analyze_interval_heuristic(left)?;
let right_analysis = self.analyze_interval_heuristic(right)?;
Ok(format!(
"Product of intervals: {left_analysis} * {right_analysis}"
))
}
ASTRepr::Exp(_) => Ok("(0, +∞) (exponential is always positive)".to_string()),
ASTRepr::Ln(inner) => {
if self.is_positive_definite(inner)? {
Ok("(-∞, +∞) (ln of positive expression)".to_string())
} else {
Ok("Domain error: ln requires positive argument".to_string())
}
}
ASTRepr::Sqrt(inner) => {
if self.is_non_negative(inner)? {
Ok("[0, +∞) (sqrt of non-negative expression)".to_string())
} else {
Ok("Domain error: sqrt requires non-negative argument".to_string())
}
}
_ => Ok("Complex expression - detailed analysis needed".to_string()),
}
}
fn is_positive_definite(&self, expr: &ASTRepr<f64>) -> Result<bool> {
match expr {
ASTRepr::Constant(val) => Ok(*val > 0.0),
ASTRepr::Exp(_) => Ok(true), ASTRepr::Mul(left, right) => {
let left_pos = self.is_positive_definite(left)?;
let right_pos = self.is_positive_definite(right)?;
let left_neg = self.is_negative_definite(left)?;
let right_neg = self.is_negative_definite(right)?;
Ok((left_pos && right_pos) || (left_neg && right_neg))
}
ASTRepr::Pow(base, exp) => {
if self.is_positive_definite(base)? {
Ok(true)
} else if let ASTRepr::Constant(exp_val) = exp.as_ref() {
if exp_val.fract() == 0.0 && (*exp_val as i64) % 2 == 0 {
Ok(true)
} else {
Ok(false)
}
} else {
Ok(false)
}
}
ASTRepr::Sqrt(inner) => {
self.is_positive_definite(inner)
}
_ => Ok(false), }
}
fn is_negative_definite(&self, expr: &ASTRepr<f64>) -> Result<bool> {
match expr {
ASTRepr::Constant(val) => Ok(*val < 0.0),
ASTRepr::Neg(inner) => self.is_positive_definite(inner),
_ => Ok(false), }
}
fn is_non_negative(&self, expr: &ASTRepr<f64>) -> Result<bool> {
match expr {
ASTRepr::Constant(val) => Ok(*val >= 0.0),
ASTRepr::Exp(_) => Ok(true), ASTRepr::Sqrt(_) => Ok(true), ASTRepr::Mul(left, right) => {
let left_nonneg = self.is_non_negative(left)?;
let right_nonneg = self.is_non_negative(right)?;
let left_nonpos = self.is_non_positive(left)?;
let right_nonpos = self.is_non_positive(right)?;
Ok((left_nonneg && right_nonneg) || (left_nonpos && right_nonpos))
}
ASTRepr::Pow(base, exp) => {
if self.is_non_negative(base)? {
Ok(true)
} else if let ASTRepr::Constant(exp_val) = exp.as_ref() {
if exp_val.fract() == 0.0 && (*exp_val as i64) % 2 == 0 {
Ok(true)
} else {
Ok(false)
}
} else {
Ok(false)
}
}
ASTRepr::Add(left, right) => {
Ok(self.is_non_negative(left)? && self.is_non_negative(right)?)
}
_ => Ok(false), }
}
fn is_non_positive(&self, expr: &ASTRepr<f64>) -> Result<bool> {
match expr {
ASTRepr::Constant(val) => Ok(*val <= 0.0),
ASTRepr::Neg(inner) => self.is_non_negative(inner),
_ => Ok(false), }
}
fn is_nonzero_denominator(&self, expr: &ASTRepr<f64>) -> Result<bool> {
match expr {
ASTRepr::Constant(val) => Ok(*val != 0.0),
ASTRepr::Exp(_) => Ok(true), ASTRepr::Sqrt(inner) => {
self.is_positive_definite(inner)
}
_ => Ok(false), }
}
pub fn ast_to_egglog(&self, expr: &ASTRepr<f64>) -> Result<String> {
match expr {
ASTRepr::Constant(value) => {
if value.fract() == 0.0 {
Ok(format!("(Num {value:.1})"))
} else {
Ok(format!("(Num {value})"))
}
}
ASTRepr::Variable(index) => Ok(format!("(Var \"x{index}\")")),
ASTRepr::Add(left, right) => {
let left_s = self.ast_to_egglog(left)?;
let right_s = self.ast_to_egglog(right)?;
Ok(format!("(Add {left_s} {right_s})"))
}
ASTRepr::Sub(left, right) => {
let left_s = self.ast_to_egglog(left)?;
let right_s = self.ast_to_egglog(right)?;
Ok(format!("(Add {left_s} (Neg {right_s}))"))
}
ASTRepr::Mul(left, right) => {
let left_s = self.ast_to_egglog(left)?;
let right_s = self.ast_to_egglog(right)?;
Ok(format!("(Mul {left_s} {right_s})"))
}
ASTRepr::Div(left, right) => {
let left_s = self.ast_to_egglog(left)?;
let right_s = self.ast_to_egglog(right)?;
Ok(format!("(Mul {left_s} (Pow {right_s} (Neg (Num 1.0))))"))
}
ASTRepr::Pow(base, exp) => {
let base_s = self.ast_to_egglog(base)?;
let exp_s = self.ast_to_egglog(exp)?;
Ok(format!("(Pow {base_s} {exp_s})"))
}
ASTRepr::Neg(inner) => {
let inner_s = self.ast_to_egglog(inner)?;
Ok(format!("(Neg {inner_s})"))
}
ASTRepr::Ln(inner) => {
let inner_s = self.ast_to_egglog(inner)?;
Ok(format!("(Ln {inner_s})"))
}
ASTRepr::Exp(inner) => {
let inner_s = self.ast_to_egglog(inner)?;
Ok(format!("(Exp {inner_s})"))
}
ASTRepr::Sin(inner) => {
let inner_s = self.ast_to_egglog(inner)?;
Ok(format!("(Sin {inner_s})"))
}
ASTRepr::Cos(inner) => {
let inner_s = self.ast_to_egglog(inner)?;
Ok(format!("(Cos {inner_s})"))
}
ASTRepr::Sqrt(inner) => {
let inner_s = self.ast_to_egglog(inner)?;
Ok(format!("(Sqrt {inner_s})"))
}
}
}
fn extract_best(&mut self, expr_id: &str) -> Result<ASTRepr<f64>> {
let extract_command = format!("(extract {expr_id})");
let extract_result = self
.egraph
.parse_and_run_program(None, &extract_command)
.map_err(|e| {
DSLCompileError::Optimization(format!(
"Failed to extract optimized expression: {e}"
))
})?;
let output_string = extract_result.join("\n");
match self.parse_egglog_output(&output_string) {
Ok(optimized) => Ok(optimized),
Err(_) => {
self.expr_cache.get(expr_id).cloned().ok_or_else(|| {
DSLCompileError::Optimization("Expression not found in cache".to_string())
})
}
}
}
fn parse_egglog_output(&self, output: &str) -> Result<ASTRepr<f64>> {
let cleaned = output.trim();
self.parse_sexpr(cleaned)
}
fn parse_sexpr(&self, s: &str) -> Result<ASTRepr<f64>> {
let s = s.trim();
if !s.starts_with('(') || !s.ends_with(')') {
return Err(DSLCompileError::Optimization(format!(
"Invalid s-expression: {s}"
)));
}
let inner = &s[1..s.len() - 1];
let tokens = self.tokenize_sexpr(inner)?;
if tokens.is_empty() {
return Err(DSLCompileError::Optimization(
"Empty s-expression".to_string(),
));
}
match tokens[0].as_str() {
"Num" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Num requires exactly one argument".to_string(),
));
}
let value = tokens[1].parse::<f64>().map_err(|_| {
DSLCompileError::Optimization(format!("Invalid number: {}", tokens[1]))
})?;
Ok(ASTRepr::Constant(value))
}
"Var" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Var requires exactly one argument".to_string(),
));
}
let var_name = tokens[1].trim_matches('"');
if let Some(index_str) = var_name.strip_prefix('x') {
let index = index_str.parse::<usize>().map_err(|_| {
DSLCompileError::Optimization(format!(
"Invalid variable index: {index_str}"
))
})?;
Ok(ASTRepr::Variable(index))
} else {
Err(DSLCompileError::Optimization(format!(
"Invalid variable name: {var_name}"
)))
}
}
"Add" => {
if tokens.len() != 3 {
return Err(DSLCompileError::Optimization(
"Add requires exactly two arguments".to_string(),
));
}
let left = self.parse_sexpr(&tokens[1])?;
let right = self.parse_sexpr(&tokens[2])?;
Ok(ASTRepr::Add(Box::new(left), Box::new(right)))
}
"Mul" => {
if tokens.len() != 3 {
return Err(DSLCompileError::Optimization(
"Mul requires exactly two arguments".to_string(),
));
}
let left = self.parse_sexpr(&tokens[1])?;
let right = self.parse_sexpr(&tokens[2])?;
Ok(ASTRepr::Mul(Box::new(left), Box::new(right)))
}
"Neg" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Neg requires exactly one argument".to_string(),
));
}
let inner = self.parse_sexpr(&tokens[1])?;
Ok(ASTRepr::Neg(Box::new(inner)))
}
"Pow" => {
if tokens.len() != 3 {
return Err(DSLCompileError::Optimization(
"Pow requires exactly two arguments".to_string(),
));
}
let base = self.parse_sexpr(&tokens[1])?;
let exp = self.parse_sexpr(&tokens[2])?;
Ok(ASTRepr::Pow(Box::new(base), Box::new(exp)))
}
"Ln" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Ln requires exactly one argument".to_string(),
));
}
let inner = self.parse_sexpr(&tokens[1])?;
Ok(ASTRepr::Ln(Box::new(inner)))
}
"Exp" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Exp requires exactly one argument".to_string(),
));
}
let inner = self.parse_sexpr(&tokens[1])?;
Ok(ASTRepr::Exp(Box::new(inner)))
}
"Sin" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Sin requires exactly one argument".to_string(),
));
}
let inner = self.parse_sexpr(&tokens[1])?;
Ok(ASTRepr::Sin(Box::new(inner)))
}
"Cos" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Cos requires exactly one argument".to_string(),
));
}
let inner = self.parse_sexpr(&tokens[1])?;
Ok(ASTRepr::Cos(Box::new(inner)))
}
"Sqrt" => {
if tokens.len() != 2 {
return Err(DSLCompileError::Optimization(
"Sqrt requires exactly one argument".to_string(),
));
}
let inner = self.parse_sexpr(&tokens[1])?;
Ok(ASTRepr::Sqrt(Box::new(inner)))
}
_ => Err(DSLCompileError::Optimization(format!(
"Unknown operation: {}",
tokens[0]
))),
}
}
fn tokenize_sexpr(&self, s: &str) -> Result<Vec<String>> {
let mut tokens = Vec::new();
let mut current_token = String::new();
let mut paren_depth = 0;
let mut in_string = false;
let chars = s.chars().peekable();
for ch in chars {
match ch {
'"' => {
in_string = !in_string;
current_token.push(ch);
}
'(' if !in_string => {
if paren_depth == 0 && !current_token.is_empty() {
tokens.push(current_token.trim().to_string());
current_token.clear();
}
current_token.push(ch);
paren_depth += 1;
}
')' if !in_string => {
current_token.push(ch);
paren_depth -= 1;
if paren_depth == 0 {
tokens.push(current_token.trim().to_string());
current_token.clear();
}
}
' ' | '\t' | '\n' if !in_string && paren_depth == 0 => {
if !current_token.is_empty() {
tokens.push(current_token.trim().to_string());
current_token.clear();
}
}
_ => {
current_token.push(ch);
}
}
}
if !current_token.is_empty() {
tokens.push(current_token.trim().to_string());
}
Ok(tokens)
}
}
#[cfg(not(feature = "optimization"))]
pub struct NativeEgglogOptimizer;
#[cfg(not(feature = "optimization"))]
impl NativeEgglogOptimizer {
pub fn new() -> Result<Self> {
Ok(Self)
}
pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
Ok(expr.clone())
}
}
pub fn optimize_with_native_egglog(expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
let mut optimizer = NativeEgglogOptimizer::new()?;
optimizer.optimize(expr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::final_tagless::{ASTEval, ASTMathExpr};
#[test]
fn test_native_egglog_creation() {
let result = NativeEgglogOptimizer::new();
if let Err(e) = &result {
println!("Error creating NativeEgglogOptimizer: {e}");
}
assert!(result.is_ok());
}
#[test]
fn test_ast_to_egglog_conversion() {
let optimizer = NativeEgglogOptimizer::new().unwrap();
let num = ASTRepr::Constant(42.0);
let egglog_str = optimizer.ast_to_egglog(&num).unwrap();
assert_eq!(egglog_str, "(Num 42.0)");
let var = ASTRepr::Variable(0);
let egglog_str = optimizer.ast_to_egglog(&var).unwrap();
assert_eq!(egglog_str, "(Var \"x0\")");
let add = ASTRepr::Add(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Constant(1.0)),
);
let egglog_str = optimizer.ast_to_egglog(&add).unwrap();
assert_eq!(egglog_str, "(Add (Var \"x0\") (Num 1.0))");
}
#[test]
fn test_canonical_form_conversion() {
let optimizer = NativeEgglogOptimizer::new().unwrap();
let sub = ASTRepr::Sub(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Constant(1.0)),
);
let egglog_str = optimizer.ast_to_egglog(&sub).unwrap();
assert_eq!(egglog_str, "(Add (Var \"x0\") (Neg (Num 1.0)))");
let div = ASTRepr::Div(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Constant(2.0)),
);
let egglog_str = optimizer.ast_to_egglog(&div).unwrap();
assert_eq!(
egglog_str,
"(Mul (Var \"x0\") (Pow (Num 2.0) (Neg (Num 1.0))))"
);
}
#[test]
fn test_basic_optimization() {
let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(0.0));
let result = optimize_with_native_egglog(&expr);
#[cfg(feature = "optimization")]
{
assert!(result.is_ok());
}
#[cfg(not(feature = "optimization"))]
{
assert!(result.is_ok());
}
}
#[test]
fn test_domain_aware_optimization() {
let mut optimizer = NativeEgglogOptimizer::new().unwrap();
let safe_expr = ASTRepr::Ln(Box::new(ASTRepr::Exp(Box::new(ASTRepr::Variable(0)))));
let result = optimizer.optimize(&safe_expr);
assert!(result.is_ok());
let potentially_unsafe =
ASTRepr::Exp(Box::new(ASTRepr::Ln(Box::new(ASTRepr::Variable(0)))));
let result = optimizer.optimize(&potentially_unsafe);
assert!(result.is_ok());
}
#[test]
fn test_interval_analysis() {
let mut optimizer = NativeEgglogOptimizer::new().unwrap();
let constant_expr = ASTRepr::Constant(5.0);
let interval_info = optimizer.analyze_interval(&constant_expr);
assert!(interval_info.is_ok());
let var_expr = ASTRepr::Variable(0);
let interval_info = optimizer.analyze_interval(&var_expr);
assert!(interval_info.is_ok());
let complex_expr = ASTRepr::Add(
Box::new(ASTRepr::Constant(2.0)),
Box::new(ASTRepr::Mul(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Constant(3.0)),
)),
);
let interval_info = optimizer.analyze_interval(&complex_expr);
assert!(interval_info.is_ok());
}
#[test]
fn test_domain_safety_checks() {
let mut optimizer = NativeEgglogOptimizer::new().unwrap();
let positive_constant = ASTRepr::Constant(5.0);
let is_safe = optimizer.is_domain_safe(&positive_constant, "ln");
assert!(is_safe.is_ok());
let nonzero_expr = ASTRepr::Add(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Constant(1.0)),
);
let is_safe = optimizer.is_domain_safe(&nonzero_expr, "div");
assert!(is_safe.is_ok());
let sqrt_expr = ASTRepr::Pow(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Constant(2.0)),
);
let is_safe = optimizer.is_domain_safe(&sqrt_expr, "sqrt");
assert!(is_safe.is_ok());
}
#[test]
fn test_domain_aware_ln_rules() {
let mut optimizer = NativeEgglogOptimizer::new().unwrap();
let ln_product = ASTRepr::Ln(Box::new(ASTRepr::Mul(
Box::new(ASTRepr::Constant(2.0)),
Box::new(ASTRepr::Constant(3.0)),
)));
let result = optimizer.optimize(&ln_product);
assert!(result.is_ok());
}
#[test]
fn test_sqrt_domain_awareness() {
let mut optimizer = NativeEgglogOptimizer::new().unwrap();
let sqrt_square = ASTRepr::Sqrt(Box::new(ASTRepr::Pow(
Box::new(ASTRepr::Variable(0)),
Box::new(ASTRepr::Constant(2.0)),
)));
let result = optimizer.optimize(&sqrt_square);
assert!(result.is_ok());
}
}