use anyhow::Result;
use serde::{Deserialize, Serialize};
pub use egglog;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EgraphConfig {
pub max_iterations: usize,
pub max_nodes: usize,
pub debug: bool,
}
impl Default for EgraphConfig {
fn default() -> Self {
Self {
max_iterations: 100,
max_nodes: 100_000,
debug: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RewriteRule {
pub name: String,
pub lhs: String,
pub rhs: String,
pub condition: Option<String>,
}
impl RewriteRule {
pub fn new(name: &str, lhs: &str, rhs: &str) -> Self {
Self {
name: name.to_string(),
lhs: lhs.to_string(),
rhs: rhs.to_string(),
condition: None,
}
}
pub fn with_condition(mut self, condition: &str) -> Self {
self.condition = Some(condition.to_string());
self
}
}
pub fn algebraic_rules() -> Vec<RewriteRule> {
vec![
RewriteRule::new("add-zero", "(+ ?x 0)", "?x"),
RewriteRule::new("add-zero-left", "(+ 0 ?x)", "?x"),
RewriteRule::new("mul-one", "(* ?x 1)", "?x"),
RewriteRule::new("mul-one-left", "(* 1 ?x)", "?x"),
RewriteRule::new("mul-zero", "(* ?x 0)", "0"),
RewriteRule::new("mul-zero-left", "(* 0 ?x)", "0"),
RewriteRule::new("add-comm", "(+ ?x ?y)", "(+ ?y ?x)"),
RewriteRule::new("mul-comm", "(* ?x ?y)", "(* ?y ?x)"),
RewriteRule::new("add-assoc", "(+ (+ ?x ?y) ?z)", "(+ ?x (+ ?y ?z))"),
RewriteRule::new("mul-assoc", "(* (* ?x ?y) ?z)", "(* ?x (* ?y ?z))"),
RewriteRule::new("distribute", "(* ?x (+ ?y ?z))", "(+ (* ?x ?y) (* ?x ?z))"),
RewriteRule::new("factor", "(+ (* ?x ?y) (* ?x ?z))", "(* ?x (+ ?y ?z))"),
RewriteRule::new("double-neg", "(- (- ?x))", "?x"),
RewriteRule::new("sub-self", "(- ?x ?x)", "0"),
RewriteRule::new("div-self", "(/ ?x ?x)", "1").with_condition("(!= ?x 0)"),
]
}
pub fn boolean_rules() -> Vec<RewriteRule> {
vec![
RewriteRule::new("and-true", "(and ?x true)", "?x"),
RewriteRule::new("or-false", "(or ?x false)", "?x"),
RewriteRule::new("and-false", "(and ?x false)", "false"),
RewriteRule::new("or-true", "(or ?x true)", "true"),
RewriteRule::new("and-idem", "(and ?x ?x)", "?x"),
RewriteRule::new("or-idem", "(or ?x ?x)", "?x"),
RewriteRule::new("not-not", "(not (not ?x))", "?x"),
RewriteRule::new(
"demorgan-and",
"(not (and ?x ?y))",
"(or (not ?x) (not ?y))",
),
RewriteRule::new("demorgan-or", "(not (or ?x ?y))", "(and (not ?x) (not ?y))"),
RewriteRule::new("absorb-and", "(and ?x (or ?x ?y))", "?x"),
RewriteRule::new("absorb-or", "(or ?x (and ?x ?y))", "?x"),
]
}
pub struct ExpressionOptimizer {
config: EgraphConfig,
rules: Vec<RewriteRule>,
}
impl ExpressionOptimizer {
pub fn new(config: EgraphConfig) -> Self {
Self {
config,
rules: algebraic_rules(),
}
}
pub fn boolean(config: EgraphConfig) -> Self {
Self {
config,
rules: boolean_rules(),
}
}
pub fn with_rules(mut self, rules: Vec<RewriteRule>) -> Self {
self.rules.extend(rules);
self
}
pub fn config(&self) -> &EgraphConfig {
&self.config
}
pub fn rules(&self) -> &[RewriteRule] {
&self.rules
}
pub fn optimize(&self, expr: &str) -> Result<String> {
tracing::debug!(
expr = expr,
rules = self.rules.len(),
max_iterations = self.config.max_iterations,
"Optimizing expression"
);
Ok(expr.to_string())
}
}
impl Default for ExpressionOptimizer {
fn default() -> Self {
Self::new(EgraphConfig::default())
}
}
pub struct CodeOptimizer {
optimizer: ExpressionOptimizer,
}
impl CodeOptimizer {
pub fn new() -> Self {
Self {
optimizer: ExpressionOptimizer::new(EgraphConfig::default()),
}
}
pub fn optimize_expr(&self, code: &str) -> Result<String> {
self.optimizer.optimize(code)
}
}
impl Default for CodeOptimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = EgraphConfig::default();
assert_eq!(config.max_iterations, 100);
assert!(!config.debug);
}
#[test]
fn test_rewrite_rule() {
let rule = RewriteRule::new("test", "(+ ?x 0)", "?x");
assert_eq!(rule.name, "test");
assert!(rule.condition.is_none());
}
#[test]
fn test_algebraic_rules() {
let rules = algebraic_rules();
assert!(!rules.is_empty());
assert!(rules.iter().any(|r| r.name == "add-zero"));
}
#[test]
fn test_optimizer() {
let optimizer = ExpressionOptimizer::default();
let result = optimizer.optimize("(+ x 0)").unwrap();
assert!(!result.is_empty());
}
}