use crate::util::INTERNAL_SYMBOL_PREFIX;
use egglog::ast::*;
use egglog::util::FreshGen;
use egglog::*;
use egglog_ast::span::{RustSpan, Span};
use std::sync::Arc;
struct PrefixRuleMacro {
prefix: String,
}
impl CommandMacro for PrefixRuleMacro {
fn transform(
&self,
command: Command,
symbol_gen: &mut util::SymbolGen,
_type_info: &TypeInfo,
) -> Result<Vec<Command>, Error> {
match command {
Command::Rule { mut rule } => {
rule.name = symbol_gen.fresh(&format!("{}_{}", self.prefix, rule.name));
Ok(vec![Command::Rule { rule }])
}
cmd => Ok(vec![cmd]),
}
}
}
struct DuplicateRuleMacro;
impl CommandMacro for DuplicateRuleMacro {
fn transform(
&self,
command: Command,
symbol_gen: &mut util::SymbolGen,
_type_info: &TypeInfo,
) -> Result<Vec<Command>, Error> {
match command {
Command::Rule { rule } => {
let mut rule1 = rule.clone();
let mut rule2 = rule;
rule1.name = symbol_gen.fresh(&format!("dup1_{}", rule1.name));
rule2.name = symbol_gen.fresh(&format!("dup2_{}", rule2.name));
Ok(vec![
Command::Rule { rule: rule1 },
Command::Rule { rule: rule2 },
])
}
cmd => Ok(vec![cmd]),
}
}
}
struct AddPrintSizeAfterRuleMacro;
impl CommandMacro for AddPrintSizeAfterRuleMacro {
fn transform(
&self,
command: Command,
_symbol_gen: &mut util::SymbolGen,
_type_info: &TypeInfo,
) -> Result<Vec<Command>, Error> {
match command.clone() {
Command::Rule {
rule: GenericRule { body, .. },
} => {
if let [Fact::Fact(GenericExpr::Call(_span, head, _children))] = body.as_slice() {
return Ok(vec![
command,
Command::PrintSize(span!(), Some(head.to_string())),
]);
}
Ok(vec![command])
}
cmd => Ok(vec![cmd]),
}
}
}
#[test]
fn test_single_macro_with_desugar_program() {
let mut egraph = EGraph::default();
egraph
.command_macros_mut()
.register(Arc::new(PrefixRuleMacro {
prefix: "test".to_string(),
}));
let input = r#"
(datatype Math (Num i64))
(rule ((Num x)) ((Num (+ x 1))))
(let a (Num 1))
"#;
let result = egraph
.desugar_program(None, input)
.unwrap()
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let output = result.join("\n");
assert!(output.contains("(sort Math)"), "Expected sort declaration");
assert!(
output.contains("(constructor Num (i64) Math)"),
"Expected Num constructor"
);
assert!(
output.contains(&"_test_0".to_string()),
"Expected rule name to be prefixed with test_: {}",
output
);
assert!(
output.contains("(set (a) (Num 1))"),
"Expected let desugared"
);
}
#[test]
fn test_multiple_macros_compose_with_desugar_program() {
let mut egraph = EGraph::default();
egraph
.command_macros_mut()
.register(Arc::new(PrefixRuleMacro {
prefix: "first".to_string(),
}));
egraph
.command_macros_mut()
.register(Arc::new(PrefixRuleMacro {
prefix: "second".to_string(),
}));
let input = r#"
(datatype Math (Num i64))
(rule ((Num x)) ((Num (+ x 1))))"#;
let result = egraph
.desugar_program(None, input)
.unwrap()
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let output = result.join("\n");
assert!(
output.contains(format!("second_{INTERNAL_SYMBOL_PREFIX}first_").as_str()),
"Expected rule name to have both prefixes in order: {}",
output
);
}
#[test]
fn test_duplicate_macro_creates_two_rules() {
let mut egraph = EGraph::default();
egraph
.command_macros_mut()
.register(Arc::new(DuplicateRuleMacro));
let input = r#"
(datatype Math (Num i64))
(rule ((Num x)) ((Num (+ x 1))))
"#;
let result = egraph
.desugar_program(None, input)
.unwrap()
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let output = result.join("\n");
assert!(
output.contains("dup1_"),
"Expected first duplicated rule: {}",
output
);
assert!(
output.contains("dup2_"),
"Expected second duplicated rule: {}",
output
);
let rule_count = output.matches("(rule ").count();
assert_eq!(
rule_count, 2,
"Expected exactly 2 rules, got {}",
rule_count
);
}
#[test]
fn test_macro_adds_commands_after_rules() {
let mut egraph = EGraph::default();
egraph
.command_macros_mut()
.register(Arc::new(AddPrintSizeAfterRuleMacro));
let input = r#"
(datatype Math (Num i64))
(rule ((Num x)) ((Num (+ x 1))))
(let a (Num 1))
"#;
let result = egraph
.desugar_program(None, input)
.unwrap()
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let output = result.join("\n");
assert!(
output.contains("(print-size Num)"),
"Expected print-size marker after rule: {}",
output
);
assert!(
output.contains("(set (a) (Num 1))"),
"Expected desugared a: {}",
output
);
}
#[test]
fn test_complex_macro_composition() {
let mut egraph = EGraph::default();
egraph
.command_macros_mut()
.register(Arc::new(DuplicateRuleMacro));
egraph
.command_macros_mut()
.register(Arc::new(PrefixRuleMacro {
prefix: "prefixed".to_string(),
}));
let input = r#"
(datatype Math (Num i64))
(rule ((Num x)) ((Num (+ x 1))))"#;
let result = egraph
.desugar_program(None, input)
.unwrap()
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let output = result.join("\n");
assert!(
output.contains(&format!("prefixed_{INTERNAL_SYMBOL_PREFIX}dup1_")),
"Expected first rule with both transformations: {}",
output
);
assert!(
output.contains(&format!("prefixed_{INTERNAL_SYMBOL_PREFIX}dup2_")),
"Expected second rule with both transformations: {}",
output
);
let rule_count = output.matches("(rule ").count();
assert_eq!(rule_count, 2, "Expected exactly 2 rules after duplication");
}
#[test]
fn test_macros_work_with_actual_program_execution() {
let mut egraph = EGraph::default();
egraph
.command_macros_mut()
.register(Arc::new(DuplicateRuleMacro));
let result = egraph.parse_and_run_program(
None,
r#"
(datatype Math (Num i64))
(rule ((Num x)) ((Num (+ x 1))))
(let a (Num 1))
(run 1)
(check (= a (Num 1)))
"#,
);
assert!(
result.is_ok(),
"Program with duplicated rules should run: {:?}",
result
);
}
struct TypeInfoReader;
impl CommandMacro for TypeInfoReader {
fn transform(
&self,
command: Command,
symbol_gen: &mut util::SymbolGen,
type_info: &TypeInfo,
) -> Result<Vec<Command>, Error> {
match command {
Command::Rule { rule } => {
type_info.typecheck_facts(symbol_gen, &rule.body)?;
Ok(vec![Command::Rule { rule }])
}
cmd => Ok(vec![cmd]),
}
}
}
#[test]
fn test_macro_accesses_type_info() {
let mut egraph = EGraph::default();
egraph
.command_macros_mut()
.register(Arc::new(TypeInfoReader));
let result = egraph.parse_and_run_program(
None,
r#"
(datatype Math (Num i64))
(rule ((Num x)) ((Num (+ x 1))))
(let a (Num 1))
(constructor B () Math)
(union a (B))
(check (= (B) (Num 1)))
"#,
);
assert!(
result.is_ok(),
"Program with type info reading macro should run: {:?}",
result
);
}