Skip to main content

oxiz_solver/combination/
shared_terms_advanced.rs

1//! Advanced Shared Terms Management for Theory Combination.
2//!
3//! This module provides sophisticated shared term detection and management:
4//! - Efficient shared term detection across theories
5//! - Canonical representatives for equality classes
6//! - Congruence-based equality graphs
7//! - Interface term minimization
8//! - Incremental shared term tracking
9//!
10//! ## Shared Terms
11//!
12//! In Nelson-Oppen combination, shared terms are variables that appear
13//! in constraints of multiple theories. These terms form the "interface"
14//! between theories and are used for equality propagation.
15//!
16//! ## Canonical Representatives
17//!
18//! Each equivalence class of terms has a canonical representative.
19//! This module maintains:
20//! - Fast lookup of representatives
21//! - Efficient equality class merging
22//! - Explanations for why terms are equal
23//!
24//! ## Equality Graphs (E-graphs)
25//!
26//! E-graphs compactly represent equivalence classes with congruence:
27//! - Efficient term canonicalization
28//! - Congruence closure
29//! - Extraction of smallest equivalent terms
30//!
31//! ## References
32//!
33//! - Nelson & Oppen (1979): "Simplification by Cooperating Decision Procedures"
34//! - Nieuwenhuis & Oliveras (2005): "Proof-Producing Congruence Closure"
35//! - Z3's `smt/theory_combine.cpp`
36
37use oxiz_core::TermId as CoreTermId;
38use rustc_hash::{FxHashMap, FxHashSet};
39use std::collections::VecDeque;
40
41/// Local term identifier (wrapper for core TermId).
42pub type TermId = CoreTermId;
43
44/// Theory identifier.
45pub type TheoryId = usize;
46
47/// Decision level for backtracking.
48pub type DecisionLevel = u32;
49
50/// Equality between two terms.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct Equality {
53    /// Left-hand side term.
54    pub lhs: TermId,
55    /// Right-hand side term.
56    pub rhs: TermId,
57}
58
59impl Equality {
60    /// Create a new equality (normalized: smaller term first).
61    pub fn new(lhs: TermId, rhs: TermId) -> Self {
62        if lhs.raw() <= rhs.raw() {
63            Self { lhs, rhs }
64        } else {
65            Self { lhs: rhs, rhs: lhs }
66        }
67    }
68
69    /// Flip the equality.
70    pub fn flip(self) -> Self {
71        Self::new(self.rhs, self.lhs)
72    }
73}
74
75/// Explanation for why two terms are equal.
76#[derive(Debug, Clone)]
77pub enum EqualityExplanation {
78    /// Given as input axiom.
79    Given,
80    /// Reflexivity: t = t.
81    Reflexive,
82    /// Theory propagation.
83    TheoryPropagation {
84        /// Source theory.
85        theory: TheoryId,
86        /// Supporting equalities.
87        support: Vec<Equality>,
88    },
89    /// Transitivity: a = b, b = c => a = c.
90    Transitive {
91        /// Intermediate term.
92        intermediate: TermId,
93        /// Left explanation.
94        left: Box<EqualityExplanation>,
95        /// Right explanation.
96        right: Box<EqualityExplanation>,
97    },
98    /// Congruence: f(a) = f(b) if a = b.
99    Congruence {
100        /// Function term.
101        function: TermId,
102        /// Argument equalities.
103        arg_equalities: Vec<(Equality, Box<EqualityExplanation>)>,
104    },
105}
106
107/// E-class (equivalence class) identifier.
108pub type EClassId = u32;
109
110/// E-node (term in e-graph).
111#[derive(Debug, Clone, PartialEq, Eq, Hash)]
112pub struct ENode {
113    /// Term identifier.
114    pub term: TermId,
115    /// E-class this node belongs to.
116    pub eclass: EClassId,
117}
118
119/// E-class (equivalence class in e-graph).
120#[derive(Debug, Clone)]
121pub struct EClass {
122    /// Unique identifier.
123    pub id: EClassId,
124    /// Representative term.
125    pub representative: TermId,
126    /// All terms in this class.
127    pub members: FxHashSet<TermId>,
128    /// Parent e-classes (for congruence).
129    pub parents: FxHashSet<EClassId>,
130    /// Size of this e-class (for union-by-size).
131    pub size: usize,
132}
133
134impl EClass {
135    /// Create new e-class with single term.
136    fn new(id: EClassId, term: TermId) -> Self {
137        let mut members = FxHashSet::default();
138        members.insert(term);
139
140        Self {
141            id,
142            representative: term,
143            members,
144            parents: FxHashSet::default(),
145            size: 1,
146        }
147    }
148
149    /// Merge another e-class into this one.
150    fn merge(&mut self, other: &EClass) {
151        for &term in &other.members {
152            self.members.insert(term);
153        }
154        for &parent in &other.parents {
155            self.parents.insert(parent);
156        }
157        self.size += other.size;
158    }
159}
160
161/// Equality graph (e-graph) for congruence closure.
162#[derive(Debug, Clone)]
163pub struct EGraph {
164    /// Term to e-class mapping.
165    term_to_eclass: FxHashMap<TermId, EClassId>,
166
167    /// E-class to class data mapping.
168    eclasses: FxHashMap<EClassId, EClass>,
169
170    /// Next available e-class ID.
171    next_eclass_id: EClassId,
172
173    /// Union-find parent pointers (for e-class merging).
174    parent: FxHashMap<EClassId, EClassId>,
175
176    /// Rank for union-by-rank.
177    rank: FxHashMap<EClassId, usize>,
178
179    /// Pending congruences to process.
180    pending_congruences: VecDeque<(EClassId, EClassId)>,
181
182    /// Explanations for e-class merges.
183    merge_explanations: FxHashMap<(EClassId, EClassId), EqualityExplanation>,
184}
185
186impl EGraph {
187    /// Create new e-graph.
188    pub fn new() -> Self {
189        Self {
190            term_to_eclass: FxHashMap::default(),
191            eclasses: FxHashMap::default(),
192            next_eclass_id: 0,
193            parent: FxHashMap::default(),
194            rank: FxHashMap::default(),
195            pending_congruences: VecDeque::new(),
196            merge_explanations: FxHashMap::default(),
197        }
198    }
199
200    /// Add term to e-graph.
201    pub fn add_term(&mut self, term: TermId) -> EClassId {
202        if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
203            return self.find(eclass_id);
204        }
205
206        let eclass_id = self.next_eclass_id;
207        self.next_eclass_id += 1;
208
209        let eclass = EClass::new(eclass_id, term);
210        self.eclasses.insert(eclass_id, eclass);
211        self.term_to_eclass.insert(term, eclass_id);
212
213        eclass_id
214    }
215
216    /// Find canonical e-class ID (with path compression).
217    pub fn find(&mut self, mut eclass_id: EClassId) -> EClassId {
218        let mut path = Vec::new();
219
220        while let Some(&parent) = self.parent.get(&eclass_id) {
221            if parent == eclass_id {
222                break;
223            }
224            path.push(eclass_id);
225            eclass_id = parent;
226        }
227
228        // Path compression
229        for node in path {
230            self.parent.insert(node, eclass_id);
231        }
232
233        eclass_id
234    }
235
236    /// Merge two e-classes.
237    pub fn merge(
238        &mut self,
239        a: EClassId,
240        b: EClassId,
241        explanation: EqualityExplanation,
242    ) -> Result<EClassId, String> {
243        let a_root = self.find(a);
244        let b_root = self.find(b);
245
246        if a_root == b_root {
247            return Ok(a_root);
248        }
249
250        // Union by rank
251        let a_rank = self.rank.get(&a_root).copied().unwrap_or(0);
252        let b_rank = self.rank.get(&b_root).copied().unwrap_or(0);
253
254        let (child, parent_id) = if a_rank < b_rank {
255            (a_root, b_root)
256        } else if a_rank > b_rank {
257            (b_root, a_root)
258        } else {
259            self.rank.insert(b_root, b_rank + 1);
260            (a_root, b_root)
261        };
262
263        self.parent.insert(child, parent_id);
264
265        // Merge e-class data
266        if let Some(child_eclass) = self.eclasses.get(&child).cloned()
267            && let Some(parent_eclass) = self.eclasses.get_mut(&parent_id)
268        {
269            parent_eclass.merge(&child_eclass);
270        }
271
272        // Store explanation
273        self.merge_explanations
274            .insert((child, parent_id), explanation);
275
276        // Queue congruence checks
277        self.queue_congruence_checks(child, parent_id);
278
279        Ok(parent_id)
280    }
281
282    /// Queue congruence checks for parent terms.
283    fn queue_congruence_checks(&mut self, _a: EClassId, _b: EClassId) {
284        // Simplified: would check parent applications for congruence
285    }
286
287    /// Process pending congruences.
288    pub fn process_congruences(&mut self) -> Result<(), String> {
289        while let Some((a, b)) = self.pending_congruences.pop_front() {
290            let a_root = self.find(a);
291            let b_root = self.find(b);
292
293            if a_root != b_root {
294                self.merge(
295                    a_root,
296                    b_root,
297                    EqualityExplanation::Congruence {
298                        function: TermId::new(0), // Simplified
299                        arg_equalities: Vec::new(),
300                    },
301                )?;
302            }
303        }
304
305        Ok(())
306    }
307
308    /// Get canonical term for an e-class.
309    pub fn get_representative(&mut self, term: TermId) -> Option<TermId> {
310        let eclass_id = *self.term_to_eclass.get(&term)?;
311        let root = self.find(eclass_id);
312        self.eclasses.get(&root).map(|ec| ec.representative)
313    }
314
315    /// Check if two terms are in the same e-class.
316    pub fn are_equal(&mut self, a: TermId, b: TermId) -> bool {
317        if let (Some(&a_class), Some(&b_class)) =
318            (self.term_to_eclass.get(&a), self.term_to_eclass.get(&b))
319        {
320            self.find(a_class) == self.find(b_class)
321        } else {
322            false
323        }
324    }
325
326    /// Get explanation for why two terms are equal.
327    pub fn get_explanation(&mut self, a: TermId, b: TermId) -> Option<EqualityExplanation> {
328        if !self.are_equal(a, b) {
329            return None;
330        }
331
332        if a == b {
333            return Some(EqualityExplanation::Reflexive);
334        }
335
336        // Trace path through union-find to build explanation
337        let a_class = self.term_to_eclass.get(&a)?;
338        let b_class = self.term_to_eclass.get(&b)?;
339
340        // Copy values to avoid borrow checker issues
341        let a_class_val = *a_class;
342        let b_class_val = *b_class;
343
344        let a_root = self.find(a_class_val);
345        let b_root = self.find(b_class_val);
346
347        if a_root == b_root {
348            // Find stored explanation
349            if let Some(explanation) = self.merge_explanations.get(&(a_class_val, b_class_val)) {
350                return Some(explanation.clone());
351            }
352        }
353
354        None
355    }
356
357    /// Get all terms in the same e-class as a term.
358    pub fn get_eclass_members(&mut self, term: TermId) -> Vec<TermId> {
359        if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
360            let root = self.find(eclass_id);
361            if let Some(eclass) = self.eclasses.get(&root) {
362                return eclass.members.iter().copied().collect();
363            }
364        }
365        Vec::new()
366    }
367
368    /// Clear all state.
369    pub fn clear(&mut self) {
370        self.term_to_eclass.clear();
371        self.eclasses.clear();
372        self.next_eclass_id = 0;
373        self.parent.clear();
374        self.rank.clear();
375        self.pending_congruences.clear();
376        self.merge_explanations.clear();
377    }
378}
379
380impl Default for EGraph {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386/// Information about a shared term.
387#[derive(Debug, Clone)]
388pub struct SharedTermInfo {
389    /// Theories that use this term.
390    pub theories: FxHashSet<TheoryId>,
391
392    /// Is this term an interface term?
393    pub is_interface: bool,
394
395    /// Representative in equality class.
396    pub representative: TermId,
397
398    /// Size of equivalence class.
399    pub class_size: usize,
400
401    /// Decision level where this term became shared.
402    pub shared_at_level: DecisionLevel,
403}
404
405impl SharedTermInfo {
406    /// Create new shared term info.
407    fn new(theory: TheoryId, level: DecisionLevel) -> Self {
408        let mut theories = FxHashSet::default();
409        theories.insert(theory);
410
411        Self {
412            theories,
413            is_interface: false,
414            representative: TermId::new(0), // Will be set later
415            class_size: 1,
416            shared_at_level: level,
417        }
418    }
419}
420
421/// Configuration for shared terms manager.
422#[derive(Debug, Clone)]
423pub struct SharedTermsConfig {
424    /// Enable notification batching.
425    pub enable_batching: bool,
426
427    /// Maximum batch size before forcing flush.
428    pub max_batch_size: usize,
429
430    /// Enable e-graph for congruence closure.
431    pub enable_egraph: bool,
432
433    /// Enable interface term minimization.
434    pub minimize_interface: bool,
435
436    /// Track explanations.
437    pub track_explanations: bool,
438}
439
440impl Default for SharedTermsConfig {
441    fn default() -> Self {
442        Self {
443            enable_batching: true,
444            max_batch_size: 1000,
445            enable_egraph: true,
446            minimize_interface: true,
447            track_explanations: true,
448        }
449    }
450}
451
452/// Statistics for shared terms.
453#[derive(Debug, Clone, Default)]
454pub struct SharedTermsStats {
455    /// Number of shared terms registered.
456    pub terms_registered: u64,
457
458    /// Number of theory subscriptions.
459    pub subscriptions: u64,
460
461    /// Equalities propagated.
462    pub equalities_propagated: u64,
463
464    /// Notification batches sent.
465    pub batches_sent: u64,
466
467    /// E-class merges performed.
468    pub eclass_merges: u64,
469
470    /// Congruences detected.
471    pub congruences: u64,
472
473    /// Interface terms identified.
474    pub interface_terms: u64,
475}
476
477/// Advanced shared terms manager for theory combination.
478#[derive(Debug)]
479pub struct AdvancedSharedTermsManager {
480    /// Configuration.
481    config: SharedTermsConfig,
482
483    /// Shared term information.
484    terms: FxHashMap<TermId, SharedTermInfo>,
485
486    /// E-graph for equality management.
487    egraph: EGraph,
488
489    /// Pending equalities to propagate.
490    pending_equalities: Vec<Equality>,
491
492    /// Theories subscribed to each term.
493    subscriptions: FxHashMap<TermId, FxHashSet<TheoryId>>,
494
495    /// Interface terms (truly shared between theories).
496    interface_terms: FxHashSet<TermId>,
497
498    /// Decision level history for backtracking.
499    decision_levels: FxHashMap<DecisionLevel, Vec<TermId>>,
500
501    /// Current decision level.
502    current_level: DecisionLevel,
503
504    /// Statistics.
505    stats: SharedTermsStats,
506
507    /// Equality explanations.
508    explanations: FxHashMap<Equality, EqualityExplanation>,
509}
510
511impl AdvancedSharedTermsManager {
512    /// Create a new shared terms manager.
513    pub fn new(config: SharedTermsConfig) -> Self {
514        Self {
515            config,
516            terms: FxHashMap::default(),
517            egraph: EGraph::new(),
518            pending_equalities: Vec::new(),
519            subscriptions: FxHashMap::default(),
520            interface_terms: FxHashSet::default(),
521            decision_levels: FxHashMap::default(),
522            current_level: 0,
523            stats: SharedTermsStats::default(),
524            explanations: FxHashMap::default(),
525        }
526    }
527
528    /// Create with default configuration.
529    pub fn default_config() -> Self {
530        Self::new(SharedTermsConfig::default())
531    }
532
533    /// Register a shared term.
534    pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
535        let is_new = !self.terms.contains_key(&term);
536
537        let entry = self.terms.entry(term).or_insert_with(|| {
538            self.stats.terms_registered += 1;
539            SharedTermInfo::new(theory, self.current_level)
540        });
541
542        let was_single_theory = entry.theories.len() == 1;
543        entry.theories.insert(theory);
544
545        // Track as interface term if used by multiple theories
546        if was_single_theory && entry.theories.len() > 1 {
547            self.interface_terms.insert(term);
548            entry.is_interface = true;
549            self.stats.interface_terms += 1;
550        }
551
552        self.stats.subscriptions += 1;
553
554        // Track subscriptions
555        self.subscriptions.entry(term).or_default().insert(theory);
556
557        // Add to e-graph
558        if self.config.enable_egraph {
559            self.egraph.add_term(term);
560        }
561
562        // Track in decision level history
563        if is_new {
564            self.decision_levels
565                .entry(self.current_level)
566                .or_default()
567                .push(term);
568        }
569    }
570
571    /// Check if a term is shared between multiple theories.
572    pub fn is_shared(&self, term: TermId) -> bool {
573        self.terms
574            .get(&term)
575            .map(|info| info.theories.len() > 1)
576            .unwrap_or(false)
577    }
578
579    /// Check if a term is an interface term.
580    pub fn is_interface_term(&self, term: TermId) -> bool {
581        self.interface_terms.contains(&term)
582    }
583
584    /// Get theories that use a term.
585    pub fn get_theories(&self, term: TermId) -> Vec<TheoryId> {
586        self.terms
587            .get(&term)
588            .map(|info| info.theories.iter().copied().collect())
589            .unwrap_or_default()
590    }
591
592    /// Assert equality between two terms.
593    ///
594    /// This merges their equivalence classes and queues notifications.
595    pub fn assert_equality(
596        &mut self,
597        lhs: TermId,
598        rhs: TermId,
599        explanation: EqualityExplanation,
600    ) -> Result<(), String> {
601        if self.config.enable_egraph && self.egraph.are_equal(lhs, rhs) {
602            return Ok(()); // Already equal
603        }
604
605        // Merge in e-graph
606        if self.config.enable_egraph {
607            let lhs_class = self.egraph.add_term(lhs);
608            let rhs_class = self.egraph.add_term(rhs);
609            self.egraph
610                .merge(lhs_class, rhs_class, explanation.clone())?;
611            self.stats.eclass_merges += 1;
612        }
613
614        // Queue equality for propagation
615        let equality = Equality::new(lhs, rhs);
616        self.pending_equalities.push(equality);
617        self.stats.equalities_propagated += 1;
618
619        // Store explanation
620        if self.config.track_explanations {
621            self.explanations.insert(equality, explanation);
622        }
623
624        // Check if should flush batch
625        if self.pending_equalities.len() >= self.config.max_batch_size {
626            self.flush_equalities();
627        }
628
629        Ok(())
630    }
631
632    /// Check if two terms are in the same equivalence class.
633    pub fn are_equal(&mut self, lhs: TermId, rhs: TermId) -> bool {
634        if !self.config.enable_egraph {
635            return lhs == rhs;
636        }
637
638        self.egraph.are_equal(lhs, rhs)
639    }
640
641    /// Get canonical representative of a term's equivalence class.
642    pub fn get_representative(&mut self, term: TermId) -> TermId {
643        if !self.config.enable_egraph {
644            return term;
645        }
646
647        self.egraph.get_representative(term).unwrap_or(term)
648    }
649
650    /// Get all terms in the same equivalence class.
651    pub fn get_eclass_members(&mut self, term: TermId) -> Vec<TermId> {
652        if !self.config.enable_egraph {
653            return vec![term];
654        }
655
656        self.egraph.get_eclass_members(term)
657    }
658
659    /// Get explanation for why two terms are equal.
660    pub fn get_equality_explanation(
661        &mut self,
662        lhs: TermId,
663        rhs: TermId,
664    ) -> Option<EqualityExplanation> {
665        let eq = Equality::new(lhs, rhs);
666
667        if let Some(explanation) = self.explanations.get(&eq) {
668            return Some(explanation.clone());
669        }
670
671        if self.config.enable_egraph {
672            return self.egraph.get_explanation(lhs, rhs);
673        }
674
675        None
676    }
677
678    /// Get pending equalities to propagate.
679    pub fn get_pending_equalities(&self) -> &[Equality] {
680        &self.pending_equalities
681    }
682
683    /// Flush pending equalities (send to theories).
684    pub fn flush_equalities(&mut self) {
685        if !self.pending_equalities.is_empty() {
686            self.stats.batches_sent += 1;
687            self.pending_equalities.clear();
688        }
689    }
690
691    /// Get all shared terms.
692    pub fn get_shared_terms(&self) -> Vec<TermId> {
693        self.terms
694            .iter()
695            .filter(|(_, info)| info.theories.len() > 1)
696            .map(|(&term, _)| term)
697            .collect()
698    }
699
700    /// Get interface terms (minimal shared terms).
701    pub fn get_interface_terms(&self) -> Vec<TermId> {
702        self.interface_terms.iter().copied().collect()
703    }
704
705    /// Minimize interface terms.
706    ///
707    /// Reduce the number of interface terms by using canonical representatives.
708    pub fn minimize_interface(&mut self) -> Vec<TermId> {
709        if !self.config.minimize_interface || !self.config.enable_egraph {
710            return self.get_interface_terms();
711        }
712
713        let mut minimal = FxHashSet::default();
714
715        // Collect terms first to avoid borrow checker issues
716        let terms: Vec<_> = self.interface_terms.iter().copied().collect();
717        for term in terms {
718            let rep = self.get_representative(term);
719            minimal.insert(rep);
720        }
721
722        minimal.into_iter().collect()
723    }
724
725    /// Push a new decision level.
726    pub fn push_decision_level(&mut self) {
727        self.current_level += 1;
728    }
729
730    /// Backtrack to a decision level.
731    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
732        if level > self.current_level {
733            return Err("Cannot backtrack to future level".to_string());
734        }
735
736        // Remove terms registered above this level
737        let levels_to_remove: Vec<_> = self
738            .decision_levels
739            .keys()
740            .filter(|&&l| l > level)
741            .copied()
742            .collect();
743
744        for l in levels_to_remove {
745            if let Some(terms) = self.decision_levels.remove(&l) {
746                for term in terms {
747                    self.terms.remove(&term);
748                    self.subscriptions.remove(&term);
749                    self.interface_terms.remove(&term);
750                }
751            }
752        }
753
754        self.current_level = level;
755        Ok(())
756    }
757
758    /// Get statistics.
759    pub fn stats(&self) -> &SharedTermsStats {
760        &self.stats
761    }
762
763    /// Reset manager state.
764    pub fn reset(&mut self) {
765        self.terms.clear();
766        self.egraph.clear();
767        self.pending_equalities.clear();
768        self.subscriptions.clear();
769        self.interface_terms.clear();
770        self.decision_levels.clear();
771        self.current_level = 0;
772        self.explanations.clear();
773        self.stats = SharedTermsStats::default();
774    }
775
776    /// Process pending congruences in e-graph.
777    pub fn process_congruences(&mut self) -> Result<(), String> {
778        if !self.config.enable_egraph {
779            return Ok(());
780        }
781
782        self.egraph.process_congruences()?;
783        Ok(())
784    }
785
786    /// Detect new shared terms based on current theory assignments.
787    pub fn detect_shared_terms(&mut self, _term_theories: &FxHashMap<TermId, FxHashSet<TheoryId>>) {
788        // Simplified: would analyze term occurrences across theories
789    }
790
791    /// Build equality explanation chain.
792    pub fn build_explanation_chain(
793        &self,
794        equalities: &[Equality],
795    ) -> Result<EqualityExplanation, String> {
796        if equalities.is_empty() {
797            return Err("No equalities to explain".to_string());
798        }
799
800        if equalities.len() == 1 {
801            let eq = &equalities[0];
802            return Ok(self
803                .explanations
804                .get(eq)
805                .cloned()
806                .unwrap_or(EqualityExplanation::Given));
807        }
808
809        // Build transitive chain
810        let mut current = equalities[0];
811        let mut explanation = self
812            .explanations
813            .get(&current)
814            .cloned()
815            .unwrap_or(EqualityExplanation::Given);
816
817        for &eq in &equalities[1..] {
818            let next_explanation = self
819                .explanations
820                .get(&eq)
821                .cloned()
822                .unwrap_or(EqualityExplanation::Given);
823
824            explanation = EqualityExplanation::Transitive {
825                intermediate: current.rhs,
826                left: Box::new(explanation),
827                right: Box::new(next_explanation),
828            };
829
830            current = eq;
831        }
832
833        Ok(explanation)
834    }
835}
836
837impl Default for AdvancedSharedTermsManager {
838    fn default() -> Self {
839        Self::default_config()
840    }
841}
842
843/// Interface term minimizer.
844///
845/// Reduces the number of interface terms by finding minimal representatives.
846pub struct InterfaceTermMinimizer {
847    /// E-graph for term equivalences.
848    egraph: EGraph,
849
850    /// Candidate interface terms.
851    candidates: FxHashSet<TermId>,
852}
853
854impl InterfaceTermMinimizer {
855    /// Create new minimizer.
856    pub fn new() -> Self {
857        Self {
858            egraph: EGraph::new(),
859            candidates: FxHashSet::default(),
860        }
861    }
862
863    /// Add candidate interface term.
864    pub fn add_candidate(&mut self, term: TermId) {
865        self.candidates.insert(term);
866        self.egraph.add_term(term);
867    }
868
869    /// Add equality between terms.
870    pub fn add_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
871        let lhs_class = self.egraph.add_term(lhs);
872        let rhs_class = self.egraph.add_term(rhs);
873        self.egraph
874            .merge(lhs_class, rhs_class, EqualityExplanation::Given)?;
875        Ok(())
876    }
877
878    /// Compute minimal interface terms.
879    pub fn minimize(&mut self) -> Vec<TermId> {
880        let mut minimal = FxHashSet::default();
881
882        for &term in &self.candidates {
883            let rep = self.egraph.get_representative(term).unwrap_or(term);
884            minimal.insert(rep);
885        }
886
887        minimal.into_iter().collect()
888    }
889
890    /// Clear state.
891    pub fn clear(&mut self) {
892        self.egraph.clear();
893        self.candidates.clear();
894    }
895}
896
897impl Default for InterfaceTermMinimizer {
898    fn default() -> Self {
899        Self::new()
900    }
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906
907    fn term(id: u32) -> TermId {
908        TermId::new(id)
909    }
910
911    #[test]
912    fn test_equality_creation() {
913        let eq1 = Equality::new(term(1), term(2));
914        let eq2 = Equality::new(term(2), term(1));
915        assert_eq!(eq1, eq2);
916    }
917
918    #[test]
919    fn test_egraph_creation() {
920        let egraph = EGraph::new();
921        assert_eq!(egraph.next_eclass_id, 0);
922    }
923
924    #[test]
925    fn test_egraph_add_term() {
926        let mut egraph = EGraph::new();
927        let class1 = egraph.add_term(term(1));
928        let class2 = egraph.add_term(term(1));
929        assert_eq!(class1, class2);
930    }
931
932    #[test]
933    fn test_egraph_merge() {
934        let mut egraph = EGraph::new();
935        let c1 = egraph.add_term(term(1));
936        let c2 = egraph.add_term(term(2));
937
938        egraph
939            .merge(c1, c2, EqualityExplanation::Given)
940            .expect("Merge failed");
941
942        assert!(egraph.are_equal(term(1), term(2)));
943    }
944
945    #[test]
946    fn test_egraph_transitivity() {
947        let mut egraph = EGraph::new();
948        let c1 = egraph.add_term(term(1));
949        let c2 = egraph.add_term(term(2));
950        let c3 = egraph.add_term(term(3));
951
952        egraph
953            .merge(c1, c2, EqualityExplanation::Given)
954            .expect("Merge failed");
955        egraph
956            .merge(c2, c3, EqualityExplanation::Given)
957            .expect("Merge failed");
958
959        assert!(egraph.are_equal(term(1), term(3)));
960    }
961
962    #[test]
963    fn test_manager_creation() {
964        let manager = AdvancedSharedTermsManager::default_config();
965        assert_eq!(manager.stats().terms_registered, 0);
966    }
967
968    #[test]
969    fn test_register_term() {
970        let mut manager = AdvancedSharedTermsManager::default_config();
971
972        manager.register_term(term(1), 0); // Theory 0
973        manager.register_term(term(1), 1); // Theory 1
974
975        assert!(manager.is_shared(term(1)));
976        assert!(manager.is_interface_term(term(1)));
977        assert_eq!(manager.get_theories(term(1)).len(), 2);
978    }
979
980    #[test]
981    fn test_equality_assertion() {
982        let mut manager = AdvancedSharedTermsManager::default_config();
983
984        manager
985            .assert_equality(term(1), term(2), EqualityExplanation::Given)
986            .expect("Assert failed");
987
988        assert!(manager.are_equal(term(1), term(2)));
989        assert_eq!(manager.get_pending_equalities().len(), 1);
990    }
991
992    #[test]
993    fn test_representative() {
994        let mut manager = AdvancedSharedTermsManager::default_config();
995
996        manager
997            .assert_equality(term(1), term(2), EqualityExplanation::Given)
998            .expect("Assert failed");
999
1000        let rep1 = manager.get_representative(term(1));
1001        let rep2 = manager.get_representative(term(2));
1002
1003        assert_eq!(rep1, rep2);
1004    }
1005
1006    #[test]
1007    fn test_eclass_members() {
1008        let mut manager = AdvancedSharedTermsManager::default_config();
1009
1010        manager
1011            .assert_equality(term(1), term(2), EqualityExplanation::Given)
1012            .expect("Assert failed");
1013
1014        let members = manager.get_eclass_members(term(1));
1015        assert!(members.contains(&term(1)));
1016        assert!(members.contains(&term(2)));
1017    }
1018
1019    #[test]
1020    fn test_flush_equalities() {
1021        let mut manager = AdvancedSharedTermsManager::default_config();
1022
1023        manager
1024            .assert_equality(term(1), term(2), EqualityExplanation::Given)
1025            .expect("Assert failed");
1026        assert_eq!(manager.get_pending_equalities().len(), 1);
1027
1028        manager.flush_equalities();
1029        assert_eq!(manager.get_pending_equalities().len(), 0);
1030    }
1031
1032    #[test]
1033    fn test_interface_term_minimization() {
1034        let mut manager = AdvancedSharedTermsManager::default_config();
1035
1036        manager.register_term(term(1), 0);
1037        manager.register_term(term(1), 1);
1038        manager.register_term(term(2), 0);
1039        manager.register_term(term(2), 1);
1040
1041        manager
1042            .assert_equality(term(1), term(2), EqualityExplanation::Given)
1043            .expect("Assert failed");
1044
1045        let minimal = manager.minimize_interface();
1046        assert_eq!(minimal.len(), 1);
1047    }
1048
1049    #[test]
1050    fn test_decision_levels() {
1051        let mut manager = AdvancedSharedTermsManager::default_config();
1052
1053        manager.push_decision_level();
1054        manager.register_term(term(1), 0);
1055
1056        manager.push_decision_level();
1057        manager.register_term(term(2), 0);
1058
1059        manager.backtrack(1).expect("Backtrack failed");
1060
1061        assert!(manager.terms.contains_key(&term(1)));
1062        assert!(!manager.terms.contains_key(&term(2)));
1063    }
1064
1065    #[test]
1066    fn test_interface_minimizer() {
1067        let mut minimizer = InterfaceTermMinimizer::new();
1068
1069        minimizer.add_candidate(term(1));
1070        minimizer.add_candidate(term(2));
1071        minimizer.add_candidate(term(3));
1072
1073        minimizer
1074            .add_equality(term(1), term(2))
1075            .expect("Equality failed");
1076
1077        let minimal = minimizer.minimize();
1078        assert_eq!(minimal.len(), 2); // {rep(1,2), 3}
1079    }
1080}