use std::collections::{BTreeMap, BTreeSet};
use xlog_core::ScalarType;
use xlog_logic::ast::{Atom, BodyLiteral, Rule, Term};
use xlog_logic::hypergraph::{
evaluate_fixpoint_typed, evaluate_rule_typed, evaluate_scc_fixpoint_typed, explain_plans,
plan_rules, AppearanceOrder, FixpointConfig, RefRelation, RefRelationStore, RefValue, RulePlan,
};
fn var(name: &str) -> Term {
Term::Variable(name.to_string())
}
fn atom(predicate: &str, terms: Vec<Term>) -> Atom {
Atom {
predicate: predicate.to_string(),
terms,
}
}
fn pos(predicate: &str, terms: Vec<Term>) -> BodyLiteral {
BodyLiteral::Positive(atom(predicate, terms))
}
fn rule_with(head: Atom, body: Vec<BodyLiteral>) -> Rule {
Rule { head, body }
}
fn u32_relation(rows: &[&[u32]]) -> RefRelation {
let arity = rows.first().map(|r| r.len()).unwrap_or(0);
RefRelation {
schema: vec![ScalarType::U32; arity],
rows: rows
.iter()
.map(|r| r.iter().map(|v| RefValue::U32(*v)).collect())
.collect(),
}
}
fn store_with_one(name: &str, rel: RefRelation) -> RefRelationStore {
let mut s: RefRelationStore = BTreeMap::new();
s.insert(name.to_string(), rel);
s
}
fn triples(rows: &[Vec<RefValue>]) -> Vec<(u32, u32, u32)> {
let mut out: Vec<(u32, u32, u32)> = rows
.iter()
.map(|r| match (&r[0], &r[1], &r[2]) {
(RefValue::U32(a), RefValue::U32(b), RefValue::U32(c)) => (*a, *b, *c),
other => panic!("unexpected row shape: {other:?}"),
})
.collect();
out.sort();
out
}
fn pairs(rows: &[Vec<RefValue>]) -> Vec<(u32, u32)> {
let mut out: Vec<(u32, u32)> = rows
.iter()
.map(|r| match (&r[0], &r[1]) {
(RefValue::U32(a), RefValue::U32(b)) => (*a, *b),
other => panic!("unexpected row shape: {other:?}"),
})
.collect();
out.sort();
out
}
fn pairs_from_rel(rel: &RefRelation) -> Vec<(u32, u32)> {
pairs(&rel.rows)
}
#[test]
fn triangle_certification() {
let r = rule_with(
atom("tri", vec![var("X"), var("Y"), var("Z")]),
vec![
pos("e", vec![var("X"), var("Y")]),
pos("e", vec![var("Y"), var("Z")]),
pos("e", vec![var("X"), var("Z")]),
],
);
let edges = u32_relation(&[
&[1, 2],
&[1, 3],
&[1, 4],
&[2, 3],
&[2, 4],
&[3, 4],
&[5, 6],
&[5, 7],
&[6, 7],
]);
let store = store_with_one("e", edges);
let plans = plan_rules(std::slice::from_ref(&r), &store).expect("triangle plans");
assert_eq!(plans.len(), 1);
match &plans[0] {
RulePlan::MultiwayCandidate {
head_predicate,
hypergraph,
variable_order,
} => {
assert_eq!(head_predicate, "tri");
assert_eq!(hypergraph.hyperedge_count(), 3);
assert_eq!(variable_order.len(), 3);
}
other => panic!("expected MultiwayCandidate, got {other:?}"),
}
let rows = evaluate_rule_typed(&r, &store, &AppearanceOrder).expect("triangle eval");
assert_eq!(
triples(&rows),
vec![(1, 2, 3), (1, 2, 4), (1, 3, 4), (2, 3, 4), (5, 6, 7),]
);
let explained = explain_plans(&plans);
assert!(
explained.contains("tri/0: multiway"),
"expected multiway tri in explain:\n{explained}"
);
}
fn sg_reference(parent_edges: &[(u32, u32)]) -> Vec<(u32, u32)> {
let mut sg: BTreeSet<(u32, u32)> = BTreeSet::new();
for (x, p_x) in parent_edges {
for (y, p_y) in parent_edges {
if p_x == p_y {
sg.insert((*x, *y));
}
}
}
loop {
let snapshot: Vec<(u32, u32)> = sg.iter().copied().collect();
let before_len = sg.len();
for (a, b) in &snapshot {
for (x, p_x) in parent_edges {
if p_x != a {
continue;
}
for (y, p_y) in parent_edges {
if p_y != b {
continue;
}
sg.insert((*x, *y));
}
}
}
if sg.len() == before_len {
break;
}
}
sg.into_iter().collect()
}
#[test]
fn same_generation_certification() {
let r_base = rule_with(
atom("sg", vec![var("X"), var("Y")]),
vec![
pos("parent", vec![var("X"), var("P")]),
pos("parent", vec![var("Y"), var("P")]),
],
);
let r_step = rule_with(
atom("sg", vec![var("X"), var("Y")]),
vec![
pos("parent", vec![var("X"), var("A")]),
pos("sg", vec![var("A"), var("B")]),
pos("parent", vec![var("Y"), var("B")]),
],
);
let parent_pairs: Vec<(u32, u32)> = vec![(1, 10), (2, 10), (11, 1), (12, 2), (13, 1), (14, 12)];
let parent_rows: Vec<Vec<u32>> = parent_pairs.iter().map(|(c, p)| vec![*c, *p]).collect();
let parent_refs: Vec<&[u32]> = parent_rows.iter().map(|v| v.as_slice()).collect();
let store = store_with_one("parent", u32_relation(&parent_refs));
let rules = vec![r_base, r_step];
let plans = plan_rules(&rules, &store).expect("sg plans");
assert_eq!(plans.len(), 2);
for (i, p) in plans.iter().enumerate() {
match p {
RulePlan::MultiwayCandidate { head_predicate, .. } => {
assert_eq!(head_predicate, "sg", "rule index {i}");
}
other => panic!("expected MultiwayCandidate for rule {i}, got {other:?}"),
}
}
let result = evaluate_fixpoint_typed(
&rules,
&store,
"sg",
&AppearanceOrder,
&FixpointConfig::default(),
)
.expect("sg fixpoint converges");
let expected = sg_reference(&parent_pairs);
assert_eq!(
pairs_from_rel(&result),
expected,
"oracle output disagrees with reference SG impl"
);
let explained = explain_plans(&plans);
assert!(
explained.contains("sg/0:") && explained.contains("sg/1:"),
"expected both sg rules in explain:\n{explained}"
);
let multiway_count = explained.matches("multiway vars=").count();
assert_eq!(multiway_count, 2, "expected 2 multiway lines:\n{explained}");
}
#[test]
fn mutually_recursive_parity_scc_certification() {
let even_seed = rule_with(
atom("even_path", vec![var("X"), var("Z")]),
vec![
pos("e", vec![var("X"), var("M")]),
pos("e", vec![var("M"), var("Z")]),
],
);
let even_step = rule_with(
atom("even_path", vec![var("X"), var("Z")]),
vec![
pos("e", vec![var("X"), var("Y")]),
pos("odd_path", vec![var("Y"), var("Z")]),
],
);
let odd_seed = rule_with(
atom("odd_path", vec![var("X"), var("Z")]),
vec![
pos("e", vec![var("X"), var("M")]),
pos("e", vec![var("M"), var("Y")]),
pos("e", vec![var("Y"), var("Z")]),
],
);
let odd_step = rule_with(
atom("odd_path", vec![var("X"), var("Z")]),
vec![
pos("e", vec![var("X"), var("Y")]),
pos("even_path", vec![var("Y"), var("Z")]),
],
);
let chain_len: u32 = 6;
let chain_rows: Vec<Vec<u32>> = (1..chain_len).map(|i| vec![i, i + 1]).collect();
let chain_refs: Vec<&[u32]> = chain_rows.iter().map(|v| v.as_slice()).collect();
let store = store_with_one("e", u32_relation(&chain_refs));
let mut rules: BTreeMap<String, Vec<Rule>> = BTreeMap::new();
rules.insert("even_path".into(), vec![even_seed, even_step]);
rules.insert("odd_path".into(), vec![odd_seed, odd_step]);
let flat: Vec<Rule> = rules.values().flatten().cloned().collect();
let plans = plan_rules(&flat, &store).expect("scc plans");
assert_eq!(plans.len(), 4);
for (i, p) in plans.iter().enumerate() {
match p {
RulePlan::MultiwayCandidate { head_predicate, .. } => {
assert!(
head_predicate == "even_path" || head_predicate == "odd_path",
"unexpected head at index {i}: {head_predicate}"
);
}
other => panic!("expected MultiwayCandidate for rule {i}, got {other:?}"),
}
}
let result =
evaluate_scc_fixpoint_typed(&rules, &store, &AppearanceOrder, &FixpointConfig::default())
.expect("SCC fixpoint converges");
let mut even_expected: Vec<(u32, u32)> = Vec::new();
for i in 1..=chain_len {
for d in [2u32, 4u32] {
if i + d <= chain_len {
even_expected.push((i, i + d));
}
}
}
even_expected.sort();
let mut odd_expected: Vec<(u32, u32)> = Vec::new();
for i in 1..=chain_len {
for d in [3u32, 5u32] {
if i + d <= chain_len {
odd_expected.push((i, i + d));
}
}
}
odd_expected.sort();
assert_eq!(
pairs_from_rel(result.get("even_path").expect("even_path present")),
even_expected
);
assert_eq!(
pairs_from_rel(result.get("odd_path").expect("odd_path present")),
odd_expected
);
let explained = explain_plans(&plans);
let multiway_count = explained.matches("multiway vars=").count();
assert_eq!(multiway_count, 4, "expected 4 multiway lines:\n{explained}");
let pos_even = explained.find("even_path/").expect("even_path present");
let pos_odd = explained.find("odd_path/").expect("odd_path present");
assert!(
pos_even < pos_odd,
"expected even_path before odd_path in canonical explain:\n{explained}"
);
}
#[test]
fn skewed_multiway_certification() {
let r = rule_with(
atom("result", vec![var("X"), var("Y"), var("Z")]),
vec![
pos("big", vec![var("X"), var("Y")]),
pos("small_a", vec![var("Y"), var("Z")]),
pos("small_b", vec![var("X"), var("Z")]),
],
);
let big_rows: Vec<Vec<u32>> = (1u32..=8)
.flat_map(|x| (1u32..=8).filter(move |y| *y != x).map(move |y| vec![x, y]))
.collect();
let big_refs: Vec<&[u32]> = big_rows.iter().map(|v| v.as_slice()).collect();
let big = u32_relation(&big_refs);
let small_a = u32_relation(&[&[2, 10], &[3, 20], &[4, 30], &[5, 40]]);
let small_b = u32_relation(&[&[1, 10], &[2, 20], &[3, 30], &[4, 40]]);
let mut store: RefRelationStore = BTreeMap::new();
store.insert("big".into(), big);
store.insert("small_a".into(), small_a);
store.insert("small_b".into(), small_b);
let plans = plan_rules(std::slice::from_ref(&r), &store).expect("skewed plans");
match &plans[0] {
RulePlan::MultiwayCandidate { hypergraph, .. } => {
assert_eq!(hypergraph.hyperedge_count(), 3);
}
other => panic!("expected MultiwayCandidate, got {other:?}"),
}
let rows = evaluate_rule_typed(&r, &store, &AppearanceOrder).expect("skewed eval");
let expected = vec![(1, 2, 10), (2, 3, 20), (3, 4, 30), (4, 5, 40)];
assert_eq!(triples(&rows), expected);
let explained = explain_plans(&plans);
assert!(
explained.contains("result/0: multiway"),
"expected multiway result in explain:\n{explained}"
);
}
#[test]
fn deep_recursive_frontier_certification() {
let r_base = rule_with(
atom("reach", vec![var("X"), var("Y")]),
vec![
pos("e", vec![var("X"), var("M")]),
pos("e", vec![var("M"), var("Y")]),
],
);
let r_step = rule_with(
atom("reach", vec![var("X"), var("Z")]),
vec![
pos("e", vec![var("X"), var("Y")]),
pos("reach", vec![var("Y"), var("Z")]),
],
);
let chain_len: u32 = 12;
let chain_rows: Vec<Vec<u32>> = (1..chain_len).map(|i| vec![i, i + 1]).collect();
let chain_refs: Vec<&[u32]> = chain_rows.iter().map(|v| v.as_slice()).collect();
let edges = u32_relation(&chain_refs);
let store = store_with_one("e", edges);
let rules = vec![r_base, r_step];
let plans = plan_rules(&rules, &store).expect("frontier plans");
assert_eq!(plans.len(), 2);
for (i, p) in plans.iter().enumerate() {
match p {
RulePlan::MultiwayCandidate { head_predicate, .. } => {
assert_eq!(head_predicate, "reach", "rule index {i}");
}
other => panic!("expected MultiwayCandidate for rule {i}, got {other:?}"),
}
}
let result = evaluate_fixpoint_typed(
&rules,
&store,
"reach",
&AppearanceOrder,
&FixpointConfig::default(),
)
.expect("frontier fixpoint converges");
let mut expected: Vec<(u32, u32)> = Vec::new();
for i in 1..chain_len {
for j in (i + 2)..=chain_len {
expected.push((i, j));
}
}
expected.sort();
assert_eq!(pairs_from_rel(&result), expected);
let explained = explain_plans(&plans);
assert!(
explained.contains("reach/0:") && explained.contains("reach/1:"),
"expected both reach rules in explain:\n{explained}"
);
let multiway_count = explained.matches("multiway vars=").count();
assert_eq!(multiway_count, 2, "expected 2 multiway lines:\n{explained}");
}