use egglog::{extract::DefaultCost, *};
use egglog_ast::span::{RustSpan, Span};
#[test]
fn globals_missing_prefix_warns_by_default() {
testing_logger::setup();
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(None, "(let value 41)")
.unwrap();
testing_logger::validate(|logs| {
let bodies: Vec<_> = logs.iter().map(|entry| entry.body.clone()).collect();
assert!(
bodies
.iter()
.any(|body| body.contains("Global `value` should start with `$`")),
"expected warning about missing global prefix, got logs: {:?}",
bodies
);
});
}
#[test]
fn globals_missing_prefix_warns_for_prefixed_pattern_variable_by_default() {
testing_logger::setup();
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(None, "(rule ((= $x 1)) ())")
.unwrap();
testing_logger::validate(|logs| {
let bodies: Vec<_> = logs.iter().map(|entry| entry.body.clone()).collect();
assert!(
bodies
.iter()
.any(|body| body.contains("Global `x` should start with `$`")),
"expected warning about missing global prefix, got logs: {:?}",
bodies
);
});
}
#[test]
fn globals_missing_prefix_warns_for_prefixed_rule_let_by_default() {
testing_logger::setup();
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(None, "(rule () ((let $y 1)))")
.unwrap();
testing_logger::validate(|logs| {
let bodies: Vec<_> = logs.iter().map(|entry| entry.body.clone()).collect();
assert!(
bodies
.iter()
.any(|body| body.contains("Global `y` should start with `$`")),
"expected warning about missing global prefix, got logs: {:?}",
bodies
);
});
}
#[test]
fn globals_missing_prefix_errors_when_opted_in() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
egraph.set_strict_mode(true);
let err = egraph
.parse_and_run_program(None, "(let value 41)")
.unwrap_err();
match err {
Error::TypeError(TypeError::GlobalMissingPrefix { ref name, .. }) => {
assert_eq!(name, "value");
}
other => panic!("expected missing dollar error, got {other:?}"),
}
}
#[test]
fn globals_missing_prefix_errors_for_prefixed_pattern_variable_when_opted_in() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
egraph.set_strict_mode(true);
let err = egraph
.parse_and_run_program(None, "(rule ((= $x 1)) ())")
.unwrap_err();
match err {
Error::TypeError(TypeError::NonGlobalPrefixed { ref name, .. }) => {
assert_eq!(name, "$x");
}
other => panic!("expected non-global prefixed variable error, got {other:?}"),
}
}
#[test]
fn globals_missing_prefix_errors_for_prefixed_rule_let_when_opted_in() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
egraph.set_strict_mode(true);
let err = egraph
.parse_and_run_program(None, "(rule () ((let $y 1)))")
.unwrap_err();
match err {
Error::TypeError(TypeError::NonGlobalPrefixed { ref name, .. }) => {
assert_eq!(name, "$y");
}
other => panic!("expected non-global prefixed variable error, got {other:?}"),
}
}
#[test]
fn globals_cannot_be_shadowed_by_pattern_variables() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let program = r#"
(let $value 41)
(rule ((= value $value)) ())
"#;
let err = egraph.parse_and_run_program(None, program).unwrap_err();
match err {
Error::Shadowing(message, _global_span, _shadow_span) => {
assert!(message.contains("pattern variable `value`"));
assert!(message.contains("global `$value`"));
}
other => panic!("expected shadowing error, got {other:?}"),
}
}
#[test]
fn test_simple_extract1() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Op (Add i64 i64))
(let expr (Add 1 1))
(extract expr)"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, 3);
}
#[test]
fn primitive_error_in_extract_returns_error() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let err = egraph
.parse_and_run_program(None, "(extract (<< 1 10000))")
.unwrap_err();
assert!(err.to_string().contains("call of primitive << failed"));
}
#[test]
fn primitive_error_in_run_schedule_returns_error() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let program = r#"
(ruleset problematic)
(rule ()
((let tmp (<< 1 10000)))
:ruleset problematic)
(run-schedule (run problematic))
"#;
let err = egraph.parse_and_run_program(None, program).unwrap_err();
assert!(err.to_string().contains("call of primitive << failed"));
}
#[test]
fn test_simple_extract2() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Term
(Origin :cost 0)
(BigStep Term :cost 10)
(SmallStep Term :cost 1)
)
(let t (Origin))
(let tb (BigStep t))
(let tbs (SmallStep tb))
(let ts (SmallStep t))
(let tss (SmallStep ts))
(let tsss (SmallStep tss))
(union tbs tsss)
(let tssss (SmallStep tsss))
(union tssss tb)
(extract tb)
"#,
)
.unwrap();
assert!(matches!(outputs[0], CommandOutput::ExtractBest(_, 4, _)));
}
#[test]
fn test_simple_extract3() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Fruit
(Apple i64 :cost 1)
(Orange f64 :cost 2)
)
(datatype Vegetable
(Broccoli bool :cost 3)
(Carrot Fruit :cost 4)
)
(let a (Apple 5))
(let o (Orange 3.14))
(let b (Broccoli true))
(let c (Carrot a))
(extract a)
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, 2);
}
#[test]
fn test_simple_extract4() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
r#"
(datatype Foo
(Foobar)
)
(subsume (Foobar))
(datatype Bar
(Barbar Foo)
)
(let x (Barbar (Foobar)))
(extract x)
"#,
)
.unwrap_err();
}
#[test]
fn test_simple_extract5() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Foo
(Foobar i64)
)
(let foobar (Foobar 42))
(datatype Bar
(Barfoo i64)
)
(let barfoo (Barfoo 24))
(sort QuaVec (Vec i64))
(sort QuaMap (Map QuaVec Foo))
(function Quaz () QuaMap :no-merge)
(set (Quaz) (map-empty))
(extract (Quaz))
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, 0);
}
#[test]
fn test_simple_extract6() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype False)
(sort QuaVec (Vec i64))
(sort QuaMap (Map QuaVec False))
(function Quaz () QuaMap :no-merge)
(set (Quaz) (map-empty))
(extract (Quaz))
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, 0);
}
#[test]
fn test_simple_extract7() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Foo
(bar)
(baz)
)
(sort Mapsrt1 (Map i64 Foo))
(let map1 (map-insert (map-empty) 0 (bar)))
(sort Mapsrt2 (Map bool Foo))
(let map2 (map-insert (map-empty) false (baz)))
;(let map2b (map-insert (map-empty) false (bar)))
;(union map2 map2b)
;(extract map1)
;(extract map2)
;(function toerr (Mapsrt2) Foo :no-merge)
;(set (toerr map2) (bar))
(union (bar) (baz))
; Also unions map1 and map2!?
(extract map1)
(extract map2)
;(extract (toerr map2))
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, 2);
}
#[test]
fn test_simple_extract8() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Foo
(bar :cost 10)
)
(function func () Foo :no-merge)
(set (func) (bar))
(extract (bar))
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, 10);
}
#[test]
fn test_simple_extract9() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Foo)
(datatype NodeA)
(datatype NodeB)
(datatype NodeC)
(constructor ctoa (NodeC) NodeA)
(constructor atob (NodeA) NodeB)
(constructor btoc (NodeB) NodeC)
(constructor bar () Foo :cost 9223372036854775807)
(constructor barbar (Foo Foo) Foo :cost 2)
(constructor groundedA (Foo) NodeA)
(let a (groundedA (barbar (bar) (bar))))
(let b (atob a))
(let c (btoc b))
(let a2 (ctoa c))
(let b2 (atob a2))
(let c2 (btoc b2))
(union a a2)
(union b b2)
(union c c2)
(extract a)
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, DefaultCost::MAX);
let outputs = egraph
.parse_and_run_program(
None,
r#"
(extract b)
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, DefaultCost::MAX);
let outputs = egraph
.parse_and_run_program(
None,
r#"
(extract c)
"#,
)
.unwrap();
let CommandOutput::ExtractBest(_, cost, _) = outputs[0] else {
panic!();
};
assert_eq!(cost, DefaultCost::MAX);
}
#[test]
fn test_extract_variants1() {
let _ = env_logger::builder().is_test(true).try_init();
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Term
(Origin :cost 0)
(BigStep Term :cost 10)
(SmallStep Term :cost 1)
)
(let t (Origin))
(let tb (BigStep t))
(let tbs (SmallStep tb))
(let ts (SmallStep t))
(let tss (SmallStep ts))
(let tsss (SmallStep tss))
(union tbs tsss)
(let tssss (SmallStep tsss))
(union tssss tb)
(extract tb 3)
"#,
)
.unwrap();
assert_eq!(
outputs[0].to_string(),
"(\n (SmallStep (SmallStep (SmallStep (SmallStep (Origin)))))\n (BigStep (Origin))\n)\n"
);
}
#[test]
fn test_subsumed_unextractable_action_extract() {
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(constructor exp () Math :cost 100)
(constructor cheap () Math :cost 1)
(union (exp) (cheap))
(extract (exp))
"#,
)
.unwrap();
assert!(match &outputs[0] {
CommandOutput::ExtractBest(termdag, _, term_id) => {
matches!(termdag.get(*term_id), Term::App(s, ..) if s == "cheap")
}
_ => false,
});
let outputs = egraph
.parse_and_run_program(
None,
r#"
(subsume (cheap))
(extract (exp))
"#,
)
.unwrap();
assert!(match &outputs[0] {
CommandOutput::ExtractBest(termdag, _, term_id) => {
matches!(termdag.get(*term_id), Term::App(s, ..) if s == "exp")
}
_ => false,
});
}
#[test]
fn test_subsume_unextractable_insert_and_merge() {
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Expr
(f Expr)
(Num i64))
(constructor exp () Expr :cost 100)
(f (Num 1))
(subsume (f (Num 1)))
(f (Num 2))
(union (Num 2) (Num 1))
(union (f (Num 2)) (exp))
(extract (f (Num 2)))
"#,
)
.unwrap();
assert!(match &outputs[0] {
CommandOutput::ExtractBest(termdag, _, term_id) => {
matches!(termdag.get(*term_id), Term::App(s, ..) if s == "exp")
}
_ => false,
});
}
#[test]
fn test_subsume_unextractable_action_extract_multiple() {
let mut egraph = EGraph::default();
let outputs1 = egraph
.parse_and_run_program(
None,
"
(datatype Math (Num i64))
(Num 1)
(union (Num 1) (Num 2))
(extract (Num 1) 2)
",
)
.unwrap();
assert!(matches!(
outputs1[0],
CommandOutput::ExtractVariants(_, ref terms) if terms.len() == 2
));
let outputs2 = egraph
.parse_and_run_program(
None,
"
(subsume (Num 2))
(extract (Num 1) 2)
",
)
.unwrap();
assert!(matches!(
outputs2[0],
CommandOutput::ExtractVariants(_, ref terms) if terms.len() == 1
));
}
#[test]
fn test_rewrite_subsumed_unextractable() {
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(constructor exp () Math :cost 100)
(constructor cheap () Math :cost 1)
(rewrite (cheap) (exp) :subsume)
(cheap)
(run 1)
(extract (cheap))
"#,
)
.unwrap();
assert_eq!(outputs[1].to_string(), "(exp)\n");
}
#[test]
fn test_rewrite_subsumed() {
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(constructor exp () Math :cost 100)
(constructor most-exp () Math :cost 1000)
(rewrite (most-exp) (exp) :subsume)
(most-exp)
(run 1)
(constructor cheap () Math :cost 1)
(rewrite (most-exp) (cheap))
(run 1)
(extract (most-exp))
"#,
)
.unwrap();
assert_eq!(outputs[2].to_string(), "(exp)\n");
}
#[test]
fn test_subsume() {
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Math
(Add Math Math)
(Num i64))
(rewrite (Add a b) (Add b a))
(let x (Add (Num 1) (Num 2)))
(let y (Add (Num 3) (Num 4)))
(subsume (Add (Num 1) (Num 2)))
(run 1)
(extract y 10)
"#,
)
.unwrap();
assert!(matches!(
outputs[1],
CommandOutput::ExtractVariants(_, ref terms) if terms.len() == 2
));
let outputs = egraph
.parse_and_run_program(
None,
r#"
;; add something equal to x that can be extracted:
(constructor otherConst () Math)
(let other (otherConst))
(union x other)
(extract x 10)
"#,
)
.unwrap();
assert!(matches!(
outputs[0],
CommandOutput::ExtractVariants(_, ref terms) if terms.len() == 1
));
}
#[test]
fn test_subsume_custom() {
let mut egraph = EGraph::default();
let res = egraph.parse_and_run_program(
None,
r#"
(function one () i64 :no-merge)
(set (one) 1)
(subsume (one))
"#,
);
assert!(res.is_err());
}
#[test]
fn test_subsume_ok() {
let mut egraph = EGraph::default();
let res = egraph.parse_and_run_program(
None,
r#"
(sort E)
(constructor one () E)
(constructor two () E)
(one)
(subsume (one))
;; subsuming a non-existent tuple
(subsume (two))
(relation R (i64))
(R 1)
(subsume (R 1))
(subsume (R 2))
"#,
);
assert!(res.is_ok());
}
#[test]
fn test_cant_subsume_merge() {
let mut egraph = EGraph::default();
let res = egraph.parse_and_run_program(
None,
r#"
(constructor one () i64 :merge old)
(set (one) 1)
(subsume (one))
"#,
);
assert!(res.is_err());
}
#[test]
fn test_value_to_classid() {
let mut egraph = EGraph::default();
let outputs = egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(constructor exp () Math )
(exp)
(extract (exp))
"#,
)
.unwrap();
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs[0].clone() else {
panic!();
};
let expr = termdag.term_to_expr(&term, span!());
let (sort, value) = egraph.eval_expr(&expr).unwrap();
let serialize_output = egraph.serialize(SerializeConfig::default());
assert!(serialize_output.is_complete());
let class_id = egraph.value_to_class_id(&sort, value);
assert!(serialize_output.egraph.class_data.get(&class_id).is_some());
assert_eq!(value, egraph.class_id_to_value(&class_id));
}
#[test]
fn test_serialize_617() {
let program = "
(sort Node)
(constructor mk (i64) Node)
(constructor mkb (i64) Node)
(rewrite (mkb x) (mk x))
(mkb 1) (mkb 3) (mkb 5) (mkb 6)
(union (mk 1) (mk 3))
(union (mk 3) (mk 5))
(run-schedule (saturate (run)))";
let mut egraph = EGraph::default();
egraph.parse_and_run_program(None, program).unwrap();
let serialize_output = egraph.serialize(SerializeConfig::default());
assert!(serialize_output.is_complete());
assert_eq!(serialize_output.egraph.class_data.len(), 6);
assert_eq!(serialize_output.egraph.nodes.len(), 12);
}
#[test]
fn test_serialize_subsume_status() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
r#"
(datatype Math)
(constructor a () Math )
(constructor b () Math )
(a)
(b)
(subsume (a))
"#,
)
.unwrap();
let serialize_output = egraph.serialize(SerializeConfig::default());
assert!(serialize_output.is_complete());
let a_id = egraph.to_node_id(
None,
egglog::SerializedNode::Function {
name: "a".into(),
offset: 0,
},
);
let b_id = egraph.to_node_id(
None,
egglog::SerializedNode::Function {
name: "b".into(),
offset: 0,
},
);
assert!(serialize_output.egraph.nodes[&a_id].subsumed);
assert!(!serialize_output.egraph.nodes[&b_id].subsumed);
}
#[test]
fn test_shadowing_query() {
let s = "(function f () i64 :no-merge) (set (f) 2) (check (= (f) f) (= f 2))";
let e = EGraph::default()
.parse_and_run_program(None, s)
.unwrap_err();
assert!(matches!(e, Error::Shadowing(_, _, _)));
}
#[test]
fn test_shadowing_push() {
let s = "(push) (let x 1) (pop) (let x 1)";
EGraph::default().parse_and_run_program(None, s).unwrap();
}
#[test]
fn test_print_function_size() {
let s = "(function f () i64 :no-merge) (set (f) 2) (print-size f)";
let outputs = EGraph::default().parse_and_run_program(None, s).unwrap();
assert_eq!(outputs[0].to_string(), "1\n");
}
#[test]
fn test_print_function() {
let s = "(function f () i64 :no-merge) (set (f) 2) (print-function f)";
let outputs = EGraph::default().parse_and_run_program(None, s).unwrap();
assert_eq!(outputs[0].to_string(), "(\n (f) -> 2\n)\n");
}
#[test]
fn test_print_function_csv() {
let s = "(function f () i64 :no-merge) (set (f) 2) (print-function f :mode csv)";
let outputs = EGraph::default().parse_and_run_program(None, s).unwrap();
assert_eq!(outputs[0].to_string(), "f,2\n");
}
#[test]
fn test_print_stats() {
let s = "(run 1) (print-stats)";
let outputs = EGraph::default().parse_and_run_program(None, s).unwrap();
assert_eq!(
outputs[1].to_string(),
"Overall statistics:\nRuleset : search 0.000s, merge 0.000s, rebuild 0.000s\n"
);
}
#[test]
fn test_run_report() {
let s = "(run 1)";
let outputs = EGraph::default().parse_and_run_program(None, s).unwrap();
assert_eq!(outputs[0].to_string(), "");
assert!(matches!(outputs[0], CommandOutput::RunSchedule(..)));
}
#[test]
fn test_serialize_message_max_functions() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
r#"
(datatype A)
(constructor a () A)
(constructor b () A)
(constructor c () A)
(a) (b) (c)
"#,
)
.unwrap();
let serialize_output = egraph.serialize(SerializeConfig {
max_functions: Some(2),
max_calls_per_function: None,
include_temporary_functions: false,
root_eclasses: vec![],
});
assert!(!serialize_output.is_complete());
assert_eq!(serialize_output.omitted_description(), "Omitted: c\n");
}
#[test]
fn test_serialize_message_max_calls_per_function() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
r#"
(datatype N)
(constructor mk (i64) N)
(mk 0) (mk 1) (mk 2) (mk 3)
"#,
)
.unwrap();
let serialize_output = egraph.serialize(SerializeConfig {
max_functions: None,
max_calls_per_function: Some(2),
include_temporary_functions: false,
root_eclasses: vec![],
});
assert!(!serialize_output.is_complete());
assert_eq!(serialize_output.omitted_description(), "Truncated: mk\n");
}