Skip to main content

oxiz_solver/combination/
equality_propagation.rs

1//! Equality Propagation Engine for Theory Combination.
2#![allow(dead_code)] // Under development
3//!
4//! Implements efficient equality propagation between theories using:
5//! - Congruence closure with union-find
6//! - E-graph for term rewriting
7//! - Equality explanation generation
8//! - Watched equalities for lazy propagation
9
10use oxiz_core::ast::{TermId, TermKind, TermManager};
11use rustc_hash::FxHashMap;
12use std::collections::VecDeque;
13
14/// Equality propagation engine.
15pub struct EqualityPropagator {
16    /// Union-find for equality classes
17    union_find: UnionFind,
18    /// Congruence closure data structures
19    congruence: CongruenceData,
20    /// Pending equalities to propagate
21    pending: VecDeque<(TermId, TermId, Explanation)>,
22    /// Watched equalities: term → watchers
23    watched: FxHashMap<TermId, Vec<EqualityWatch>>,
24    /// E-graph for term canonicalization
25    egraph: EGraph,
26    /// Statistics
27    stats: EqualityPropStats,
28}
29
30/// Union-find data structure for equivalence classes.
31#[derive(Debug, Clone)]
32pub struct UnionFind {
33    /// Parent pointers
34    parent: FxHashMap<TermId, TermId>,
35    /// Rank for union-by-rank
36    rank: FxHashMap<TermId, usize>,
37    /// Size of equivalence class
38    size: FxHashMap<TermId, usize>,
39}
40
41/// Congruence closure data.
42#[derive(Debug, Clone)]
43pub struct CongruenceData {
44    /// Use list: term → terms that use it
45    use_list: FxHashMap<TermId, Vec<TermId>>,
46    /// Lookup table: (function, args) → term
47    lookup: FxHashMap<CongruenceKey, TermId>,
48    /// Pending congruence checks
49    pending_congruences: VecDeque<(TermId, TermId)>,
50}
51
52/// Key for congruence lookup.
53#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct CongruenceKey {
55    /// Function/operator
56    pub function: TermKind,
57    /// Canonical arguments (equivalence class representatives)
58    pub args: Vec<TermId>,
59}
60
61/// E-graph for term canonicalization.
62#[derive(Debug, Clone)]
63pub struct EGraph {
64    /// E-class membership: term → e-class
65    eclass: FxHashMap<TermId, EClassId>,
66    /// E-class contents: e-class → terms
67    nodes: FxHashMap<EClassId, Vec<TermId>>,
68    /// E-class data
69    data: FxHashMap<EClassId, EClassData>,
70    /// Next available e-class ID
71    next_id: EClassId,
72}
73
74/// E-class identifier.
75pub type EClassId = usize;
76
77/// Data associated with an e-class.
78#[derive(Debug, Clone)]
79pub struct EClassData {
80    /// Representative term
81    pub representative: TermId,
82    /// Size of e-class
83    pub size: usize,
84    /// Parent e-classes (for congruence)
85    pub parents: Vec<EClassId>,
86}
87
88/// Explanation for an equality.
89#[derive(Debug, Clone)]
90pub enum Explanation {
91    /// Given equality (axiom)
92    Given,
93    /// Equality by reflexivity
94    Reflexivity,
95    /// Equality by transitivity
96    Transitivity(TermId, Box<Explanation>, Box<Explanation>),
97    /// Equality by congruence
98    Congruence(Vec<(TermId, TermId, Box<Explanation>)>),
99    /// Theory propagation
100    TheoryPropagation(TheoryExplanation),
101}
102
103/// Theory-specific explanation.
104#[derive(Debug, Clone)]
105pub struct TheoryExplanation {
106    /// Theory ID
107    pub theory_id: usize,
108    /// Antecedent equalities
109    pub antecedents: Vec<(TermId, TermId)>,
110}
111
112/// Watched equality for lazy propagation.
113#[derive(Debug, Clone)]
114pub struct EqualityWatch {
115    /// Left-hand side
116    pub lhs: TermId,
117    /// Right-hand side
118    pub rhs: TermId,
119    /// Callback ID
120    pub callback: usize,
121}
122
123/// Equality propagation statistics.
124#[derive(Debug, Clone, Default)]
125pub struct EqualityPropStats {
126    /// Equalities propagated
127    pub equalities_propagated: usize,
128    /// Congruences found
129    pub congruences_found: usize,
130    /// E-graph merges
131    pub egraph_merges: usize,
132    /// Explanations generated
133    pub explanations_generated: usize,
134    /// Watched equality triggers
135    pub watch_triggers: usize,
136}
137
138impl UnionFind {
139    /// Create a new union-find structure.
140    pub fn new() -> Self {
141        Self {
142            parent: FxHashMap::default(),
143            rank: FxHashMap::default(),
144            size: FxHashMap::default(),
145        }
146    }
147
148    /// Find the representative of a set.
149    pub fn find(&mut self, x: TermId) -> TermId {
150        if let std::collections::hash_map::Entry::Vacant(e) = self.parent.entry(x) {
151            e.insert(x);
152            self.rank.insert(x, 0);
153            self.size.insert(x, 1);
154            return x;
155        }
156
157        let parent = self.parent[&x];
158        if parent != x {
159            // Path compression
160            let root = self.find(parent);
161            self.parent.insert(x, root);
162            root
163        } else {
164            x
165        }
166    }
167
168    /// Union two sets.
169    pub fn union(&mut self, x: TermId, y: TermId) -> bool {
170        let root_x = self.find(x);
171        let root_y = self.find(y);
172
173        if root_x == root_y {
174            return false; // Already in same set
175        }
176
177        let rank_x = self.rank.get(&root_x).copied().unwrap_or(0);
178        let rank_y = self.rank.get(&root_y).copied().unwrap_or(0);
179
180        // Union by rank
181        if rank_x < rank_y {
182            self.parent.insert(root_x, root_y);
183            let size_x = self.size.get(&root_x).copied().unwrap_or(1);
184            *self.size.entry(root_y).or_insert(1) += size_x;
185        } else if rank_x > rank_y {
186            self.parent.insert(root_y, root_x);
187            let size_y = self.size.get(&root_y).copied().unwrap_or(1);
188            *self.size.entry(root_x).or_insert(1) += size_y;
189        } else {
190            self.parent.insert(root_y, root_x);
191            *self.rank.entry(root_x).or_insert(0) += 1;
192            let size_y = self.size.get(&root_y).copied().unwrap_or(1);
193            *self.size.entry(root_x).or_insert(1) += size_y;
194        }
195
196        true
197    }
198
199    /// Check if two elements are in the same set.
200    pub fn connected(&mut self, x: TermId, y: TermId) -> bool {
201        self.find(x) == self.find(y)
202    }
203
204    /// Get size of the set containing x.
205    pub fn set_size(&mut self, x: TermId) -> usize {
206        let root = self.find(x);
207        self.size[&root]
208    }
209}
210
211impl EqualityPropagator {
212    /// Create a new equality propagator.
213    pub fn new() -> Self {
214        Self {
215            union_find: UnionFind::new(),
216            congruence: CongruenceData::new(),
217            pending: VecDeque::new(),
218            watched: FxHashMap::default(),
219            egraph: EGraph::new(),
220            stats: EqualityPropStats::default(),
221        }
222    }
223
224    /// Assert an equality.
225    pub fn assert_equality(
226        &mut self,
227        lhs: TermId,
228        rhs: TermId,
229        explanation: Explanation,
230        tm: &TermManager,
231    ) -> Result<(), String> {
232        // Check if already equal
233        if self.union_find.connected(lhs, rhs) {
234            return Ok(());
235        }
236
237        // Add to pending queue
238        self.pending.push_back((lhs, rhs, explanation));
239
240        // Propagate all pending equalities
241        self.propagate(tm)?;
242
243        Ok(())
244    }
245
246    /// Propagate all pending equalities.
247    fn propagate(&mut self, tm: &TermManager) -> Result<(), String> {
248        while let Some((lhs, rhs, explanation)) = self.pending.pop_front() {
249            self.propagate_equality(lhs, rhs, explanation, tm)?;
250        }
251
252        // Check for new congruences
253        self.check_congruences(tm)?;
254
255        Ok(())
256    }
257
258    /// Propagate a single equality.
259    fn propagate_equality(
260        &mut self,
261        lhs: TermId,
262        rhs: TermId,
263        _explanation: Explanation,
264        _tm: &TermManager,
265    ) -> Result<(), String> {
266        // Union in union-find
267        if !self.union_find.union(lhs, rhs) {
268            return Ok(()); // Already merged
269        }
270
271        self.stats.equalities_propagated += 1;
272
273        // Merge in e-graph
274        self.egraph.merge(lhs, rhs);
275        self.stats.egraph_merges += 1;
276
277        // Update use lists
278        self.congruence.merge_use_lists(lhs, rhs);
279
280        // Trigger watches
281        self.trigger_watches(lhs, rhs)?;
282
283        // Add parents to pending congruence checks
284        let lhs_parents = self.congruence.get_parents(lhs);
285        let rhs_parents = self.congruence.get_parents(rhs);
286
287        for lhs_parent in lhs_parents {
288            for &rhs_parent in &rhs_parents {
289                self.congruence
290                    .pending_congruences
291                    .push_back((lhs_parent, rhs_parent));
292            }
293        }
294
295        Ok(())
296    }
297
298    /// Check for new congruences.
299    fn check_congruences(&mut self, tm: &TermManager) -> Result<(), String> {
300        while let Some((t1, t2)) = self.congruence.pending_congruences.pop_front() {
301            // Check if they have congruent arguments
302            if self.are_congruent(t1, t2, tm)? {
303                self.stats.congruences_found += 1;
304
305                // Generate congruence explanation
306                let explanation = self.generate_congruence_explanation(t1, t2, tm)?;
307
308                // Assert equality
309                self.pending.push_back((t1, t2, explanation));
310            }
311        }
312
313        Ok(())
314    }
315
316    /// Check if two terms are congruent.
317    fn are_congruent(&mut self, t1: TermId, t2: TermId, tm: &TermManager) -> Result<bool, String> {
318        let term1 = tm.get(t1).ok_or("term not found")?;
319        let term2 = tm.get(t2).ok_or("term not found")?;
320
321        // Must have same kind
322        if std::mem::discriminant(&term1.kind) != std::mem::discriminant(&term2.kind) {
323            return Ok(false);
324        }
325
326        // Get arguments
327        let args1 = self.get_args(&term1.kind);
328        let args2 = self.get_args(&term2.kind);
329
330        if args1.len() != args2.len() {
331            return Ok(false);
332        }
333
334        // Check if all arguments are equal
335        for (arg1, arg2) in args1.iter().zip(args2.iter()) {
336            if !self.union_find.connected(*arg1, *arg2) {
337                return Ok(false);
338            }
339        }
340
341        Ok(true)
342    }
343
344    /// Generate explanation for congruence.
345    fn generate_congruence_explanation(
346        &mut self,
347        t1: TermId,
348        t2: TermId,
349        tm: &TermManager,
350    ) -> Result<Explanation, String> {
351        let term1 = tm.get(t1).ok_or("term not found")?;
352        let term2 = tm.get(t2).ok_or("term not found")?;
353
354        let args1 = self.get_args(&term1.kind);
355        let args2 = self.get_args(&term2.kind);
356
357        let mut arg_explanations = Vec::new();
358
359        for (arg1, arg2) in args1.iter().zip(args2.iter()) {
360            let expl = self.explain_equality(*arg1, *arg2)?;
361            arg_explanations.push((*arg1, *arg2, Box::new(expl)));
362        }
363
364        self.stats.explanations_generated += 1;
365
366        Ok(Explanation::Congruence(arg_explanations))
367    }
368
369    /// Explain why two terms are equal.
370    pub fn explain_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<Explanation, String> {
371        if lhs == rhs {
372            return Ok(Explanation::Reflexivity);
373        }
374
375        if !self.union_find.connected(lhs, rhs) {
376            return Err("Terms are not equal".to_string());
377        }
378
379        // Simplified: return a generic explanation
380        // Full implementation would trace union-find path
381        Ok(Explanation::Given)
382    }
383
384    /// Watch an equality.
385    pub fn watch_equality(&mut self, lhs: TermId, rhs: TermId, callback: usize) {
386        let watch = EqualityWatch { lhs, rhs, callback };
387
388        self.watched.entry(lhs).or_default().push(watch.clone());
389        self.watched.entry(rhs).or_default().push(watch);
390    }
391
392    /// Trigger watches when an equality is established.
393    fn trigger_watches(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
394        let mut triggered = Vec::new();
395
396        // Check watches on lhs
397        if let Some(watches) = self.watched.get(&lhs) {
398            for watch in watches {
399                if self.union_find.connected(watch.lhs, watch.rhs) {
400                    triggered.push(watch.callback);
401                }
402            }
403        }
404
405        // Check watches on rhs
406        if let Some(watches) = self.watched.get(&rhs) {
407            for watch in watches {
408                if self.union_find.connected(watch.lhs, watch.rhs) {
409                    triggered.push(watch.callback);
410                }
411            }
412        }
413
414        self.stats.watch_triggers += triggered.len();
415
416        Ok(())
417    }
418
419    /// Get arguments of a term.
420    fn get_args(&self, kind: &TermKind) -> Vec<TermId> {
421        match kind {
422            TermKind::And(args) | TermKind::Or(args) => args.to_vec(),
423            TermKind::Not(arg) => vec![*arg],
424            TermKind::Eq(l, r) | TermKind::Le(l, r) | TermKind::Lt(l, r) => vec![*l, *r],
425            TermKind::Add(args) | TermKind::Mul(args) => args.to_vec(),
426            _ => vec![],
427        }
428    }
429
430    /// Get statistics.
431    pub fn stats(&self) -> &EqualityPropStats {
432        &self.stats
433    }
434}
435
436impl CongruenceData {
437    /// Create new congruence data.
438    pub fn new() -> Self {
439        Self {
440            use_list: FxHashMap::default(),
441            lookup: FxHashMap::default(),
442            pending_congruences: VecDeque::new(),
443        }
444    }
445
446    /// Merge use lists when two terms become equal.
447    pub fn merge_use_lists(&mut self, t1: TermId, t2: TermId) {
448        // Simplified implementation
449        let t1_uses = self.use_list.get(&t1).cloned().unwrap_or_default();
450        let t2_uses = self.use_list.get(&t2).cloned().unwrap_or_default();
451
452        let mut merged = t1_uses;
453        merged.extend(t2_uses);
454
455        self.use_list.insert(t1, merged.clone());
456        self.use_list.insert(t2, merged);
457    }
458
459    /// Get parent terms.
460    pub fn get_parents(&self, t: TermId) -> Vec<TermId> {
461        self.use_list.get(&t).cloned().unwrap_or_default()
462    }
463}
464
465impl EGraph {
466    /// Create a new e-graph.
467    pub fn new() -> Self {
468        Self {
469            eclass: FxHashMap::default(),
470            nodes: FxHashMap::default(),
471            data: FxHashMap::default(),
472            next_id: 0,
473        }
474    }
475
476    /// Get e-class for a term.
477    pub fn get_eclass(&mut self, term: TermId) -> EClassId {
478        if let Some(&id) = self.eclass.get(&term) {
479            id
480        } else {
481            let id = self.next_id;
482            self.next_id += 1;
483
484            self.eclass.insert(term, id);
485            self.nodes.insert(id, vec![term]);
486            self.data.insert(
487                id,
488                EClassData {
489                    representative: term,
490                    size: 1,
491                    parents: Vec::new(),
492                },
493            );
494
495            id
496        }
497    }
498
499    /// Merge two terms in the e-graph.
500    pub fn merge(&mut self, t1: TermId, t2: TermId) {
501        let id1 = self.get_eclass(t1);
502        let id2 = self.get_eclass(t2);
503
504        if id1 == id2 {
505            return;
506        }
507
508        // Merge smaller into larger
509        let size1 = self.data[&id1].size;
510        let size2 = self.data[&id2].size;
511
512        let (smaller, larger) = if size1 < size2 {
513            (id1, id2)
514        } else {
515            (id2, id1)
516        };
517
518        // Update e-class membership
519        let smaller_nodes = self.nodes[&smaller].clone();
520        for &node in &smaller_nodes {
521            self.eclass.insert(node, larger);
522        }
523
524        // Merge node lists
525        if let Some(larger_nodes) = self.nodes.get_mut(&larger) {
526            larger_nodes.extend(smaller_nodes);
527        }
528        self.nodes.remove(&smaller);
529
530        // Update data
531        let smaller_size = self.data.get(&smaller).map(|d| d.size).unwrap_or(0);
532        if let Some(larger_data) = self.data.get_mut(&larger) {
533            larger_data.size += smaller_size;
534        }
535        self.data.remove(&smaller);
536    }
537}
538
539impl Default for EqualityPropagator {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545impl Default for UnionFind {
546    fn default() -> Self {
547        Self::new()
548    }
549}
550
551impl Default for CongruenceData {
552    fn default() -> Self {
553        Self::new()
554    }
555}
556
557impl Default for EGraph {
558    fn default() -> Self {
559        Self::new()
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_union_find() {
569        let mut uf = UnionFind::new();
570
571        let t1 = TermId::from(1);
572        let t2 = TermId::from(2);
573        let t3 = TermId::from(3);
574
575        assert!(!uf.connected(t1, t2));
576
577        uf.union(t1, t2);
578        assert!(uf.connected(t1, t2));
579
580        uf.union(t2, t3);
581        assert!(uf.connected(t1, t3));
582    }
583
584    #[test]
585    fn test_equality_propagator() {
586        let prop = EqualityPropagator::new();
587        assert_eq!(prop.stats.equalities_propagated, 0);
588    }
589
590    #[test]
591    fn test_egraph() {
592        let mut eg = EGraph::new();
593
594        let t1 = TermId::from(1);
595        let t2 = TermId::from(2);
596
597        let id1 = eg.get_eclass(t1);
598        let id2 = eg.get_eclass(t2);
599
600        assert_ne!(id1, id2);
601
602        eg.merge(t1, t2);
603
604        let id1_after = eg.get_eclass(t1);
605        let id2_after = eg.get_eclass(t2);
606
607        assert_eq!(id1_after, id2_after);
608    }
609}