use std::collections::HashMap;
use crate::term::{Literal, Term};
type TermId = usize;
pub struct UnionFind {
parent: Vec<TermId>,
rank: Vec<usize>,
}
impl UnionFind {
pub fn new() -> Self {
UnionFind {
parent: Vec::new(),
rank: Vec::new(),
}
}
pub fn make_set(&mut self) -> TermId {
let id = self.parent.len();
self.parent.push(id);
self.rank.push(0);
id
}
pub fn find(&mut self, x: TermId) -> TermId {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]);
}
self.parent[x]
}
pub fn union(&mut self, x: TermId, y: TermId) -> bool {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return false;
}
if self.rank[rx] < self.rank[ry] {
self.parent[rx] = ry;
} else if self.rank[rx] > self.rank[ry] {
self.parent[ry] = rx;
} else {
self.parent[ry] = rx;
self.rank[rx] += 1;
}
true
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ENode {
Lit(i64),
Var(i64),
Name(String),
App {
func: TermId,
arg: TermId,
},
}
pub struct EGraph {
nodes: Vec<ENode>,
uf: UnionFind,
node_map: HashMap<ENode, TermId>,
pending: Vec<(TermId, TermId)>,
use_list: Vec<Vec<TermId>>,
}
impl EGraph {
pub fn new() -> Self {
EGraph {
nodes: Vec::new(),
uf: UnionFind::new(),
node_map: HashMap::new(),
pending: Vec::new(),
use_list: Vec::new(),
}
}
pub fn add(&mut self, node: ENode) -> TermId {
if let Some(&id) = self.node_map.get(&node) {
return id;
}
let id = self.nodes.len();
self.nodes.push(node.clone());
self.node_map.insert(node.clone(), id);
self.uf.make_set();
self.use_list.push(Vec::new());
if let ENode::App { func, arg } = &node {
self.use_list[*func].push(id);
self.use_list[*arg].push(id);
}
id
}
pub fn merge(&mut self, a: TermId, b: TermId) {
self.pending.push((a, b));
self.propagate();
}
fn propagate(&mut self) {
while let Some((a, b)) = self.pending.pop() {
let ra = self.uf.find(a);
let rb = self.uf.find(b);
if ra == rb {
continue;
}
let uses_a: Vec<TermId> = self.use_list[ra].clone();
let uses_b: Vec<TermId> = self.use_list[rb].clone();
self.uf.union(ra, rb);
let new_root = self.uf.find(ra);
for &ua in &uses_a {
for &ub in &uses_b {
if self.congruent(ua, ub) {
self.pending.push((ua, ub));
}
}
}
if new_root == ra {
for u in uses_b {
self.use_list[ra].push(u);
}
} else {
for u in uses_a {
self.use_list[rb].push(u);
}
}
}
}
fn congruent(&mut self, a: TermId, b: TermId) -> bool {
match (&self.nodes[a].clone(), &self.nodes[b].clone()) {
(ENode::App { func: f1, arg: a1 }, ENode::App { func: f2, arg: a2 }) => {
self.uf.find(*f1) == self.uf.find(*f2) && self.uf.find(*a1) == self.uf.find(*a2)
}
_ => false,
}
}
pub fn equivalent(&mut self, a: TermId, b: TermId) -> bool {
self.uf.find(a) == self.uf.find(b)
}
}
pub fn reify(egraph: &mut EGraph, term: &Term) -> Option<TermId> {
if let Some(n) = extract_slit(term) {
return Some(egraph.add(ENode::Lit(n)));
}
if let Some(i) = extract_svar(term) {
return Some(egraph.add(ENode::Var(i)));
}
if let Some(name) = extract_sname(term) {
return Some(egraph.add(ENode::Name(name)));
}
if let Some((func_term, arg_term)) = extract_sapp(term) {
let func = reify(egraph, &func_term)?;
let arg = reify(egraph, &arg_term)?;
return Some(egraph.add(ENode::App { func, arg }));
}
None
}
pub fn decompose_goal(goal: &Term) -> (Vec<(Term, Term)>, Term) {
let mut hypotheses = Vec::new();
let mut current = goal.clone();
while let Some((hyp, rest)) = extract_implication(¤t) {
if let Some((lhs, rhs)) = extract_equality(&hyp) {
hypotheses.push((lhs, rhs));
}
current = rest;
}
(hypotheses, current)
}
pub fn check_goal(goal: &Term) -> bool {
let (hypotheses, conclusion) = decompose_goal(goal);
let (lhs, rhs) = match extract_equality(&conclusion) {
Some(eq) => eq,
None => return false,
};
let mut egraph = EGraph::new();
let lhs_id = match reify(&mut egraph, &lhs) {
Some(id) => id,
None => return false,
};
let rhs_id = match reify(&mut egraph, &rhs) {
Some(id) => id,
None => return false,
};
for (h_lhs, h_rhs) in &hypotheses {
let h_lhs_id = match reify(&mut egraph, h_lhs) {
Some(id) => id,
None => return false,
};
let h_rhs_id = match reify(&mut egraph, h_rhs) {
Some(id) => id,
None => return false,
};
egraph.merge(h_lhs_id, h_rhs_id);
}
egraph.equivalent(lhs_id, rhs_id)
}
fn extract_slit(term: &Term) -> Option<i64> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SLit" {
if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
return Some(*n);
}
}
}
}
None
}
fn extract_svar(term: &Term) -> Option<i64> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SVar" {
if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
return Some(*i);
}
}
}
}
None
}
fn extract_sname(term: &Term) -> Option<String> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SName" {
if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
return Some(s.clone());
}
}
}
}
None
}
fn extract_sapp(term: &Term) -> Option<(Term, Term)> {
if let Term::App(outer, arg) = term {
if let Term::App(sapp, func) = outer.as_ref() {
if let Term::Global(ctor) = sapp.as_ref() {
if ctor == "SApp" {
return Some((func.as_ref().clone(), arg.as_ref().clone()));
}
}
}
}
None
}
fn extract_implication(term: &Term) -> Option<(Term, Term)> {
if let Some((op, hyp, concl)) = extract_binary_app(term) {
if op == "implies" {
return Some((hyp, concl));
}
}
None
}
fn extract_equality(term: &Term) -> Option<(Term, Term)> {
if let Some((op, lhs, rhs)) = extract_binary_app(term) {
if op == "Eq" || op == "eq" {
return Some((lhs, rhs));
}
}
None
}
fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
if let Term::App(outer, b) = term {
if let Term::App(sapp_outer, inner) = outer.as_ref() {
if let Term::Global(ctor) = sapp_outer.as_ref() {
if ctor == "SApp" {
if let Term::App(partial, a) = inner.as_ref() {
if let Term::App(sapp_inner, op_term) = partial.as_ref() {
if let Term::Global(ctor2) = sapp_inner.as_ref() {
if ctor2 == "SApp" {
if let Some(op) = extract_sname(op_term) {
return Some((
op,
a.as_ref().clone(),
b.as_ref().clone(),
));
}
}
}
}
}
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_union_find_basic() {
let mut uf = UnionFind::new();
let a = uf.make_set();
let b = uf.make_set();
assert_ne!(uf.find(a), uf.find(b));
uf.union(a, b);
assert_eq!(uf.find(a), uf.find(b));
}
#[test]
fn test_union_find_transitivity() {
let mut uf = UnionFind::new();
let a = uf.make_set();
let b = uf.make_set();
let c = uf.make_set();
uf.union(a, b);
uf.union(b, c);
assert_eq!(uf.find(a), uf.find(c));
}
#[test]
fn test_egraph_reflexive() {
let mut eg = EGraph::new();
let x = eg.add(ENode::Var(0));
assert!(eg.equivalent(x, x));
}
#[test]
fn test_egraph_congruence() {
let mut eg = EGraph::new();
let x = eg.add(ENode::Var(0));
let y = eg.add(ENode::Var(1));
let f = eg.add(ENode::Name("f".to_string()));
let fx = eg.add(ENode::App { func: f, arg: x });
let fy = eg.add(ENode::App { func: f, arg: y });
assert!(!eg.equivalent(fx, fy));
eg.merge(x, y);
assert!(eg.equivalent(fx, fy));
}
#[test]
fn test_egraph_nested_congruence() {
let mut eg = EGraph::new();
let a = eg.add(ENode::Var(0));
let b = eg.add(ENode::Var(1));
let c = eg.add(ENode::Var(2));
let f = eg.add(ENode::Name("f".to_string()));
let fa = eg.add(ENode::App { func: f, arg: a });
let fc = eg.add(ENode::App { func: f, arg: c });
let ffa = eg.add(ENode::App { func: f, arg: fa });
let ffc = eg.add(ENode::App { func: f, arg: fc });
eg.merge(a, b);
eg.merge(b, c);
assert!(eg.equivalent(ffa, ffc));
}
#[test]
fn test_egraph_binary_congruence() {
let mut eg = EGraph::new();
let a = eg.add(ENode::Var(0));
let b = eg.add(ENode::Var(1));
let c = eg.add(ENode::Var(2));
let add = eg.add(ENode::Name("add".to_string()));
let add_a = eg.add(ENode::App { func: add, arg: a });
let add_b = eg.add(ENode::App { func: add, arg: b });
let add_a_c = eg.add(ENode::App { func: add_a, arg: c });
let add_b_c = eg.add(ENode::App { func: add_b, arg: c });
assert!(!eg.equivalent(add_a_c, add_b_c));
eg.merge(a, b);
assert!(eg.equivalent(add_a_c, add_b_c));
}
fn make_sname(s: &str) -> Term {
Term::App(
Box::new(Term::Global("SName".to_string())),
Box::new(Term::Lit(Literal::Text(s.to_string()))),
)
}
fn make_svar(i: i64) -> Term {
Term::App(
Box::new(Term::Global("SVar".to_string())),
Box::new(Term::Lit(Literal::Int(i))),
)
}
fn make_sapp(f: Term, a: Term) -> Term {
Term::App(
Box::new(Term::App(
Box::new(Term::Global("SApp".to_string())),
Box::new(f),
)),
Box::new(a),
)
}
#[test]
fn test_extract_sname() {
let term = make_sname("f");
assert_eq!(extract_sname(&term), Some("f".to_string()));
}
#[test]
fn test_extract_svar() {
let term = make_svar(0);
assert_eq!(extract_svar(&term), Some(0));
}
#[test]
fn test_extract_sapp() {
let term = make_sapp(make_sname("f"), make_svar(0));
let result = extract_sapp(&term);
assert!(result.is_some());
let (func, arg) = result.unwrap();
assert_eq!(extract_sname(&func), Some("f".to_string()));
assert_eq!(extract_svar(&arg), Some(0));
}
#[test]
fn test_extract_binary_app() {
let term = make_sapp(make_sapp(make_sname("Eq"), make_svar(0)), make_svar(1));
let result = extract_binary_app(&term);
assert!(result.is_some(), "Should extract binary app");
let (op, a, b) = result.unwrap();
assert_eq!(op, "Eq");
assert_eq!(extract_svar(&a), Some(0));
assert_eq!(extract_svar(&b), Some(1));
}
#[test]
fn test_extract_equality() {
let term = make_sapp(make_sapp(make_sname("Eq"), make_svar(0)), make_svar(1));
let result = extract_equality(&term);
assert!(result.is_some(), "Should extract equality");
let (lhs, rhs) = result.unwrap();
assert_eq!(extract_svar(&lhs), Some(0));
assert_eq!(extract_svar(&rhs), Some(1));
}
#[test]
fn test_extract_implication() {
let x = make_svar(0);
let y = make_svar(1);
let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
let f = make_sname("f");
let fx = make_sapp(f.clone(), x);
let fy = make_sapp(f, y);
let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
let goal = make_sapp(make_sapp(make_sname("implies"), hyp.clone()), concl.clone());
let result = extract_implication(&goal);
assert!(result.is_some(), "Should extract implication");
let (hyp_extracted, concl_extracted) = result.unwrap();
let hyp_eq = extract_equality(&hyp_extracted);
assert!(hyp_eq.is_some(), "Hypothesis should be equality");
let (h_lhs, h_rhs) = hyp_eq.unwrap();
assert_eq!(extract_svar(&h_lhs), Some(0));
assert_eq!(extract_svar(&h_rhs), Some(1));
let concl_eq = extract_equality(&concl_extracted);
assert!(concl_eq.is_some(), "Conclusion should be equality");
}
#[test]
fn test_decompose_goal_with_hypothesis() {
let x = make_svar(0);
let y = make_svar(1);
let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
let f = make_sname("f");
let fx = make_sapp(f.clone(), x);
let fy = make_sapp(f, y);
let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
let (hypotheses, conclusion) = decompose_goal(&goal);
assert_eq!(hypotheses.len(), 1, "Should have 1 hypothesis");
let concl_eq = extract_equality(&conclusion);
assert!(concl_eq.is_some(), "Conclusion should be equality");
}
#[test]
fn test_check_goal_with_hypothesis() {
let x = make_svar(0);
let y = make_svar(1);
let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
let f = make_sname("f");
let fx = make_sapp(f.clone(), x.clone());
let fy = make_sapp(f.clone(), y.clone());
let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
assert!(check_goal(&goal), "CC should prove x=y → f(x)=f(y)");
}
}