use std::{
iter,
ops::Range,
sync::{Arc, Mutex},
};
use egglog_reports::ReportLevel;
use crate::numeric_id::NumericId;
use crate::{
PlanStrategy,
action::WriteVal,
common::Value,
free_join::{CounterId, Database, TableId},
make_external_func,
query::RuleSetBuilder,
table::SortedWritesTable,
table_shortcuts::v,
table_spec::{ColumnId, Constraint},
uf::DisplacedTable,
};
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[test]
fn basic_query() {
let MathEgraph {
num,
add,
id_counter,
mut db,
..
} = basic_math_egraph();
db.base_values_mut().register_type::<i64>();
let add_int = db.add_external_function(Box::new(make_external_func(|exec_state, args| {
let [x, y] = args else { panic!() };
let x: i64 = exec_state.base_values().unwrap(*x);
let y: i64 = exec_state.base_values().unwrap(*y);
let z: i64 = x + y;
Some(exec_state.base_values().get(z))
})));
let mut ids = Vec::new();
{
let mut num_buf = db.new_buffer(num);
for i in 0..10 {
let id = db.inc_counter(id_counter);
let i = db.base_values().get::<i64>(i as i64);
ids.push(i);
num_buf.stage_insert(&[i, Value::from_usize(id), Value::new(0)]);
}
}
db.merge_all();
let mut add_ids = Vec::new();
{
let mut add_buf = db.new_buffer(add);
for i in ids.chunks(2) {
let &[x, y] = i else { unreachable!() };
let id = Value::from_usize(db.inc_counter(id_counter));
add_ids.push(id);
add_buf.stage_insert(&[x, y, id, Value::new(0)]);
}
}
db.merge_all();
let mut rsb = RuleSetBuilder::new(&mut db);
let mut add_query = rsb.new_rule();
let x = add_query.new_var_named("x");
let y = add_query.new_var_named("y");
let z = add_query.new_var_named("z");
let t1 = add_query.new_var_named("t1");
let t2 = add_query.new_var_named("t2");
let t3 = add_query.new_var_named("t3");
let a = add_query.new_var_named("a");
let b = add_query.new_var_named("b");
add_query
.add_atom(add, &[x.into(), y.into(), z.into(), t1.into()], &[])
.unwrap();
add_query
.add_atom(num, &[a.into(), x.into(), t2.into()], &[])
.unwrap();
add_query
.add_atom(num, &[b.into(), y.into(), t3.into()], &[])
.unwrap();
let mut rules = add_query.build();
let add_a_b = rules.call_external(add_int, &[a.into(), b.into()]).unwrap();
rules
.insert(num, &[add_a_b.into(), z.into(), Value::new(1).into()])
.unwrap();
rules.build_with_description("add");
let rule_set = rsb.build();
let report = db.run_rule_set(&rule_set, ReportLevel::TimeOnly);
assert!(report.changed, "{report:?}");
assert_eq!(report.num_matches("add"), 5, "{report:?}");
let num_table = db.get_table(num);
let all_num = num_table.all();
let items = num_table.scan(all_num.as_ref());
let mut res = Vec::from_iter(
items
.iter()
.map(|(_, row)| db.base_values().unwrap::<i64>(row[0])),
);
res.sort();
assert_eq!(res, Vec::from_iter((0..10).chain([13, 17].into_iter())));
}
#[test]
fn line_graph_1_fj_puresize() {
line_graph_1_test(PlanStrategy::PureSize);
}
#[test]
fn line_graph_1_fj_mincover() {
line_graph_1_test(PlanStrategy::MinCover);
}
#[test]
fn line_graph_1_gj() {
line_graph_1_test(PlanStrategy::Gj);
}
fn line_graph_1_test(strat: PlanStrategy) {
let mut db = Database::default();
let edge_impl = SortedWritesTable::new(
2,
2,
None,
vec![],
Box::new(move |_, a, b, _| {
if a != b {
panic!("merge not supported")
} else {
false
}
}),
);
let edges = db.add_table(edge_impl, iter::empty(), iter::empty());
let nodes = Vec::from_iter((0..10).map(Value::new));
{
let mut edge_buf = db.new_buffer(edges);
for edge in nodes.windows(2) {
edge_buf.stage_insert(edge);
}
}
db.merge_all();
let mut rsb = RuleSetBuilder::new(&mut db);
let mut query = rsb.new_rule();
query.set_plan_strategy(strat);
let x = query.new_var_named("x");
let y = query.new_var_named("y");
let z = query.new_var_named("z");
query.add_atom(edges, &[x.into(), y.into()], &[]).unwrap();
query.add_atom(edges, &[y.into(), z.into()], &[]).unwrap();
let mut rule = query.build();
rule.insert(edges, &[x.into(), z.into()]).unwrap();
rule.build();
let rule_set = rsb.build();
assert!(db.run_rule_set(&rule_set, ReportLevel::TimeOnly).changed);
let mut expected = Vec::from_iter(
nodes
.windows(2)
.map(|x| vec![x[0], x[1]])
.chain(nodes.windows(3).map(|x| vec![x[0], x[2]])),
);
expected.sort();
let edges_table = db.get_table(edges);
let all = edges_table.all();
let vals = edges_table.scan(all.as_ref());
let mut got = Vec::from_iter(vals.iter().map(|(_, row)| row.to_vec()));
got.sort();
assert_eq!(expected, got);
}
#[test]
fn line_graph_2_fj_puresize() {
line_graph_2_test(PlanStrategy::PureSize);
}
#[test]
fn line_graph_2_fj_mincover() {
line_graph_2_test(PlanStrategy::MinCover);
}
#[test]
fn line_graph_2_gj() {
line_graph_2_test(PlanStrategy::Gj);
}
fn line_graph_2_test(strat: PlanStrategy) {
let mut db = Database::default();
let edge_impl = SortedWritesTable::new(
2,
2,
None,
vec![],
Box::new(move |_, a, b, _| {
if a != b {
panic!("merge not supported")
} else {
false
}
}),
);
let edges = db.add_table(edge_impl, iter::empty(), iter::empty());
let nodes = Vec::from_iter((0..10).map(Value::new));
{
let mut edge_buf = db.new_buffer(edges);
for edge in nodes.windows(2) {
edge_buf.stage_insert(edge);
}
}
db.merge_all();
let mut rsb = RuleSetBuilder::new(&mut db);
let mut query = rsb.new_rule();
query.set_plan_strategy(strat);
let x = query.new_var_named("x");
let y = query.new_var_named("y");
let z = query.new_var_named("z");
query
.add_atom(
edges,
&[x.into(), y.into()],
&[Constraint::GtConst {
col: ColumnId::new(1),
val: Value::new(1),
}],
)
.unwrap();
query.add_atom(edges, &[y.into(), z.into()], &[]).unwrap();
let mut rule = query.build();
rule.insert(edges, &[x.into(), z.into()]).unwrap();
rule.build();
let rule_set = rsb.build();
assert!(db.run_rule_set(&rule_set, ReportLevel::TimeOnly).changed);
let mut expected = Vec::from_iter(
nodes.windows(2).map(|x| vec![x[0], x[1]]).chain(
nodes
.windows(3)
.filter(|x| x[1] > Value::new(1))
.map(|x| vec![x[0], x[2]]),
),
);
expected.sort();
let edges_table = db.get_table(edges);
let all = edges_table.all();
let vals = edges_table.scan(all.as_ref());
let mut got = Vec::from_iter(vals.iter().map(|(_, row)| row.to_vec()));
got.sort();
assert_eq!(expected, got);
}
fn intersection_test(strat: PlanStrategy) {
let mut db = Database::default();
let rst = (0..3).map(|_| {
SortedWritesTable::new(
2,
2,
None,
vec![],
Box::new(move |_, a, b, _| {
if a != b {
panic!("merge not supported")
} else {
false
}
}),
)
});
let u = SortedWritesTable::new(
1,
1,
None,
vec![],
Box::new(move |_, a, b, _| {
if a != b {
panic!("merge not supported")
} else {
false
}
}),
);
let rst_ids = rst
.map(|r| db.add_table(r, iter::empty(), iter::empty()))
.collect::<Vec<TableId>>();
let u_id = db.add_table(u, iter::empty(), iter::empty());
for rel in rst_ids.iter() {
let mut rel_buf = db.new_buffer(*rel);
for x in 0..10 {
rel_buf.stage_insert(&[Value::new(x), Value::new(x)]);
}
}
db.merge_all();
let mut rsb = RuleSetBuilder::new(&mut db);
let mut query = rsb.new_rule();
query.set_plan_strategy(strat);
let x = query.new_var_named("x");
for rel in rst_ids.iter() {
query
.add_atom(
*rel,
&[x.into(), x.into()],
&[Constraint::GtConst {
col: ColumnId::new(0),
val: Value::new(5),
}],
)
.unwrap();
}
let mut rule = query.build();
rule.insert(u_id, &[x.into()]).unwrap();
rule.build();
let rule_set = rsb.build();
assert!(db.run_rule_set(&rule_set, ReportLevel::TimeOnly).changed);
let expected = Vec::from_iter((6..10).map(|x| vec![Value::new(x)]));
let u_table = db.get_table(u_id);
let all = u_table.all();
let vals = u_table.scan(all.as_ref());
let mut got = Vec::from_iter(vals.iter().map(|(_, row)| row.to_vec()));
got.sort();
assert_eq!(expected, got);
}
#[test]
fn intersection_test_fj_puresize() {
intersection_test(PlanStrategy::PureSize);
}
#[test]
fn intersection_test_fj_mincover() {
intersection_test(PlanStrategy::MinCover);
}
#[test]
fn intersection_test_gj() {
intersection_test(PlanStrategy::Gj);
}
#[test]
fn minimal_ac() {
let MathEgraph {
add,
id_counter,
mut db,
..
} = basic_math_egraph();
{
{
let mut add_buf = db.new_buffer(add);
add_buf.stage_insert(&[v(0), v(0), v(1), v(0)]);
add_buf.stage_insert(&[v(0), v(1), v(2), v(0)]);
add_buf.stage_insert(&[v(0), v(2), v(3), v(0)]);
}
db.merge_all();
{
let mut add_buf = db.new_buffer(add);
add_buf.stage_insert(&[v(1), v(0), v(2), v(1)]);
add_buf.stage_insert(&[v(1), v(1), v(3), v(1)]);
}
db.merge_all();
}
let mut rsb = db.new_rule_set();
let mut add_assoc = rsb.new_rule();
let x = add_assoc.new_var_named("x");
let y = add_assoc.new_var_named("y");
let z = add_assoc.new_var_named("z");
let i1 = add_assoc.new_var_named("i1");
let i2 = add_assoc.new_var_named("i2");
let t1 = add_assoc.new_var_named("t1");
let t2 = add_assoc.new_var_named("t2");
add_assoc
.add_atom(
add,
&[y.into(), z.into(), i1.into(), t1.into()],
&[
Constraint::GeConst {
col: ColumnId::new(3),
val: v(0),
},
Constraint::LtConst {
col: ColumnId::new(3),
val: v(1),
},
],
)
.unwrap();
add_assoc
.add_atom(
add,
&[x.into(), i1.into(), i2.into(), t2.into()],
&[
Constraint::GeConst {
col: ColumnId::new(3),
val: v(1),
},
Constraint::LtConst {
col: ColumnId::new(3),
val: v(2),
},
],
)
.unwrap();
let mut rules = add_assoc.build();
let res = rules
.lookup_or_insert(
add,
&[x.into(), y.into()],
&[
WriteVal::IncCounter(id_counter),
WriteVal::QueryEntry(v(2).into()),
],
ColumnId::new(2),
)
.unwrap();
rules
.insert(add, &[res.into(), z.into(), i2.into(), v(2).into()])
.unwrap();
rules.build();
let rule_set = rsb.build();
db.run_rule_set(&rule_set, ReportLevel::TimeOnly);
let add_table = db.get_table(add);
let all_add = add_table.all();
let items = add_table.scan(all_add.as_ref());
let mut res = Vec::from_iter(items.iter().map(|(_, row)| row.to_vec()));
res.sort();
let expected = vec![
vec![v(0), v(0), v(1), v(0)],
vec![v(0), v(1), v(2), v(0)],
vec![v(0), v(2), v(3), v(0)],
vec![v(1), v(0), v(2), v(1)],
vec![v(1), v(1), v(3), v(1)],
vec![v(2), v(0), v(3), v(2)],
];
assert_eq!(res, expected);
}
#[test]
fn ac_gj() {
ac_test(PlanStrategy::Gj);
}
#[test]
fn ac_fj_mincover() {
ac_test(PlanStrategy::MinCover);
}
#[test]
fn ac_fj_puresize() {
ac_test(PlanStrategy::PureSize);
}
fn ac_test(strat: PlanStrategy) {
const N: usize = 5;
let MathEgraph {
num,
add,
id_counter,
mut db,
uf,
} = basic_math_egraph();
let mut ids = Vec::new();
db.base_values_mut().register_type::<i64>();
for i in 0..N {
let id = db.inc_counter(id_counter);
let i = db.base_values().get::<i64>(i as i64);
ids.push(i);
db.new_buffer(num)
.stage_insert(&[i, Value::from_usize(id), Value::new(0)]);
}
db.merge_all();
let (left_root, right_root) = {
let mut add_ids = Vec::new();
let mut prev = ids[0];
for num in &ids[1..] {
let id = Value::from_usize(db.inc_counter(id_counter));
db.new_buffer(add)
.stage_insert(&[*num, prev, id, Value::new(0)]);
prev = id;
add_ids.push(id);
}
let left_root = *add_ids.last().unwrap();
add_ids.clear();
prev = *ids.last().unwrap();
for num in ids[0..(N - 1)].iter().rev() {
let id = Value::from_usize(db.inc_counter(id_counter));
db.new_buffer(add)
.stage_insert(&[prev, *num, id, Value::new(0)]);
prev = id;
add_ids.push(id);
}
let right_root = *add_ids.last().unwrap();
(left_root, right_root)
};
db.merge_all();
let run_ac_rule = move |db: &mut Database, recent_range: Range<Value>| {
let old_range = Value::new(0)..recent_range.start;
let all_range = Value::new(0)..recent_range.end;
let next_ts = recent_range.end;
let mut rsb = RuleSetBuilder::new(db);
for (l_range, r_range) in [
(all_range, recent_range.clone()),
(recent_range.clone(), old_range.clone()),
] {
let mut add_assoc = rsb.new_rule();
add_assoc.set_plan_strategy(strat);
let x = add_assoc.new_var_named("x");
let y = add_assoc.new_var_named("y");
let z = add_assoc.new_var_named("z");
let i1 = add_assoc.new_var_named("i1");
let i2 = add_assoc.new_var_named("i2");
let t1 = add_assoc.new_var_named("t1");
let t2 = add_assoc.new_var_named("t2");
add_assoc
.add_atom(
add,
&[y.into(), z.into(), i1.into(), t1.into()],
&[
Constraint::GeConst {
col: ColumnId::new(3),
val: l_range.start,
},
Constraint::LtConst {
col: ColumnId::new(3),
val: l_range.end,
},
],
)
.unwrap();
add_assoc
.add_atom(
add,
&[x.into(), i1.into(), i2.into(), t2.into()],
&[
Constraint::GeConst {
col: ColumnId::new(3),
val: r_range.start,
},
Constraint::LtConst {
col: ColumnId::new(3),
val: r_range.end,
},
],
)
.unwrap();
let mut rules = add_assoc.build();
let res = rules
.lookup_or_insert(
add,
&[x.into(), y.into()],
&[
WriteVal::IncCounter(id_counter),
WriteVal::QueryEntry(next_ts.into()),
],
ColumnId::new(2),
)
.unwrap();
rules
.insert(add, &[res.into(), z.into(), i2.into(), next_ts.into()])
.unwrap();
rules.build();
}
let mut add_comm = rsb.new_rule();
add_comm.set_plan_strategy(strat);
let x = add_comm.new_var_named("x");
let y = add_comm.new_var_named("y");
let z = add_comm.new_var_named("z");
let t1 = add_comm.new_var_named("t1");
add_comm
.add_atom(
add,
&[x.into(), y.into(), z.into(), t1.into()],
&[Constraint::EqConst {
col: ColumnId::new(3),
val: recent_range.start,
}],
)
.unwrap();
let mut rules = add_comm.build();
rules
.insert(add, &[y.into(), x.into(), z.into(), next_ts.into()])
.unwrap();
rules.build();
let rule_set = rsb.build();
db.run_rule_set(&rule_set, ReportLevel::TimeOnly)
};
let rebuild = |db: &mut Database, cur_ts: Value| -> (Value, bool) {
let next_ts = Value::new(cur_ts.rep() + 1);
let mut rsb = db.new_rule_set();
let num_rebuild = |rsb: &mut RuleSetBuilder, cur_ts: Value, next_ts: Value| {
let num_size = rsb.estimate_size(num, None);
let uf_size = rsb.estimate_size(
uf,
Some(Constraint::EqConst {
col: ColumnId::new(2),
val: cur_ts,
}),
);
let mut num_rebuild = rsb.new_rule();
num_rebuild.set_plan_strategy(strat);
if incremental_rebuild(uf_size, num_size) {
let x = num_rebuild.new_var_named("x");
let id = num_rebuild.new_var_named("id");
let t1 = num_rebuild.new_var_named("t1");
num_rebuild
.add_atom(num, &[x.into(), id.into(), t1.into()], &[])
.unwrap();
let mut rules = num_rebuild.build();
let id_canon = rules
.lookup_with_default(uf, &[id.into()], id.into(), ColumnId::new(1))
.unwrap();
rules.assert_ne(id.into(), id_canon.into()).unwrap();
rules
.insert(num, &[x.into(), id_canon.into(), next_ts.into()])
.unwrap();
rules.build();
} else {
let x = num_rebuild.new_var_named("x");
let id = num_rebuild.new_var_named("id");
let t1 = num_rebuild.new_var_named("t1");
let id_new = num_rebuild.new_var_named("id_new");
let t2 = num_rebuild.new_var_named("t2");
num_rebuild
.add_atom(num, &[x.into(), id.into(), t1.into()], &[])
.unwrap();
num_rebuild
.add_atom(
uf,
&[id.into(), id_new.into(), t2.into()],
&[Constraint::EqConst {
col: ColumnId::new(2),
val: cur_ts,
}],
)
.unwrap();
let mut rules = num_rebuild.build();
rules
.insert(num, &[x.into(), id_new.into(), next_ts.into()])
.unwrap();
rules.build();
}
};
num_rebuild(&mut rsb, cur_ts, next_ts);
let mut changed = false;
let add_size = rsb.estimate_size(add, None);
let uf_size = rsb.estimate_size(
uf,
Some(Constraint::EqConst {
col: ColumnId::new(2),
val: cur_ts,
}),
);
if incremental_rebuild(uf_size, add_size) {
let mut add_rebuild_id = rsb.new_rule();
add_rebuild_id.set_plan_strategy(strat);
let x = add_rebuild_id.new_var_named("x");
let y = add_rebuild_id.new_var_named("y");
let id = add_rebuild_id.new_var_named("id");
let t1 = add_rebuild_id.new_var_named("t1");
let id_new = add_rebuild_id.new_var_named("id_new");
let t2 = add_rebuild_id.new_var_named("t2");
add_rebuild_id
.add_atom(add, &[x.into(), y.into(), id.into(), t1.into()], &[])
.unwrap();
add_rebuild_id
.add_atom(
uf,
&[id.into(), id_new.into(), t2.into()],
&[Constraint::EqConst {
col: ColumnId::new(2),
val: cur_ts,
}],
)
.unwrap();
let mut rules = add_rebuild_id.build();
let x_new = rules
.lookup_with_default(uf, &[x.into()], x.into(), ColumnId::new(1))
.unwrap();
let y_new = rules
.lookup_with_default(uf, &[y.into()], y.into(), ColumnId::new(1))
.unwrap();
rules.remove(add, &[x.into(), y.into()]).unwrap();
rules
.insert(
add,
&[x_new.into(), y_new.into(), id_new.into(), next_ts.into()],
)
.unwrap();
rules.build();
let rs = rsb.build();
changed |= db.run_rule_set(&rs, ReportLevel::TimeOnly).changed;
let mut rsb = db.new_rule_set();
num_rebuild(&mut rsb, cur_ts, next_ts);
let mut add_rebuild_l = rsb.new_rule();
add_rebuild_l.set_plan_strategy(strat);
let x = add_rebuild_l.new_var_named("x");
let y = add_rebuild_l.new_var_named("y");
let id = add_rebuild_l.new_var_named("id");
let t1 = add_rebuild_l.new_var_named("t1");
let x_new = add_rebuild_l.new_var_named("x_new");
let t2 = add_rebuild_l.new_var_named("t2");
add_rebuild_l
.add_atom(add, &[x.into(), y.into(), id.into(), t1.into()], &[])
.unwrap();
add_rebuild_l
.add_atom(
uf,
&[x.into(), x_new.into(), t2.into()],
&[Constraint::EqConst {
col: ColumnId::new(2),
val: cur_ts,
}],
)
.unwrap();
let mut rules = add_rebuild_l.build();
let y_new = rules
.lookup_with_default(uf, &[y.into()], y.into(), ColumnId::new(1))
.unwrap();
let id_new = rules
.lookup_with_default(uf, &[id.into()], id.into(), ColumnId::new(1))
.unwrap();
rules.remove(add, &[x.into(), y.into()]).unwrap();
rules
.insert(
add,
&[x_new.into(), y_new.into(), id_new.into(), next_ts.into()],
)
.unwrap();
rules.build();
let rs = rsb.build();
changed |= db.run_rule_set(&rs, ReportLevel::TimeOnly).changed;
let mut rsb = db.new_rule_set();
num_rebuild(&mut rsb, cur_ts, next_ts);
let mut add_rebuild_r = rsb.new_rule();
add_rebuild_r.set_plan_strategy(strat);
let x = add_rebuild_r.new_var_named("x");
let y = add_rebuild_r.new_var_named("y");
let id = add_rebuild_r.new_var_named("id");
let t1 = add_rebuild_r.new_var_named("t1");
let y_new = add_rebuild_r.new_var_named("y_new");
let t2 = add_rebuild_r.new_var_named("t2");
add_rebuild_r
.add_atom(add, &[x.into(), y.into(), id.into(), t1.into()], &[])
.unwrap();
add_rebuild_r
.add_atom(
uf,
&[y.into(), y_new.into(), t2.into()],
&[Constraint::EqConst {
col: ColumnId::new(2),
val: cur_ts,
}],
)
.unwrap();
let mut rules = add_rebuild_r.build();
let x_new = rules
.lookup_with_default(uf, &[x.into()], x.into(), ColumnId::new(1))
.unwrap();
let id_new = rules
.lookup_with_default(uf, &[id.into()], id.into(), ColumnId::new(1))
.unwrap();
rules.remove(add, &[x.into(), y.into()]).unwrap();
rules
.insert(
add,
&[x_new.into(), y_new.into(), id_new.into(), next_ts.into()],
)
.unwrap();
rules.build();
let rs = rsb.build();
changed |= db.run_rule_set(&rs, ReportLevel::TimeOnly).changed;
} else {
let mut rebuild = rsb.new_rule();
rebuild.set_plan_strategy(strat);
let x = rebuild.new_var_named("x");
let y = rebuild.new_var_named("y");
let id = rebuild.new_var_named("id");
let t1 = rebuild.new_var_named("t1");
rebuild
.add_atom(add, &[x.into(), y.into(), id.into(), t1.into()], &[])
.unwrap();
let mut rules = rebuild.build();
let x_canon = rules
.lookup_with_default(uf, &[x.into()], x.into(), ColumnId::new(1))
.unwrap();
let y_canon = rules
.lookup_with_default(uf, &[y.into()], y.into(), ColumnId::new(1))
.unwrap();
let id_canon = rules
.lookup_with_default(uf, &[id.into()], id.into(), ColumnId::new(1))
.unwrap();
rules
.assert_any_ne(
&[x.into(), y.into(), id.into()],
&[x_canon.into(), y_canon.into(), id_canon.into()],
)
.unwrap();
rules.remove(add, &[x.into(), y.into()]).unwrap();
rules
.insert(
add,
&[
x_canon.into(),
y_canon.into(),
id_canon.into(),
next_ts.into(),
],
)
.unwrap();
rules.build();
let rs = rsb.build();
changed |= db.run_rule_set(&rs, ReportLevel::TimeOnly).changed;
}
(next_ts, changed)
};
let mut cur_ts = Value::new(0);
let mut next_ts = Value::new(1);
loop {
if !run_ac_rule(&mut db, cur_ts..next_ts).changed {
break;
}
let start = next_ts;
let mut new_ids_at = start;
let mut changed = true;
while changed {
let (next_ts, rebuild_changed) = rebuild(&mut db, new_ids_at);
new_ids_at = next_ts;
changed = rebuild_changed;
}
cur_ts = start;
next_ts = Value::new(new_ids_at.rep() + 1);
}
let uf_table = db.get_table(uf);
let l_canon = uf_table
.get_row(&[left_root])
.map(|row| row.vals[1])
.unwrap_or(left_root);
let r_canon = uf_table
.get_row(&[right_root])
.map(|row| row.vals[1])
.unwrap_or(right_root);
assert_eq!(l_canon, r_canon);
}
struct MathEgraph {
uf: TableId,
num: TableId,
add: TableId,
id_counter: CounterId,
db: Database,
}
fn basic_math_egraph() -> MathEgraph {
let mut db = Database::default();
let uf = db.add_table(DisplacedTable::default(), iter::empty(), iter::empty());
let num_impl = SortedWritesTable::new(
1,
3,
Some(ColumnId::new(2)),
vec![],
Box::new(move |state, a, b, res| {
if a[1] != b[1] {
state.stage_insert(uf, &[a[1], b[1], b[2]]);
res.extend_from_slice(b);
true
} else {
false
}
}),
);
let id_counter = db.add_counter();
let num = db.add_table(num_impl, iter::once(uf), iter::empty());
let add_impl = SortedWritesTable::new(
2,
4,
Some(ColumnId::new(3)),
vec![],
Box::new(move |state, a, b, res| {
if a[2] != b[2] {
state.stage_insert(uf, &[a[2], b[2], b[3]]);
res.extend_from_slice(b);
true
} else {
false
}
}),
);
let add = db.add_table(add_impl, iter::once(uf), iter::empty());
MathEgraph {
uf,
num,
add,
id_counter,
db,
}
}
fn incremental_rebuild(uf_size: usize, table_size: usize) -> bool {
uf_size / 4 > table_size
}
#[test]
fn lookup_with_fallback_partial_success() {
let mut db = Database::default();
let [f, g, h] = (0..3)
.map(|_| {
db.add_table(
SortedWritesTable::new(
1,
2,
None,
vec![],
Box::new(move |_, a, b, _| {
if a[0] != b[0] {
panic!("merge not supported")
} else {
false
}
}),
),
iter::empty(),
iter::empty(),
)
})
.collect::<Vec<_>>()[..]
else {
unreachable!()
};
{
let mut buf = db.new_buffer(f);
buf.stage_insert(&[v(1), v(0)]);
buf.stage_insert(&[v(2), v(0)]);
}
{
let mut buf = db.new_buffer(g);
buf.stage_insert(&[v(1), v(0)]);
buf.stage_insert(&[v(3), v(0)]);
buf.stage_insert(&[v(4), v(0)]);
buf.stage_insert(&[v(5), v(0)]);
}
db.merge_all();
let log = Arc::new(Mutex::new(Vec::new()));
let log_vals = {
let inner = log.clone();
db.add_external_function(Box::new(make_external_func(move |_, args| {
let [x] = args else { panic!() };
inner.lock().unwrap().push(*x);
Some(*x)
})))
};
let assert_even = db.add_external_function(Box::new(make_external_func(|_, args| {
let [x] = args else { panic!() };
if x.rep().is_multiple_of(2) {
Some(*x)
} else {
None
}
})));
let mut rsb = RuleSetBuilder::new(&mut db);
let mut query = rsb.new_rule();
let x = query.new_var_named("x");
let y = query.new_var_named("y");
query.add_atom(g, &[x.into(), y.into()], &[]).unwrap();
let mut rb = query.build();
let res = rb
.lookup_with_fallback(f, &[x.into()], ColumnId::new(0), assert_even, &[x.into()])
.unwrap();
rb.call_external(log_vals, &[x.into()]).unwrap();
rb.insert(h, &[res.into(), y.into()]).unwrap();
rb.build();
let rs = rsb.build();
assert!(db.run_rule_set(&rs, ReportLevel::TimeOnly).changed);
let h = db.get_table(h);
let all = h.all();
let mut h_contents = h
.scan(all.as_ref())
.iter()
.map(|(_, row)| row.to_vec())
.collect::<Vec<_>>();
h_contents.sort();
assert_eq!(h_contents, vec![vec![v(1), v(0)], vec![v(4), v(0)],]);
let sorted_log = {
let mut log = log.lock().unwrap().clone();
log.sort();
log
};
assert_eq!(sorted_log, vec![v(1), v(4)]);
}
#[test]
fn call_external_with_fallback() {
let mut db = Database::default();
let [f, h] = (0..2)
.map(|_| {
db.add_table(
SortedWritesTable::new(
1,
2,
None,
vec![],
Box::new(move |_, a, b, _| {
if a[0] != b[0] {
panic!("merge not supported")
} else {
false
}
}),
),
iter::empty(),
iter::empty(),
)
})
.collect::<Vec<_>>()[..]
else {
unreachable!()
};
{
let mut buf = db.new_buffer(f);
buf.stage_insert(&[v(1), v(0)]);
buf.stage_insert(&[v(2), v(0)]);
buf.stage_insert(&[v(3), v(0)]);
buf.stage_insert(&[v(5), v(0)]);
}
db.merge_all();
let assert_even = db.add_external_function(Box::new(make_external_func(|_, args| {
let [x] = args else { panic!() };
if x.rep().is_multiple_of(2) {
Some(*x)
} else {
None
}
})));
let inc = db.add_external_function(Box::new(make_external_func(|_, args| {
let [x] = args else { panic!() };
if x.rep() == 5 { None } else { Some(x.inc()) }
})));
let mut rsb = RuleSetBuilder::new(&mut db);
let mut query = rsb.new_rule();
let x = query.new_var_named("x");
let y = query.new_var_named("y");
query.add_atom(f, &[x.into(), y.into()], &[]).unwrap();
let mut rb = query.build();
let res = rb
.call_external_with_fallback(assert_even, &[x.into()], inc, &[x.into()])
.unwrap();
rb.insert(h, &[res.into(), y.into()]).unwrap();
rb.build();
let rs = rsb.build();
assert!(db.run_rule_set(&rs, ReportLevel::TimeOnly).changed);
let h = db.get_table(h);
let all = h.all();
let mut h_contents = h
.scan(all.as_ref())
.iter()
.map(|(_, row)| row.to_vec())
.collect::<Vec<_>>();
h_contents.sort();
assert_eq!(h_contents, vec![vec![v(2), v(0)], vec![v(4), v(0)],]);
}
#[test]
fn early_stop() {
let mut db = Database::default();
let data_table = db.add_table(
SortedWritesTable::new(1, 2, None, vec![], Box::new(|_, _, _, _| false)),
iter::empty(),
iter::empty(),
);
{
let mut buf = db.new_buffer(data_table);
for i in 0..500_000 {
buf.stage_insert(&[Value::from_usize(i), Value::from_usize(i)]);
}
}
db.merge_all();
let call_count = Arc::new(Mutex::new(0usize));
let call_count_clone = call_count.clone();
let stop_trigger =
db.add_external_function(Box::new(make_external_func(move |exec_state, args| {
let mut count = call_count_clone.lock().unwrap();
*count += 1;
if *count >= 1000 {
exec_state.trigger_early_stop();
}
let [x] = args else { panic!() };
Some(*x)
})));
let mut rsb = RuleSetBuilder::new(&mut db);
let mut query = rsb.new_rule();
let x = query.new_var_named("x");
let y = query.new_var_named("y");
query
.add_atom(data_table, &[x.into(), y.into()], &[])
.unwrap();
let mut rb = query.build();
let _ = rb.call_external(stop_trigger, &[x.into()]).unwrap();
rb.build_with_description("early_stop_test");
let rs = rsb.build();
let report = db.run_rule_set(&rs, ReportLevel::TimeOnly);
let matches = report.num_matches("early_stop_test");
assert!(
matches < 100_000,
"Expected much fewer than 10k matches due to early stopping, got {}, (call_count={})",
matches,
call_count.lock().unwrap(),
);
assert!(
matches >= 1000,
"Expected at least 1000 matches before stopping, got {} (call_count={})",
matches,
call_count.lock().unwrap(),
);
let final_count = *call_count.lock().unwrap();
assert!(
final_count >= 1000,
"External function called {} times, should be at least 1000",
final_count
);
assert!(
final_count < 100_000,
"External function called {} times, should be much less than 10k",
final_count
);
}