Skip to main content

logicaffeine_kernel/
cc.rs

1//! Congruence Closure Tactic
2//!
3//! Proves equalities over uninterpreted functions using Union-Find with
4//! congruence propagation. This implements a simplified Downey-Sethi-Tarjan
5//! algorithm.
6//!
7//! # Algorithm
8//!
9//! The congruence closure tactic works in four steps:
10//! 1. **Build E-graph**: Add all subterms from goal and hypotheses
11//! 2. **Merge hypotheses**: For each hypothesis `x = y`, merge x and y
12//! 3. **Propagate**: When `x = y` and `f(x)`, `f(y)` exist, merge `f(x)` and `f(y)`
13//! 4. **Check**: Goal `a = b` holds iff a and b are in the same equivalence class
14//!
15//! # Supported Goals
16//!
17//! - Direct equalities: `Eq a a` (reflexivity)
18//! - Implications: `(Eq x y) -> (Eq (f x) (f y))` (congruence)
19//! - Nested implications with multiple hypotheses
20//!
21//! # E-Graph Structure
22//!
23//! The E-graph maintains:
24//! - A Union-Find for equivalence classes
25//! - Hash-consing for structural sharing
26//! - Use lists for efficient congruence propagation
27
28use std::collections::HashMap;
29
30use crate::term::{Literal, Term};
31
32type TermId = usize;
33
34// =============================================================================
35// UNION-FIND DATA STRUCTURE
36// =============================================================================
37
38/// Union-Find data structure with path compression and union by rank.
39///
40/// Maintains equivalence classes over term IDs. Supports near-constant time
41/// operations for `find` (amortized) and `union`.
42pub struct UnionFind {
43    /// Parent pointer for each element (element is its own parent if root).
44    parent: Vec<TermId>,
45    /// Rank (approximate tree depth) for union by rank optimization.
46    rank: Vec<usize>,
47}
48
49impl UnionFind {
50    pub fn new() -> Self {
51        UnionFind {
52            parent: Vec::new(),
53            rank: Vec::new(),
54        }
55    }
56
57    /// Add a new element, returns its ID
58    pub fn make_set(&mut self) -> TermId {
59        let id = self.parent.len();
60        self.parent.push(id);
61        self.rank.push(0);
62        id
63    }
64
65    /// Find representative with path compression
66    pub fn find(&mut self, x: TermId) -> TermId {
67        if self.parent[x] != x {
68            self.parent[x] = self.find(self.parent[x]);
69        }
70        self.parent[x]
71    }
72
73    /// Union by rank, returns true if a merge occurred
74    pub fn union(&mut self, x: TermId, y: TermId) -> bool {
75        let rx = self.find(x);
76        let ry = self.find(y);
77        if rx == ry {
78            return false;
79        }
80
81        if self.rank[rx] < self.rank[ry] {
82            self.parent[rx] = ry;
83        } else if self.rank[rx] > self.rank[ry] {
84            self.parent[ry] = rx;
85        } else {
86            self.parent[ry] = rx;
87            self.rank[rx] += 1;
88        }
89        true
90    }
91}
92
93// =============================================================================
94// E-GRAPH DATA STRUCTURE
95// =============================================================================
96
97/// Node in the E-graph representing a term.
98///
99/// Terms are represented in a curried style where function application
100/// is a binary node. For example, `f(x, y)` is `App(App(f, x), y)`.
101#[derive(Debug, Clone, PartialEq, Eq, Hash)]
102pub enum ENode {
103    /// Integer literal from `SLit`.
104    Lit(i64),
105    /// De Bruijn variable from `SVar`.
106    Var(i64),
107    /// Named constant or function symbol from `SName`.
108    Name(String),
109    /// Function application (curried): `func` applied to `arg`.
110    App {
111        /// The function being applied.
112        func: TermId,
113        /// The argument.
114        arg: TermId,
115    },
116}
117
118/// E-graph with congruence closure.
119///
120/// Combines a Union-Find for equivalence classes with hash-consing for
121/// structural sharing and use lists for efficient congruence propagation.
122pub struct EGraph {
123    /// The nodes stored in the graph.
124    nodes: Vec<ENode>,
125    /// Union-Find tracking equivalence classes.
126    uf: UnionFind,
127    /// Hash-consing map: node content to its canonical ID.
128    node_map: HashMap<ENode, TermId>,
129    /// Pending merges to propagate (worklist algorithm).
130    pending: Vec<(TermId, TermId)>,
131    /// Use lists: for each term, the App nodes that use it as func or arg.
132    /// Used to find potential congruences when terms are merged.
133    use_list: Vec<Vec<TermId>>,
134}
135
136impl EGraph {
137    pub fn new() -> Self {
138        EGraph {
139            nodes: Vec::new(),
140            uf: UnionFind::new(),
141            node_map: HashMap::new(),
142            pending: Vec::new(),
143            use_list: Vec::new(),
144        }
145    }
146
147    /// Add a node, return its ID (hash-consed)
148    pub fn add(&mut self, node: ENode) -> TermId {
149        // Hash-consing: return existing ID if node already exists
150        if let Some(&id) = self.node_map.get(&node) {
151            return id;
152        }
153
154        let id = self.nodes.len();
155        self.nodes.push(node.clone());
156        self.node_map.insert(node.clone(), id);
157        self.uf.make_set();
158        self.use_list.push(Vec::new());
159
160        // Register in use lists for congruence detection
161        if let ENode::App { func, arg } = &node {
162            self.use_list[*func].push(id);
163            self.use_list[*arg].push(id);
164        }
165
166        id
167    }
168
169    /// Merge two terms and propagate congruences
170    pub fn merge(&mut self, a: TermId, b: TermId) {
171        self.pending.push((a, b));
172        self.propagate();
173    }
174
175    /// Propagate congruences until fixed point
176    fn propagate(&mut self) {
177        while let Some((a, b)) = self.pending.pop() {
178            let ra = self.uf.find(a);
179            let rb = self.uf.find(b);
180            if ra == rb {
181                continue;
182            }
183
184            // Before merging, collect uses for congruence checking
185            let uses_a: Vec<TermId> = self.use_list[ra].clone();
186            let uses_b: Vec<TermId> = self.use_list[rb].clone();
187
188            // Merge equivalence classes
189            self.uf.union(ra, rb);
190            let new_root = self.uf.find(ra);
191
192            // Check for new congruences first (before modifying use lists)
193            // If f(a) and f(b) exist, and a=b now, then f(a)=f(b)
194            for &ua in &uses_a {
195                for &ub in &uses_b {
196                    if self.congruent(ua, ub) {
197                        self.pending.push((ua, ub));
198                    }
199                }
200            }
201
202            // Merge use lists (now safe to consume uses_a/uses_b)
203            if new_root == ra {
204                for u in uses_b {
205                    self.use_list[ra].push(u);
206                }
207            } else {
208                for u in uses_a {
209                    self.use_list[rb].push(u);
210                }
211            }
212        }
213    }
214
215    /// Check if two application nodes are congruent
216    fn congruent(&mut self, a: TermId, b: TermId) -> bool {
217        match (&self.nodes[a].clone(), &self.nodes[b].clone()) {
218            (ENode::App { func: f1, arg: a1 }, ENode::App { func: f2, arg: a2 }) => {
219                self.uf.find(*f1) == self.uf.find(*f2) && self.uf.find(*a1) == self.uf.find(*a2)
220            }
221            _ => false,
222        }
223    }
224
225    /// Check if two terms are in the same equivalence class
226    pub fn equivalent(&mut self, a: TermId, b: TermId) -> bool {
227        self.uf.find(a) == self.uf.find(b)
228    }
229}
230
231// =============================================================================
232// SYNTAX TERM REIFICATION
233// =============================================================================
234
235/// Reify a Syntax term into an E-graph node.
236///
237/// Converts the deep embedding (Syntax) into E-graph nodes, returning
238/// the ID of the root node. Subterms are recursively reified and
239/// hash-consed (duplicate structures share the same ID).
240///
241/// # Returns
242///
243/// `Some(id)` on successful reification, `None` if the term cannot be reified.
244pub fn reify(egraph: &mut EGraph, term: &Term) -> Option<TermId> {
245    // SLit n -> Lit(n)
246    if let Some(n) = extract_slit(term) {
247        return Some(egraph.add(ENode::Lit(n)));
248    }
249
250    // SVar i -> Var(i)
251    if let Some(i) = extract_svar(term) {
252        return Some(egraph.add(ENode::Var(i)));
253    }
254
255    // SName s -> Name(s)
256    if let Some(name) = extract_sname(term) {
257        return Some(egraph.add(ENode::Name(name)));
258    }
259
260    // SApp f a -> App { func, arg }
261    if let Some((func_term, arg_term)) = extract_sapp(term) {
262        let func = reify(egraph, &func_term)?;
263        let arg = reify(egraph, &arg_term)?;
264        return Some(egraph.add(ENode::App { func, arg }));
265    }
266
267    None
268}
269
270// =============================================================================
271// GOAL DECOMPOSITION
272// =============================================================================
273
274/// Decompose a goal into hypotheses and conclusion.
275///
276/// Peels off nested implications to extract equality hypotheses.
277/// For example, `(h1 -> h2 -> conclusion)` becomes `([h1, h2], conclusion)`.
278///
279/// # Returns
280///
281/// A tuple of:
282/// - Vector of equality hypothesis pairs (LHS, RHS)
283/// - The final conclusion term
284///
285/// Only equalities in hypothesis position are extracted; other hypotheses
286/// are ignored.
287pub fn decompose_goal(goal: &Term) -> (Vec<(Term, Term)>, Term) {
288    let mut hypotheses = Vec::new();
289    let mut current = goal.clone();
290
291    // Peel off nested implications
292    while let Some((hyp, rest)) = extract_implication(&current) {
293        if let Some((lhs, rhs)) = extract_equality(&hyp) {
294            hypotheses.push((lhs, rhs));
295        }
296        current = rest;
297    }
298
299    (hypotheses, current)
300}
301
302/// Check if a goal is provable by congruence closure.
303///
304/// This is the main entry point for the CC tactic. It builds an E-graph,
305/// adds hypothesis equalities, propagates congruences, and checks if the
306/// conclusion follows.
307///
308/// # Supported Goals
309///
310/// - Direct equalities: `(Eq a a)` succeeds by reflexivity
311/// - Implications: `(implies (Eq x y) (Eq (f x) (f y)))` succeeds by congruence
312/// - Multiple hypotheses: `(implies (Eq a b) (implies (Eq b c) (Eq a c)))`
313///
314/// # Returns
315///
316/// `true` if the goal is provable by congruence closure, `false` otherwise.
317pub fn check_goal(goal: &Term) -> bool {
318    let (hypotheses, conclusion) = decompose_goal(goal);
319
320    // Conclusion must be an equality
321    let (lhs, rhs) = match extract_equality(&conclusion) {
322        Some(eq) => eq,
323        None => return false,
324    };
325
326    let mut egraph = EGraph::new();
327
328    // IMPORTANT: Reify conclusion FIRST so that fx and fy exist in the graph
329    // with their use lists populated. This way when we merge x=y, congruence
330    // will propagate to fx=fy.
331    let lhs_id = match reify(&mut egraph, &lhs) {
332        Some(id) => id,
333        None => return false,
334    };
335
336    let rhs_id = match reify(&mut egraph, &rhs) {
337        Some(id) => id,
338        None => return false,
339    };
340
341    // Now reify and merge hypothesis equalities
342    // The subterms (x, y) will be hash-consed with the ones in fx, fy
343    for (h_lhs, h_rhs) in &hypotheses {
344        let h_lhs_id = match reify(&mut egraph, h_lhs) {
345            Some(id) => id,
346            None => return false,
347        };
348        let h_rhs_id = match reify(&mut egraph, h_rhs) {
349            Some(id) => id,
350            None => return false,
351        };
352        egraph.merge(h_lhs_id, h_rhs_id);
353    }
354
355    // Check if conclusion follows by congruence
356    egraph.equivalent(lhs_id, rhs_id)
357}
358
359// =============================================================================
360// HELPER EXTRACTORS
361// =============================================================================
362
363/// Extract integer from SLit n
364fn extract_slit(term: &Term) -> Option<i64> {
365    if let Term::App(ctor, arg) = term {
366        if let Term::Global(name) = ctor.as_ref() {
367            if name == "SLit" {
368                if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
369                    return Some(*n);
370                }
371            }
372        }
373    }
374    None
375}
376
377/// Extract variable index from SVar i
378fn extract_svar(term: &Term) -> Option<i64> {
379    if let Term::App(ctor, arg) = term {
380        if let Term::Global(name) = ctor.as_ref() {
381            if name == "SVar" {
382                if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
383                    return Some(*i);
384                }
385            }
386        }
387    }
388    None
389}
390
391/// Extract name from SName "x"
392fn extract_sname(term: &Term) -> Option<String> {
393    if let Term::App(ctor, arg) = term {
394        if let Term::Global(name) = ctor.as_ref() {
395            if name == "SName" {
396                if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
397                    return Some(s.clone());
398                }
399            }
400        }
401    }
402    None
403}
404
405/// Extract unary application: SApp f a
406fn extract_sapp(term: &Term) -> Option<(Term, Term)> {
407    if let Term::App(outer, arg) = term {
408        if let Term::App(sapp, func) = outer.as_ref() {
409            if let Term::Global(ctor) = sapp.as_ref() {
410                if ctor == "SApp" {
411                    return Some((func.as_ref().clone(), arg.as_ref().clone()));
412                }
413            }
414        }
415    }
416    None
417}
418
419/// Extract implication: SApp (SApp (SName "implies") hyp) concl
420fn extract_implication(term: &Term) -> Option<(Term, Term)> {
421    if let Some((op, hyp, concl)) = extract_binary_app(term) {
422        if op == "implies" {
423            return Some((hyp, concl));
424        }
425    }
426    None
427}
428
429/// Extract equality: SApp (SApp (SName "Eq") lhs) rhs
430fn extract_equality(term: &Term) -> Option<(Term, Term)> {
431    if let Some((op, lhs, rhs)) = extract_binary_app(term) {
432        if op == "Eq" || op == "eq" {
433            return Some((lhs, rhs));
434        }
435    }
436    None
437}
438
439/// Extract binary application: SApp (SApp (SName "op") a) b
440fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
441    if let Term::App(outer, b) = term {
442        if let Term::App(sapp_outer, inner) = outer.as_ref() {
443            if let Term::Global(ctor) = sapp_outer.as_ref() {
444                if ctor == "SApp" {
445                    if let Term::App(partial, a) = inner.as_ref() {
446                        if let Term::App(sapp_inner, op_term) = partial.as_ref() {
447                            if let Term::Global(ctor2) = sapp_inner.as_ref() {
448                                if ctor2 == "SApp" {
449                                    if let Some(op) = extract_sname(op_term) {
450                                        return Some((
451                                            op,
452                                            a.as_ref().clone(),
453                                            b.as_ref().clone(),
454                                        ));
455                                    }
456                                }
457                            }
458                        }
459                    }
460                }
461            }
462        }
463    }
464    None
465}
466
467// =============================================================================
468// UNIT TESTS
469// =============================================================================
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_union_find_basic() {
477        let mut uf = UnionFind::new();
478        let a = uf.make_set();
479        let b = uf.make_set();
480        assert_ne!(uf.find(a), uf.find(b));
481        uf.union(a, b);
482        assert_eq!(uf.find(a), uf.find(b));
483    }
484
485    #[test]
486    fn test_union_find_transitivity() {
487        let mut uf = UnionFind::new();
488        let a = uf.make_set();
489        let b = uf.make_set();
490        let c = uf.make_set();
491        uf.union(a, b);
492        uf.union(b, c);
493        assert_eq!(uf.find(a), uf.find(c));
494    }
495
496    #[test]
497    fn test_egraph_reflexive() {
498        let mut eg = EGraph::new();
499        let x = eg.add(ENode::Var(0));
500        assert!(eg.equivalent(x, x));
501    }
502
503    #[test]
504    fn test_egraph_congruence() {
505        let mut eg = EGraph::new();
506        let x = eg.add(ENode::Var(0));
507        let y = eg.add(ENode::Var(1));
508        let f = eg.add(ENode::Name("f".to_string()));
509        let fx = eg.add(ENode::App { func: f, arg: x });
510        let fy = eg.add(ENode::App { func: f, arg: y });
511
512        // Before merging x=y, f(x) != f(y)
513        assert!(!eg.equivalent(fx, fy));
514
515        // After merging x=y, congruence gives f(x) = f(y)
516        eg.merge(x, y);
517        assert!(eg.equivalent(fx, fy));
518    }
519
520    #[test]
521    fn test_egraph_nested_congruence() {
522        let mut eg = EGraph::new();
523        let a = eg.add(ENode::Var(0));
524        let b = eg.add(ENode::Var(1));
525        let c = eg.add(ENode::Var(2));
526        let f = eg.add(ENode::Name("f".to_string()));
527
528        let fa = eg.add(ENode::App { func: f, arg: a });
529        let fc = eg.add(ENode::App { func: f, arg: c });
530        let ffa = eg.add(ENode::App { func: f, arg: fa });
531        let ffc = eg.add(ENode::App { func: f, arg: fc });
532
533        // a = b, b = c should give f(f(a)) = f(f(c))
534        eg.merge(a, b);
535        eg.merge(b, c);
536        assert!(eg.equivalent(ffa, ffc));
537    }
538
539    #[test]
540    fn test_egraph_binary_congruence() {
541        let mut eg = EGraph::new();
542        let a = eg.add(ENode::Var(0));
543        let b = eg.add(ENode::Var(1));
544        let c = eg.add(ENode::Var(2));
545        let add = eg.add(ENode::Name("add".to_string()));
546
547        // add(a, c) and add(b, c) as curried: add a c = (add a) c
548        let add_a = eg.add(ENode::App { func: add, arg: a });
549        let add_b = eg.add(ENode::App { func: add, arg: b });
550        let add_a_c = eg.add(ENode::App { func: add_a, arg: c });
551        let add_b_c = eg.add(ENode::App { func: add_b, arg: c });
552
553        assert!(!eg.equivalent(add_a_c, add_b_c));
554        eg.merge(a, b);
555        assert!(eg.equivalent(add_a_c, add_b_c));
556    }
557
558    // =========================================================================
559    // EXTRACTION TESTS
560    // =========================================================================
561
562    /// Helper to build SName "s"
563    fn make_sname(s: &str) -> Term {
564        Term::App(
565            Box::new(Term::Global("SName".to_string())),
566            Box::new(Term::Lit(Literal::Text(s.to_string()))),
567        )
568    }
569
570    /// Helper to build SVar i
571    fn make_svar(i: i64) -> Term {
572        Term::App(
573            Box::new(Term::Global("SVar".to_string())),
574            Box::new(Term::Lit(Literal::Int(i))),
575        )
576    }
577
578    /// Helper to build SApp f a
579    fn make_sapp(f: Term, a: Term) -> Term {
580        Term::App(
581            Box::new(Term::App(
582                Box::new(Term::Global("SApp".to_string())),
583                Box::new(f),
584            )),
585            Box::new(a),
586        )
587    }
588
589    #[test]
590    fn test_extract_sname() {
591        let term = make_sname("f");
592        assert_eq!(extract_sname(&term), Some("f".to_string()));
593    }
594
595    #[test]
596    fn test_extract_svar() {
597        let term = make_svar(0);
598        assert_eq!(extract_svar(&term), Some(0));
599    }
600
601    #[test]
602    fn test_extract_sapp() {
603        // SApp (SName "f") (SVar 0)
604        let term = make_sapp(make_sname("f"), make_svar(0));
605        let result = extract_sapp(&term);
606        assert!(result.is_some());
607        let (func, arg) = result.unwrap();
608        assert_eq!(extract_sname(&func), Some("f".to_string()));
609        assert_eq!(extract_svar(&arg), Some(0));
610    }
611
612    #[test]
613    fn test_extract_binary_app() {
614        // SApp (SApp (SName "Eq") (SVar 0)) (SVar 1)
615        let term = make_sapp(make_sapp(make_sname("Eq"), make_svar(0)), make_svar(1));
616        let result = extract_binary_app(&term);
617        assert!(result.is_some(), "Should extract binary app");
618        let (op, a, b) = result.unwrap();
619        assert_eq!(op, "Eq");
620        assert_eq!(extract_svar(&a), Some(0));
621        assert_eq!(extract_svar(&b), Some(1));
622    }
623
624    #[test]
625    fn test_extract_equality() {
626        // SApp (SApp (SName "Eq") (SVar 0)) (SVar 1)
627        let term = make_sapp(make_sapp(make_sname("Eq"), make_svar(0)), make_svar(1));
628        let result = extract_equality(&term);
629        assert!(result.is_some(), "Should extract equality");
630        let (lhs, rhs) = result.unwrap();
631        assert_eq!(extract_svar(&lhs), Some(0));
632        assert_eq!(extract_svar(&rhs), Some(1));
633    }
634
635    #[test]
636    fn test_extract_implication() {
637        // Build: SApp (SApp (SName "implies") hyp) concl
638        // hyp = SApp (SApp (SName "Eq") x) y
639        // concl = SApp (SApp (SName "Eq") fx) fy
640        let x = make_svar(0);
641        let y = make_svar(1);
642        let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
643
644        let f = make_sname("f");
645        let fx = make_sapp(f.clone(), x);
646        let fy = make_sapp(f, y);
647        let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
648
649        let goal = make_sapp(make_sapp(make_sname("implies"), hyp.clone()), concl.clone());
650
651        let result = extract_implication(&goal);
652        assert!(result.is_some(), "Should extract implication");
653        let (hyp_extracted, concl_extracted) = result.unwrap();
654
655        // Verify hypothesis is the equality x = y
656        let hyp_eq = extract_equality(&hyp_extracted);
657        assert!(hyp_eq.is_some(), "Hypothesis should be equality");
658        let (h_lhs, h_rhs) = hyp_eq.unwrap();
659        assert_eq!(extract_svar(&h_lhs), Some(0));
660        assert_eq!(extract_svar(&h_rhs), Some(1));
661
662        // Verify conclusion is an equality
663        let concl_eq = extract_equality(&concl_extracted);
664        assert!(concl_eq.is_some(), "Conclusion should be equality");
665    }
666
667    #[test]
668    fn test_decompose_goal_with_hypothesis() {
669        // Build: (implies (Eq x y) (Eq (f x) (f y)))
670        let x = make_svar(0);
671        let y = make_svar(1);
672        let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
673
674        let f = make_sname("f");
675        let fx = make_sapp(f.clone(), x);
676        let fy = make_sapp(f, y);
677        let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
678
679        let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
680
681        let (hypotheses, conclusion) = decompose_goal(&goal);
682        assert_eq!(hypotheses.len(), 1, "Should have 1 hypothesis");
683
684        // Verify conclusion is an equality
685        let concl_eq = extract_equality(&conclusion);
686        assert!(concl_eq.is_some(), "Conclusion should be equality");
687    }
688
689    #[test]
690    fn test_check_goal_with_hypothesis() {
691        // Build: (implies (Eq x y) (Eq (f x) (f y)))
692        // This should be provable by CC
693        let x = make_svar(0);
694        let y = make_svar(1);
695        let hyp = make_sapp(make_sapp(make_sname("Eq"), x.clone()), y.clone());
696
697        let f = make_sname("f");
698        let fx = make_sapp(f.clone(), x.clone());
699        let fy = make_sapp(f.clone(), y.clone());
700        let concl = make_sapp(make_sapp(make_sname("Eq"), fx), fy);
701
702        let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
703
704        assert!(check_goal(&goal), "CC should prove x=y → f(x)=f(y)");
705    }
706}