1use std::collections::HashMap;
29
30use crate::term::{Literal, Term};
31
32type TermId = usize;
33
34pub struct UnionFind {
43 parent: Vec<TermId>,
45 rank: Vec<usize>,
47}
48
49impl UnionFind {
50 pub fn new() -> Self {
51 UnionFind {
52 parent: Vec::new(),
53 rank: Vec::new(),
54 }
55 }
56
57 pub fn make_set(&mut self) -> TermId {
59 let id = self.parent.len();
60 self.parent.push(id);
61 self.rank.push(0);
62 id
63 }
64
65 pub fn find(&mut self, x: TermId) -> TermId {
67 if self.parent[x] != x {
68 self.parent[x] = self.find(self.parent[x]);
69 }
70 self.parent[x]
71 }
72
73 pub fn union(&mut self, x: TermId, y: TermId) -> bool {
75 let rx = self.find(x);
76 let ry = self.find(y);
77 if rx == ry {
78 return false;
79 }
80
81 if self.rank[rx] < self.rank[ry] {
82 self.parent[rx] = ry;
83 } else if self.rank[rx] > self.rank[ry] {
84 self.parent[ry] = rx;
85 } else {
86 self.parent[ry] = rx;
87 self.rank[rx] += 1;
88 }
89 true
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Hash)]
102pub enum ENode {
103 Lit(i64),
105 Var(i64),
107 Name(String),
109 App {
111 func: TermId,
113 arg: TermId,
115 },
116}
117
118pub struct EGraph {
123 nodes: Vec<ENode>,
125 uf: UnionFind,
127 node_map: HashMap<ENode, TermId>,
129 pending: Vec<(TermId, TermId)>,
131 use_list: Vec<Vec<TermId>>,
134}
135
136impl EGraph {
137 pub fn new() -> Self {
138 EGraph {
139 nodes: Vec::new(),
140 uf: UnionFind::new(),
141 node_map: HashMap::new(),
142 pending: Vec::new(),
143 use_list: Vec::new(),
144 }
145 }
146
147 pub fn add(&mut self, node: ENode) -> TermId {
149 if let Some(&id) = self.node_map.get(&node) {
151 return id;
152 }
153
154 let id = self.nodes.len();
155 self.nodes.push(node.clone());
156 self.node_map.insert(node.clone(), id);
157 self.uf.make_set();
158 self.use_list.push(Vec::new());
159
160 if let ENode::App { func, arg } = &node {
162 self.use_list[*func].push(id);
163 self.use_list[*arg].push(id);
164 }
165
166 id
167 }
168
169 pub fn merge(&mut self, a: TermId, b: TermId) {
171 self.pending.push((a, b));
172 self.propagate();
173 }
174
175 fn propagate(&mut self) {
177 while let Some((a, b)) = self.pending.pop() {
178 let ra = self.uf.find(a);
179 let rb = self.uf.find(b);
180 if ra == rb {
181 continue;
182 }
183
184 let uses_a: Vec<TermId> = self.use_list[ra].clone();
186 let uses_b: Vec<TermId> = self.use_list[rb].clone();
187
188 self.uf.union(ra, rb);
190 let new_root = self.uf.find(ra);
191
192 for &ua in &uses_a {
195 for &ub in &uses_b {
196 if self.congruent(ua, ub) {
197 self.pending.push((ua, ub));
198 }
199 }
200 }
201
202 if new_root == ra {
204 for u in uses_b {
205 self.use_list[ra].push(u);
206 }
207 } else {
208 for u in uses_a {
209 self.use_list[rb].push(u);
210 }
211 }
212 }
213 }
214
215 fn congruent(&mut self, a: TermId, b: TermId) -> bool {
217 match (&self.nodes[a].clone(), &self.nodes[b].clone()) {
218 (ENode::App { func: f1, arg: a1 }, ENode::App { func: f2, arg: a2 }) => {
219 self.uf.find(*f1) == self.uf.find(*f2) && self.uf.find(*a1) == self.uf.find(*a2)
220 }
221 _ => false,
222 }
223 }
224
225 pub fn equivalent(&mut self, a: TermId, b: TermId) -> bool {
227 self.uf.find(a) == self.uf.find(b)
228 }
229}
230
231pub fn reify(egraph: &mut EGraph, term: &Term) -> Option<TermId> {
245 if let Some(n) = extract_slit(term) {
247 return Some(egraph.add(ENode::Lit(n)));
248 }
249
250 if let Some(i) = extract_svar(term) {
252 return Some(egraph.add(ENode::Var(i)));
253 }
254
255 if let Some(name) = extract_sname(term) {
257 return Some(egraph.add(ENode::Name(name)));
258 }
259
260 if let Some((func_term, arg_term)) = extract_sapp(term) {
262 let func = reify(egraph, &func_term)?;
263 let arg = reify(egraph, &arg_term)?;
264 return Some(egraph.add(ENode::App { func, arg }));
265 }
266
267 None
268}
269
270pub fn decompose_goal(goal: &Term) -> (Vec<(Term, Term)>, Term) {
288 let mut hypotheses = Vec::new();
289 let mut current = goal.clone();
290
291 while let Some((hyp, rest)) = extract_implication(¤t) {
293 if let Some((lhs, rhs)) = extract_equality(&hyp) {
294 hypotheses.push((lhs, rhs));
295 }
296 current = rest;
297 }
298
299 (hypotheses, current)
300}
301
302pub fn check_goal(goal: &Term) -> bool {
318 let (hypotheses, conclusion) = decompose_goal(goal);
319
320 let (lhs, rhs) = match extract_equality(&conclusion) {
322 Some(eq) => eq,
323 None => return false,
324 };
325
326 let mut egraph = EGraph::new();
327
328 let lhs_id = match reify(&mut egraph, &lhs) {
332 Some(id) => id,
333 None => return false,
334 };
335
336 let rhs_id = match reify(&mut egraph, &rhs) {
337 Some(id) => id,
338 None => return false,
339 };
340
341 for (h_lhs, h_rhs) in &hypotheses {
344 let h_lhs_id = match reify(&mut egraph, h_lhs) {
345 Some(id) => id,
346 None => return false,
347 };
348 let h_rhs_id = match reify(&mut egraph, h_rhs) {
349 Some(id) => id,
350 None => return false,
351 };
352 egraph.merge(h_lhs_id, h_rhs_id);
353 }
354
355 egraph.equivalent(lhs_id, rhs_id)
357}
358
359fn extract_slit(term: &Term) -> Option<i64> {
365 if let Term::App(ctor, arg) = term {
366 if let Term::Global(name) = ctor.as_ref() {
367 if name == "SLit" {
368 if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
369 return Some(*n);
370 }
371 }
372 }
373 }
374 None
375}
376
377fn extract_svar(term: &Term) -> Option<i64> {
379 if let Term::App(ctor, arg) = term {
380 if let Term::Global(name) = ctor.as_ref() {
381 if name == "SVar" {
382 if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
383 return Some(*i);
384 }
385 }
386 }
387 }
388 None
389}
390
391fn extract_sname(term: &Term) -> Option<String> {
393 if let Term::App(ctor, arg) = term {
394 if let Term::Global(name) = ctor.as_ref() {
395 if name == "SName" {
396 if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
397 return Some(s.clone());
398 }
399 }
400 }
401 }
402 None
403}
404
405fn extract_sapp(term: &Term) -> Option<(Term, Term)> {
407 if let Term::App(outer, arg) = term {
408 if let Term::App(sapp, func) = outer.as_ref() {
409 if let Term::Global(ctor) = sapp.as_ref() {
410 if ctor == "SApp" {
411 return Some((func.as_ref().clone(), arg.as_ref().clone()));
412 }
413 }
414 }
415 }
416 None
417}
418
419fn extract_implication(term: &Term) -> Option<(Term, Term)> {
421 if let Some((op, hyp, concl)) = extract_binary_app(term) {
422 if op == "implies" {
423 return Some((hyp, concl));
424 }
425 }
426 None
427}
428
429fn extract_equality(term: &Term) -> Option<(Term, Term)> {
431 if let Some((op, lhs, rhs)) = extract_binary_app(term) {
432 if op == "Eq" || op == "eq" {
433 return Some((lhs, rhs));
434 }
435 }
436 None
437}
438
439fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
441 if let Term::App(outer, b) = term {
442 if let Term::App(sapp_outer, inner) = outer.as_ref() {
443 if let Term::Global(ctor) = sapp_outer.as_ref() {
444 if ctor == "SApp" {
445 if let Term::App(partial, a) = inner.as_ref() {
446 if let Term::App(sapp_inner, op_term) = partial.as_ref() {
447 if let Term::Global(ctor2) = sapp_inner.as_ref() {
448 if ctor2 == "SApp" {
449 if let Some(op) = extract_sname(op_term) {
450 return Some((
451 op,
452 a.as_ref().clone(),
453 b.as_ref().clone(),
454 ));
455 }
456 }
457 }
458 }
459 }
460 }
461 }
462 }
463 }
464 None
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_union_find_basic() {
477 let mut uf = UnionFind::new();
478 let a = uf.make_set();
479 let b = uf.make_set();
480 assert_ne!(uf.find(a), uf.find(b));
481 uf.union(a, b);
482 assert_eq!(uf.find(a), uf.find(b));
483 }
484
485 #[test]
486 fn test_union_find_transitivity() {
487 let mut uf = UnionFind::new();
488 let a = uf.make_set();
489 let b = uf.make_set();
490 let c = uf.make_set();
491 uf.union(a, b);
492 uf.union(b, c);
493 assert_eq!(uf.find(a), uf.find(c));
494 }
495
496 #[test]
497 fn test_egraph_reflexive() {
498 let mut eg = EGraph::new();
499 let x = eg.add(ENode::Var(0));
500 assert!(eg.equivalent(x, x));
501 }
502
503 #[test]
504 fn test_egraph_congruence() {
505 let mut eg = EGraph::new();
506 let x = eg.add(ENode::Var(0));
507 let y = eg.add(ENode::Var(1));
508 let f = eg.add(ENode::Name("f".to_string()));
509 let fx = eg.add(ENode::App { func: f, arg: x });
510 let fy = eg.add(ENode::App { func: f, arg: y });
511
512 assert!(!eg.equivalent(fx, fy));
514
515 eg.merge(x, y);
517 assert!(eg.equivalent(fx, fy));
518 }
519
520 #[test]
521 fn test_egraph_nested_congruence() {
522 let mut eg = EGraph::new();
523 let a = eg.add(ENode::Var(0));
524 let b = eg.add(ENode::Var(1));
525 let c = eg.add(ENode::Var(2));
526 let f = eg.add(ENode::Name("f".to_string()));
527
528 let fa = eg.add(ENode::App { func: f, arg: a });
529 let fc = eg.add(ENode::App { func: f, arg: c });
530 let ffa = eg.add(ENode::App { func: f, arg: fa });
531 let ffc = eg.add(ENode::App { func: f, arg: fc });
532
533 eg.merge(a, b);
535 eg.merge(b, c);
536 assert!(eg.equivalent(ffa, ffc));
537 }
538
539 #[test]
540 fn test_egraph_binary_congruence() {
541 let mut eg = EGraph::new();
542 let a = eg.add(ENode::Var(0));
543 let b = eg.add(ENode::Var(1));
544 let c = eg.add(ENode::Var(2));
545 let add = eg.add(ENode::Name("add".to_string()));
546
547 let add_a = eg.add(ENode::App { func: add, arg: a });
549 let add_b = eg.add(ENode::App { func: add, arg: b });
550 let add_a_c = eg.add(ENode::App { func: add_a, arg: c });
551 let add_b_c = eg.add(ENode::App { func: add_b, arg: c });
552
553 assert!(!eg.equivalent(add_a_c, add_b_c));
554 eg.merge(a, b);
555 assert!(eg.equivalent(add_a_c, add_b_c));
556 }
557
558 fn make_sname(s: &str) -> Term {
564 Term::App(
565 Box::new(Term::Global("SName".to_string())),
566 Box::new(Term::Lit(Literal::Text(s.to_string()))),
567 )
568 }
569
570 fn make_svar(i: i64) -> Term {
572 Term::App(
573 Box::new(Term::Global("SVar".to_string())),
574 Box::new(Term::Lit(Literal::Int(i))),
575 )
576 }
577
578 fn make_sapp(f: Term, a: Term) -> Term {
580 Term::App(
581 Box::new(Term::App(
582 Box::new(Term::Global("SApp".to_string())),
583 Box::new(f),
584 )),
585 Box::new(a),
586 )
587 }
588
589 #[test]
590 fn test_extract_sname() {
591 let term = make_sname("f");
592 assert_eq!(extract_sname(&term), Some("f".to_string()));
593 }
594
595 #[test]
596 fn test_extract_svar() {
597 let term = make_svar(0);
598 assert_eq!(extract_svar(&term), Some(0));
599 }
600
601 #[test]
602 fn test_extract_sapp() {
603 let term = make_sapp(make_sname("f"), make_svar(0));
605 let result = extract_sapp(&term);
606 assert!(result.is_some());
607 let (func, arg) = result.unwrap();
608 assert_eq!(extract_sname(&func), Some("f".to_string()));
609 assert_eq!(extract_svar(&arg), Some(0));
610 }
611
612 #[test]
613 fn test_extract_binary_app() {
614 let term = make_sapp(make_sapp(make_sname("Eq"), make_svar(0)), make_svar(1));
616 let result = extract_binary_app(&term);
617 assert!(result.is_some(), "Should extract binary app");
618 let (op, a, b) = result.unwrap();
619 assert_eq!(op, "Eq");
620 assert_eq!(extract_svar(&a), Some(0));
621 assert_eq!(extract_svar(&b), Some(1));
622 }
623
624 #[test]
625 fn test_extract_equality() {
626 let term = make_sapp(make_sapp(make_sname("Eq"), make_svar(0)), make_svar(1));
628 let result = extract_equality(&term);
629 assert!(result.is_some(), "Should extract equality");
630 let (lhs, rhs) = result.unwrap();
631 assert_eq!(extract_svar(&lhs), Some(0));
632 assert_eq!(extract_svar(&rhs), Some(1));
633 }
634
635 #[test]
636 fn test_extract_implication() {
637 let x = make_svar(0);
641 let y = make_svar(1);
642 let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
643
644 let f = make_sname("f");
645 let fx = make_sapp(f.clone(), x);
646 let fy = make_sapp(f, y);
647 let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
648
649 let goal = make_sapp(make_sapp(make_sname("implies"), hyp.clone()), concl.clone());
650
651 let result = extract_implication(&goal);
652 assert!(result.is_some(), "Should extract implication");
653 let (hyp_extracted, concl_extracted) = result.unwrap();
654
655 let hyp_eq = extract_equality(&hyp_extracted);
657 assert!(hyp_eq.is_some(), "Hypothesis should be equality");
658 let (h_lhs, h_rhs) = hyp_eq.unwrap();
659 assert_eq!(extract_svar(&h_lhs), Some(0));
660 assert_eq!(extract_svar(&h_rhs), Some(1));
661
662 let concl_eq = extract_equality(&concl_extracted);
664 assert!(concl_eq.is_some(), "Conclusion should be equality");
665 }
666
667 #[test]
668 fn test_decompose_goal_with_hypothesis() {
669 let x = make_svar(0);
671 let y = make_svar(1);
672 let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
673
674 let f = make_sname("f");
675 let fx = make_sapp(f.clone(), x);
676 let fy = make_sapp(f, y);
677 let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
678
679 let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
680
681 let (hypotheses, conclusion) = decompose_goal(&goal);
682 assert_eq!(hypotheses.len(), 1, "Should have 1 hypothesis");
683
684 let concl_eq = extract_equality(&conclusion);
686 assert!(concl_eq.is_some(), "Conclusion should be equality");
687 }
688
689 #[test]
690 fn test_check_goal_with_hypothesis() {
691 let x = make_svar(0);
694 let y = make_svar(1);
695 let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
696
697 let f = make_sname("f");
698 let fx = make_sapp(f.clone(), x.clone());
699 let fy = make_sapp(f.clone(), y.clone());
700 let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
701
702 let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
703
704 assert!(check_goal(&goal), "CC should prove x=y → f(x)=f(y)");
705 }
706}