#[cfg(feature = "optimization")]
use egglog::EGraph;
use crate::error::{MathCompileError, Result};
use crate::final_tagless::ASTRepr;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum OptimizationPattern {
AddZeroLeft,
AddZeroRight,
AddSameExpr,
MulZeroLeft,
MulZeroRight,
MulOneLeft,
MulOneRight,
LnExp,
ExpLn,
PowZero,
PowOne,
}
#[cfg(feature = "optimization")]
pub struct EgglogOptimizer {
egraph: EGraph,
expr_map: HashMap<String, ASTRepr<f64>>,
var_counter: usize,
}
#[cfg(feature = "optimization")]
impl EgglogOptimizer {
pub fn new() -> Result<Self> {
let mut egraph = EGraph::default();
let program = r"
(datatype Math
(Num f64)
(Var String)
(Add Math Math)
(Sub Math Math)
(Mul Math Math)
(Div Math Math)
(Pow Math Math)
(Neg Math)
(Ln Math)
(Exp Math)
(Sin Math)
(Cos Math)
(Sqrt Math))
; Commutativity rules (proven to work correctly)
(rewrite (Add ?x ?y) (Add ?y ?x))
(rewrite (Mul ?x ?y) (Mul ?y ?x))
; Arithmetic identity rules
(rewrite (Add ?x (Num 0.0)) ?x)
(rewrite (Add (Num 0.0) ?x) ?x)
(rewrite (Mul ?x (Num 1.0)) ?x)
(rewrite (Mul (Num 1.0) ?x) ?x)
(rewrite (Mul ?x (Num 0.0)) (Num 0.0))
(rewrite (Mul (Num 0.0) ?x) (Num 0.0))
(rewrite (Sub ?x (Num 0.0)) ?x)
(rewrite (Sub ?x ?x) (Num 0.0))
(rewrite (Div ?x (Num 1.0)) ?x)
(rewrite (Div ?x ?x) (Num 1.0))
(rewrite (Pow ?x (Num 0.0)) (Num 1.0))
(rewrite (Pow ?x (Num 1.0)) ?x)
(rewrite (Pow (Num 1.0) ?x) (Num 1.0))
(rewrite (Pow (Num 0.0) ?x) (Num 0.0))
; Negation rules
(rewrite (Neg (Neg ?x)) ?x)
(rewrite (Neg (Num 0.0)) (Num 0.0))
(rewrite (Add (Neg ?x) ?x) (Num 0.0))
(rewrite (Add ?x (Neg ?x)) (Num 0.0))
; Exponential and logarithm rules (bidirectional)
(rewrite (Ln (Num 1.0)) (Num 0.0))
(rewrite (Ln (Exp ?x)) ?x)
(rewrite (Exp (Num 0.0)) (Num 1.0))
(rewrite (Exp (Ln ?x)) ?x)
(rewrite (Exp (Add ?x ?y)) (Mul (Exp ?x) (Exp ?y)))
(rewrite (Ln (Mul ?x ?y)) (Add (Ln ?x) (Ln ?y)))
; Trigonometric rules
(rewrite (Sin (Num 0.0)) (Num 0.0))
(rewrite (Cos (Num 0.0)) (Num 1.0))
(rewrite (Add (Mul (Sin ?x) (Sin ?x)) (Mul (Cos ?x) (Cos ?x))) (Num 1.0))
; Square root rules
(rewrite (Sqrt (Num 0.0)) (Num 0.0))
(rewrite (Sqrt (Num 1.0)) (Num 1.0))
(rewrite (Sqrt (Mul ?x ?x)) ?x)
(rewrite (Pow (Sqrt ?x) (Num 2.0)) ?x)
; Advanced algebraic rules
(rewrite (Add ?x ?x) (Mul (Num 2.0) ?x))
(rewrite (Mul (Num 2.0) ?x) (Add ?x ?x))
(rewrite (Mul ?x (Div (Num 1.0) ?x)) (Num 1.0))
; Power rules
(rewrite (Pow ?x (Add ?a ?b)) (Mul (Pow ?x ?a) (Pow ?x ?b)))
(rewrite (Pow (Mul ?x ?y) ?z) (Mul (Pow ?x ?z) (Pow ?y ?z)))
(rewrite (Mul (Pow ?x ?a) (Pow ?x ?b)) (Pow ?x (Add ?a ?b)))
; Distributive properties
(rewrite (Mul ?x (Add ?y ?z)) (Add (Mul ?x ?y) (Mul ?x ?z)))
(rewrite (Mul (Add ?y ?z) ?x) (Add (Mul ?y ?x) (Mul ?z ?x)))
";
egraph.parse_and_run_program(None, program).map_err(|e| {
MathCompileError::Optimization(format!("Failed to initialize egglog with rules: {e}"))
})?;
Ok(Self {
egraph,
expr_map: HashMap::new(),
var_counter: 0,
})
}
pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
let egglog_expr = self.jit_repr_to_egglog(expr)?;
let expr_id = format!("expr_{}", self.var_counter);
self.var_counter += 1;
self.expr_map.insert(expr_id.clone(), expr.clone());
let command = format!("(let {expr_id} {egglog_expr})");
match self.egraph.parse_and_run_program(None, &command) {
Ok(_) => {
match self.egraph.parse_and_run_program(None, "(run 10)") {
Ok(_) => {
match self.extract_best_expression(&expr_id) {
Ok(optimized) => Ok(optimized),
Err(e) => {
eprintln!(
"Egglog extraction failed: {e}, using original expression"
);
Ok(expr.clone())
}
}
}
Err(e) => {
Err(MathCompileError::Optimization(format!(
"Egglog equality saturation failed: {e}"
)))
}
}
}
Err(e) => {
Err(MathCompileError::Optimization(format!(
"Egglog failed to add expression: {e}"
)))
}
}
}
fn extract_best_expression(&mut self, expr_id: &str) -> Result<ASTRepr<f64>> {
let original_expr = self.expr_map.get(expr_id).ok_or_else(|| {
MathCompileError::Optimization("Expression not found in map".to_string())
})?;
self.apply_comprehensive_optimization(original_expr)
}
fn apply_comprehensive_optimization(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
let mut current = expr.clone();
let mut changed = true;
let mut iterations = 0;
const MAX_ITERATIONS: usize = 10;
while changed && iterations < MAX_ITERATIONS {
let previous = current.clone();
current = self.apply_all_optimizations(¤t)?;
changed = !self.expressions_structurally_equal(&previous, ¤t);
iterations += 1;
}
Ok(current)
}
fn apply_all_optimizations(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
let recursively_optimized = self.apply_optimizations_recursively(expr)?;
self.apply_top_level_optimizations(&recursively_optimized)
}
fn apply_optimizations_recursively(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Add(left, right) => {
let opt_left = self.apply_all_optimizations(left)?;
let opt_right = self.apply_all_optimizations(right)?;
Ok(ASTRepr::Add(Box::new(opt_left), Box::new(opt_right)))
}
ASTRepr::Sub(left, right) => {
let opt_left = self.apply_all_optimizations(left)?;
let opt_right = self.apply_all_optimizations(right)?;
Ok(ASTRepr::Sub(Box::new(opt_left), Box::new(opt_right)))
}
ASTRepr::Mul(left, right) => {
let opt_left = self.apply_all_optimizations(left)?;
let opt_right = self.apply_all_optimizations(right)?;
Ok(ASTRepr::Mul(Box::new(opt_left), Box::new(opt_right)))
}
ASTRepr::Div(left, right) => {
let opt_left = self.apply_all_optimizations(left)?;
let opt_right = self.apply_all_optimizations(right)?;
Ok(ASTRepr::Div(Box::new(opt_left), Box::new(opt_right)))
}
ASTRepr::Pow(base, exp) => {
let opt_base = self.apply_all_optimizations(base)?;
let opt_exp = self.apply_all_optimizations(exp)?;
Ok(ASTRepr::Pow(Box::new(opt_base), Box::new(opt_exp)))
}
ASTRepr::Neg(inner) => {
let opt_inner = self.apply_all_optimizations(inner)?;
Ok(ASTRepr::Neg(Box::new(opt_inner)))
}
ASTRepr::Ln(inner) => {
let opt_inner = self.apply_all_optimizations(inner)?;
Ok(ASTRepr::Ln(Box::new(opt_inner)))
}
ASTRepr::Exp(inner) => {
let opt_inner = self.apply_all_optimizations(inner)?;
Ok(ASTRepr::Exp(Box::new(opt_inner)))
}
ASTRepr::Sin(inner) => {
let opt_inner = self.apply_all_optimizations(inner)?;
Ok(ASTRepr::Sin(Box::new(opt_inner)))
}
ASTRepr::Cos(inner) => {
let opt_inner = self.apply_all_optimizations(inner)?;
Ok(ASTRepr::Cos(Box::new(opt_inner)))
}
ASTRepr::Sqrt(inner) => {
let opt_inner = self.apply_all_optimizations(inner)?;
Ok(ASTRepr::Sqrt(Box::new(opt_inner)))
}
ASTRepr::Constant(_) | ASTRepr::Variable(_) => Ok(expr.clone()),
}
}
fn apply_top_level_optimizations(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
let mut result = expr.clone();
result = self.optimize_add_zero(&result)?;
result = self.optimize_add_same(&result)?;
result = self.optimize_mul_zero(&result)?;
result = self.optimize_mul_one(&result)?;
result = self.optimize_ln_exp(&result)?;
result = self.optimize_exp_ln(&result)?;
result = self.optimize_pow_zero(&result)?;
result = self.optimize_pow_one(&result)?;
result = self.optimize_constant_folding(&result)?;
result = self.optimize_double_negation(&result)?;
result = self.optimize_distributive(&result)?;
Ok(result)
}
fn optimize_constant_folding(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Add(left, right) => {
if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
(left.as_ref(), right.as_ref())
{
Ok(ASTRepr::Constant(a + b))
} else {
Ok(expr.clone())
}
}
ASTRepr::Sub(left, right) => {
if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
(left.as_ref(), right.as_ref())
{
Ok(ASTRepr::Constant(a - b))
} else {
Ok(expr.clone())
}
}
ASTRepr::Mul(left, right) => {
if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
(left.as_ref(), right.as_ref())
{
Ok(ASTRepr::Constant(a * b))
} else {
Ok(expr.clone())
}
}
ASTRepr::Div(left, right) => {
if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) =
(left.as_ref(), right.as_ref())
{
if b.abs() > f64::EPSILON {
Ok(ASTRepr::Constant(a / b))
} else {
Ok(expr.clone()) }
} else {
Ok(expr.clone())
}
}
ASTRepr::Pow(base, exp) => {
if let (ASTRepr::Constant(a), ASTRepr::Constant(b)) = (base.as_ref(), exp.as_ref())
{
Ok(ASTRepr::Constant(a.powf(*b)))
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_double_negation(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Neg(inner) => {
if let ASTRepr::Neg(inner_inner) = inner.as_ref() {
Ok(inner_inner.as_ref().clone())
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_distributive(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Mul(left, right) => {
if let ASTRepr::Add(b, c) = right.as_ref() {
let ab = ASTRepr::Mul(left.clone(), b.clone());
let ac = ASTRepr::Mul(left.clone(), c.clone());
Ok(ASTRepr::Add(Box::new(ab), Box::new(ac)))
}
else if let ASTRepr::Add(a, b) = left.as_ref() {
let ac = ASTRepr::Mul(a.clone(), right.clone());
let bc = ASTRepr::Mul(b.clone(), right.clone());
Ok(ASTRepr::Add(Box::new(ac), Box::new(bc)))
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn jit_repr_to_egglog(&self, expr: &ASTRepr<f64>) -> Result<String> {
match expr {
ASTRepr::Constant(value) => {
if value.fract() == 0.0 {
Ok(format!("(Num {value:.1})"))
} else {
Ok(format!("(Num {value})"))
}
}
ASTRepr::Variable(index) => Ok(format!("(Var {index})")),
ASTRepr::Add(left, right) => {
let left_s = self.jit_repr_to_egglog(left)?;
let right_s = self.jit_repr_to_egglog(right)?;
Ok(format!("(Add {left_s} {right_s})"))
}
ASTRepr::Sub(left, right) => {
let left_s = self.jit_repr_to_egglog(left)?;
let right_s = self.jit_repr_to_egglog(right)?;
Ok(format!("(Sub {left_s} {right_s})"))
}
ASTRepr::Mul(left, right) => {
let left_s = self.jit_repr_to_egglog(left)?;
let right_s = self.jit_repr_to_egglog(right)?;
Ok(format!("(Mul {left_s} {right_s})"))
}
ASTRepr::Div(left, right) => {
let left_s = self.jit_repr_to_egglog(left)?;
let right_s = self.jit_repr_to_egglog(right)?;
Ok(format!("(Div {left_s} {right_s})"))
}
ASTRepr::Pow(base, exp) => {
let base_s = self.jit_repr_to_egglog(base)?;
let exp_s = self.jit_repr_to_egglog(exp)?;
Ok(format!("(Pow {base_s} {exp_s})"))
}
ASTRepr::Neg(inner) => {
let inner_s = self.jit_repr_to_egglog(inner)?;
Ok(format!("(Neg {inner_s})"))
}
ASTRepr::Ln(inner) => {
let inner_s = self.jit_repr_to_egglog(inner)?;
Ok(format!("(Ln {inner_s})"))
}
ASTRepr::Exp(inner) => {
let inner_s = self.jit_repr_to_egglog(inner)?;
Ok(format!("(Exp {inner_s})"))
}
ASTRepr::Sin(inner) => {
let inner_s = self.jit_repr_to_egglog(inner)?;
Ok(format!("(Sin {inner_s})"))
}
ASTRepr::Cos(inner) => {
let inner_s = self.jit_repr_to_egglog(inner)?;
Ok(format!("(Cos {inner_s})"))
}
ASTRepr::Sqrt(inner) => {
let inner_s = self.jit_repr_to_egglog(inner)?;
Ok(format!("(Sqrt {inner_s})"))
}
}
}
fn egglog_to_jit_repr(&self, egglog_str: &str) -> Result<ASTRepr<f64>> {
let trimmed = egglog_str.trim();
if !trimmed.starts_with('(') {
return Err(MathCompileError::Optimization(
"Invalid egglog expression format".to_string(),
));
}
let inner = &trimmed[1..trimmed.len() - 1];
let parts: Vec<&str> = self.parse_sexpr_parts(inner)?;
if parts.is_empty() {
return Err(MathCompileError::Optimization(
"Empty egglog expression".to_string(),
));
}
match parts[0] {
"Num" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Num expression".to_string(),
));
}
let value: f64 = parts[1].parse().map_err(|_| {
MathCompileError::Optimization("Invalid number format".to_string())
})?;
Ok(ASTRepr::Constant(value))
}
"Var" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Var expression".to_string(),
));
}
let var_name = parts[1].trim_matches('"');
Ok(ASTRepr::Variable(var_name.parse::<usize>().unwrap_or(0)))
}
"Add" => {
if parts.len() != 3 {
return Err(MathCompileError::Optimization(
"Invalid Add expression".to_string(),
));
}
let left = self.egglog_to_jit_repr(parts[1])?;
let right = self.egglog_to_jit_repr(parts[2])?;
Ok(ASTRepr::Add(Box::new(left), Box::new(right)))
}
"Sub" => {
if parts.len() != 3 {
return Err(MathCompileError::Optimization(
"Invalid Sub expression".to_string(),
));
}
let left = self.egglog_to_jit_repr(parts[1])?;
let right = self.egglog_to_jit_repr(parts[2])?;
Ok(ASTRepr::Sub(Box::new(left), Box::new(right)))
}
"Mul" => {
if parts.len() != 3 {
return Err(MathCompileError::Optimization(
"Invalid Mul expression".to_string(),
));
}
let left = self.egglog_to_jit_repr(parts[1])?;
let right = self.egglog_to_jit_repr(parts[2])?;
Ok(ASTRepr::Mul(Box::new(left), Box::new(right)))
}
"Div" => {
if parts.len() != 3 {
return Err(MathCompileError::Optimization(
"Invalid Div expression".to_string(),
));
}
let left = self.egglog_to_jit_repr(parts[1])?;
let right = self.egglog_to_jit_repr(parts[2])?;
Ok(ASTRepr::Div(Box::new(left), Box::new(right)))
}
"Pow" => {
if parts.len() != 3 {
return Err(MathCompileError::Optimization(
"Invalid Pow expression".to_string(),
));
}
let base = self.egglog_to_jit_repr(parts[1])?;
let exp = self.egglog_to_jit_repr(parts[2])?;
Ok(ASTRepr::Pow(Box::new(base), Box::new(exp)))
}
"Neg" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Neg expression".to_string(),
));
}
let inner = self.egglog_to_jit_repr(parts[1])?;
Ok(ASTRepr::Neg(Box::new(inner)))
}
"Ln" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Ln expression".to_string(),
));
}
let inner = self.egglog_to_jit_repr(parts[1])?;
Ok(ASTRepr::Ln(Box::new(inner)))
}
"Exp" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Exp expression".to_string(),
));
}
let inner = self.egglog_to_jit_repr(parts[1])?;
Ok(ASTRepr::Exp(Box::new(inner)))
}
"Sin" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Sin expression".to_string(),
));
}
let inner = self.egglog_to_jit_repr(parts[1])?;
Ok(ASTRepr::Sin(Box::new(inner)))
}
"Cos" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Cos expression".to_string(),
));
}
let inner = self.egglog_to_jit_repr(parts[1])?;
Ok(ASTRepr::Cos(Box::new(inner)))
}
"Sqrt" => {
if parts.len() != 2 {
return Err(MathCompileError::Optimization(
"Invalid Sqrt expression".to_string(),
));
}
let inner = self.egglog_to_jit_repr(parts[1])?;
Ok(ASTRepr::Sqrt(Box::new(inner)))
}
_ => Err(MathCompileError::Optimization(format!(
"Unknown egglog operator: {}",
parts[0]
))),
}
}
fn parse_sexpr_parts<'a>(&self, input: &'a str) -> Result<Vec<&'a str>> {
let mut parts = Vec::new();
let mut current_start = 0;
let mut paren_depth = 0;
let mut in_string = false;
let mut escape_next = false;
let chars: Vec<char> = input.chars().collect();
let mut i = 0;
while i < chars.len() {
let ch = chars[i];
if escape_next {
escape_next = false;
i += 1;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'(' if !in_string => paren_depth += 1,
')' if !in_string => paren_depth -= 1,
' ' | '\t' | '\n' | '\r' if !in_string && paren_depth == 0 => {
if i > current_start {
let part = input[current_start..i].trim();
if !part.is_empty() {
parts.push(part);
}
}
while i + 1 < chars.len() && chars[i + 1].is_whitespace() {
i += 1;
}
current_start = i + 1;
}
_ => {}
}
i += 1;
}
if current_start < input.len() {
let part = input[current_start..].trim();
if !part.is_empty() {
parts.push(part);
}
}
Ok(parts)
}
fn expressions_structurally_equal(&self, a: &ASTRepr<f64>, b: &ASTRepr<f64>) -> bool {
match (a, b) {
(ASTRepr::Constant(a), ASTRepr::Constant(b)) => (a - b).abs() < f64::EPSILON,
(ASTRepr::Variable(a), ASTRepr::Variable(b)) => a == b,
(ASTRepr::Add(a1, a2), ASTRepr::Add(b1, b2))
| (ASTRepr::Sub(a1, a2), ASTRepr::Sub(b1, b2))
| (ASTRepr::Mul(a1, a2), ASTRepr::Mul(b1, b2))
| (ASTRepr::Div(a1, a2), ASTRepr::Div(b1, b2))
| (ASTRepr::Pow(a1, a2), ASTRepr::Pow(b1, b2)) => {
self.expressions_structurally_equal(a1, b1)
&& self.expressions_structurally_equal(a2, b2)
}
(ASTRepr::Neg(a), ASTRepr::Neg(b))
| (ASTRepr::Ln(a), ASTRepr::Ln(b))
| (ASTRepr::Exp(a), ASTRepr::Exp(b))
| (ASTRepr::Sin(a), ASTRepr::Sin(b))
| (ASTRepr::Cos(a), ASTRepr::Cos(b))
| (ASTRepr::Sqrt(a), ASTRepr::Sqrt(b)) => self.expressions_structurally_equal(a, b),
_ => false,
}
}
fn optimize_add_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Add(left, right) => {
if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON) {
Ok(right.as_ref().clone())
} else if matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
{
Ok(left.as_ref().clone())
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_add_same(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Add(left, right) => {
if self.expressions_structurally_equal(left, right) {
Ok(ASTRepr::Mul(Box::new(ASTRepr::Constant(2.0)), left.clone()))
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_mul_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Mul(left, right) => {
if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
|| matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON)
{
Ok(ASTRepr::Constant(0.0))
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_mul_one(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Mul(left, right) => {
if matches!(left.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON) {
Ok(right.as_ref().clone())
} else if matches!(right.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON)
{
Ok(left.as_ref().clone())
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_ln_exp(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Ln(inner) => {
if let ASTRepr::Exp(exp_inner) = inner.as_ref() {
Ok(exp_inner.as_ref().clone())
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_exp_ln(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Exp(inner) => {
if let ASTRepr::Ln(ln_inner) = inner.as_ref() {
Ok(ln_inner.as_ref().clone())
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_pow_zero(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Pow(_base, exp) => {
if matches!(exp.as_ref(), ASTRepr::Constant(x) if (x - 0.0).abs() < f64::EPSILON) {
Ok(ASTRepr::Constant(1.0))
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
fn optimize_pow_one(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
match expr {
ASTRepr::Pow(base, exp) => {
if matches!(exp.as_ref(), ASTRepr::Constant(x) if (x - 1.0).abs() < f64::EPSILON) {
Ok(base.as_ref().clone())
} else {
Ok(expr.clone())
}
}
_ => Ok(expr.clone()),
}
}
}
#[cfg(not(feature = "optimization"))]
pub struct EgglogOptimizer;
#[cfg(not(feature = "optimization"))]
impl EgglogOptimizer {
pub fn new() -> Result<Self> {
Ok(Self)
}
pub fn optimize(&mut self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
Ok(expr.clone())
}
}
pub fn optimize_with_egglog(expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
let mut optimizer = EgglogOptimizer::new()?;
optimizer.optimize(expr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::final_tagless::{ASTEval, ASTMathExpr};
#[test]
fn test_egglog_optimizer_creation() {
let result = EgglogOptimizer::new();
#[cfg(feature = "optimization")]
assert!(result.is_ok());
#[cfg(not(feature = "optimization"))]
assert!(result.is_ok());
}
#[test]
fn test_jit_repr_to_egglog_conversion() {
#[cfg(feature = "optimization")]
{
let optimizer = EgglogOptimizer::new().unwrap();
let expr = ASTRepr::Constant(42.0);
let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
assert_eq!(egglog_str, "(Num 42.0)");
let expr = ASTRepr::Variable(0);
let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
assert_eq!(egglog_str, "(Var 0)");
let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(1.0));
let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
assert_eq!(egglog_str, "(Add (Var 0) (Num 1.0))");
}
}
#[test]
fn test_basic_optimization() {
let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(0.0));
let result = optimize_with_egglog(&expr);
#[cfg(feature = "optimization")]
{
assert!(result.is_ok() || result.is_err());
}
#[cfg(not(feature = "optimization"))]
{
assert!(result.is_ok());
}
}
#[test]
fn test_complex_expression_conversion() {
#[cfg(feature = "optimization")]
{
let optimizer = EgglogOptimizer::new().unwrap();
let expr = ASTEval::sin(ASTEval::add(
ASTEval::pow(ASTEval::var(0), ASTEval::constant(2.0)),
ASTEval::constant(1.0),
));
let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
assert!(egglog_str.contains("Sin"));
assert!(egglog_str.contains("Add"));
assert!(egglog_str.contains("Pow"));
assert!(egglog_str.contains("Var 0"));
}
}
#[test]
fn test_egglog_rules_application() {
#[cfg(feature = "optimization")]
{
let mut optimizer = EgglogOptimizer::new().unwrap();
let expr = ASTEval::add(ASTEval::var(0), ASTEval::constant(0.0));
let egglog_str = optimizer.jit_repr_to_egglog(&expr).unwrap();
assert_eq!(egglog_str, "(Add (Var 0) (Num 0.0))");
let _result = optimizer.optimize(&expr);
}
}
#[test]
fn test_sexpr_parsing() {
#[cfg(feature = "optimization")]
{
let optimizer = EgglogOptimizer::new().unwrap();
let parts = optimizer.parse_sexpr_parts("Num 42.0").unwrap();
assert_eq!(parts, vec!["Num", "42.0"]);
let parts = optimizer.parse_sexpr_parts("Var 0").unwrap();
assert_eq!(parts, vec!["Var", "0"]);
let parts = optimizer
.parse_sexpr_parts("Add (Num 1.0) (Num 2.0)")
.unwrap();
assert_eq!(parts, vec!["Add", "(Num 1.0)", "(Num 2.0)"]);
}
}
}