1use rustc_hash::{FxHashMap, FxHashSet};
36use std::cmp::Ordering;
37use std::collections::BinaryHeap;
38
39pub type TermId = u32;
41
42pub type TheoryId = u32;
44
45pub type DecisionLevel = u32;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
50pub struct Equality {
51 pub lhs: TermId,
53 pub rhs: TermId,
55}
56
57impl Equality {
58 pub fn new(lhs: TermId, rhs: TermId) -> Self {
60 if lhs <= rhs {
61 Self { lhs, rhs }
62 } else {
63 Self { lhs: rhs, rhs: lhs }
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum GenerationStrategy {
71 Eager,
73 Lazy,
75 Minimal,
77 Incremental,
79 Adaptive,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub struct EqualityPriority {
86 pub level: u32,
88 pub relevancy: u32,
90 pub decision_level: DecisionLevel,
92}
93
94impl Ord for EqualityPriority {
95 fn cmp(&self, other: &Self) -> Ordering {
96 self.level
97 .cmp(&other.level)
98 .then_with(|| self.relevancy.cmp(&other.relevancy))
99 .then_with(|| other.decision_level.cmp(&self.decision_level))
100 }
101}
102
103impl PartialOrd for EqualityPriority {
104 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
105 Some(self.cmp(other))
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct InterfaceEquality {
112 pub equality: Equality,
114 pub theories: FxHashSet<TheoryId>,
116 pub priority: EqualityPriority,
118 pub is_necessary: bool,
120 pub timestamp: u64,
122}
123
124impl PartialEq for InterfaceEquality {
125 fn eq(&self, other: &Self) -> bool {
126 self.equality == other.equality
127 }
128}
129
130impl Eq for InterfaceEquality {}
131
132impl Ord for InterfaceEquality {
133 fn cmp(&self, other: &Self) -> Ordering {
134 self.priority.cmp(&other.priority)
135 }
136}
137
138impl PartialOrd for InterfaceEquality {
139 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140 Some(self.cmp(other))
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct InterfaceEClass {
147 pub representative: TermId,
149 pub members: FxHashSet<TermId>,
151 pub theories: FxHashSet<TheoryId>,
153 pub strategy: GenerationStrategy,
155}
156
157impl InterfaceEClass {
158 fn new(representative: TermId, theory: TheoryId) -> Self {
160 let mut members = FxHashSet::default();
161 members.insert(representative);
162
163 let mut theories = FxHashSet::default();
164 theories.insert(theory);
165
166 Self {
167 representative,
168 members,
169 theories,
170 strategy: GenerationStrategy::Minimal,
171 }
172 }
173
174 fn add_term(&mut self, term: TermId, theory: TheoryId) {
176 self.members.insert(term);
177 self.theories.insert(theory);
178 }
179
180 fn merge(&mut self, other: &InterfaceEClass) {
182 for &term in &other.members {
183 self.members.insert(term);
184 }
185 for &theory in &other.theories {
186 self.theories.insert(theory);
187 }
188 }
189
190 fn is_shared(&self) -> bool {
192 self.theories.len() > 1
193 }
194
195 fn generate_equalities(
197 &self,
198 timestamp: u64,
199 decision_level: DecisionLevel,
200 ) -> Vec<InterfaceEquality> {
201 match self.strategy {
202 GenerationStrategy::Eager => self.generate_eager(timestamp, decision_level),
203 GenerationStrategy::Lazy => Vec::new(), GenerationStrategy::Minimal => self.generate_minimal(timestamp, decision_level),
205 GenerationStrategy::Incremental => self.generate_incremental(timestamp, decision_level),
206 GenerationStrategy::Adaptive => self.generate_adaptive(timestamp, decision_level),
207 }
208 }
209
210 fn generate_eager(
212 &self,
213 timestamp: u64,
214 decision_level: DecisionLevel,
215 ) -> Vec<InterfaceEquality> {
216 let mut equalities = Vec::new();
217 let members: Vec<_> = self.members.iter().copied().collect();
218
219 for i in 0..members.len() {
220 for j in (i + 1)..members.len() {
221 equalities.push(InterfaceEquality {
222 equality: Equality::new(members[i], members[j]),
223 theories: self.theories.clone(),
224 priority: EqualityPriority {
225 level: 100,
226 relevancy: 50,
227 decision_level,
228 },
229 is_necessary: false,
230 timestamp,
231 });
232 }
233 }
234
235 equalities
236 }
237
238 fn generate_minimal(
240 &self,
241 timestamp: u64,
242 decision_level: DecisionLevel,
243 ) -> Vec<InterfaceEquality> {
244 let mut equalities = Vec::new();
245 let rep = self.representative;
246
247 for &term in &self.members {
248 if term != rep {
249 equalities.push(InterfaceEquality {
250 equality: Equality::new(term, rep),
251 theories: self.theories.clone(),
252 priority: EqualityPriority {
253 level: 100,
254 relevancy: 50,
255 decision_level,
256 },
257 is_necessary: true,
258 timestamp,
259 });
260 }
261 }
262
263 equalities
264 }
265
266 fn generate_incremental(
268 &self,
269 timestamp: u64,
270 decision_level: DecisionLevel,
271 ) -> Vec<InterfaceEquality> {
272 self.generate_minimal(timestamp, decision_level)
274 }
275
276 fn generate_adaptive(
278 &self,
279 timestamp: u64,
280 decision_level: DecisionLevel,
281 ) -> Vec<InterfaceEquality> {
282 if self.members.len() <= 2 {
284 self.generate_eager(timestamp, decision_level)
285 } else {
286 self.generate_minimal(timestamp, decision_level)
287 }
288 }
289}
290
291#[derive(Debug, Clone)]
293pub struct InterfaceEqualityConfig {
294 pub default_strategy: GenerationStrategy,
296
297 pub enable_minimization: bool,
299
300 pub enable_priority: bool,
302
303 pub max_batch_size: usize,
305
306 pub track_relevancy: bool,
308
309 pub adaptive_threshold: usize,
311}
312
313impl Default for InterfaceEqualityConfig {
314 fn default() -> Self {
315 Self {
316 default_strategy: GenerationStrategy::Minimal,
317 enable_minimization: true,
318 enable_priority: true,
319 max_batch_size: 1000,
320 track_relevancy: true,
321 adaptive_threshold: 10,
322 }
323 }
324}
325
326#[derive(Debug, Clone, Default)]
328pub struct InterfaceEqualityStats {
329 pub equalities_generated: u64,
331 pub equalities_minimized: u64,
333 pub eager_generations: u64,
335 pub lazy_generations: u64,
337 pub minimal_generations: u64,
339 pub eclasses: u64,
341 pub batches_sent: u64,
343}
344
345pub struct InterfaceEqualityManager {
347 config: InterfaceEqualityConfig,
349
350 stats: InterfaceEqualityStats,
352
353 term_to_eclass: FxHashMap<TermId, usize>,
355
356 eclasses: Vec<InterfaceEClass>,
358
359 pending: BinaryHeap<InterfaceEquality>,
361
362 generated: FxHashSet<Equality>,
364
365 timestamp: u64,
367
368 decision_level: DecisionLevel,
370
371 relevancy: FxHashMap<TermId, u32>,
373
374 history: FxHashMap<DecisionLevel, Vec<Equality>>,
376}
377
378impl InterfaceEqualityManager {
379 pub fn new() -> Self {
381 Self::with_config(InterfaceEqualityConfig::default())
382 }
383
384 pub fn with_config(config: InterfaceEqualityConfig) -> Self {
386 Self {
387 config,
388 stats: InterfaceEqualityStats::default(),
389 term_to_eclass: FxHashMap::default(),
390 eclasses: Vec::new(),
391 pending: BinaryHeap::new(),
392 generated: FxHashSet::default(),
393 timestamp: 0,
394 decision_level: 0,
395 relevancy: FxHashMap::default(),
396 history: FxHashMap::default(),
397 }
398 }
399
400 pub fn stats(&self) -> &InterfaceEqualityStats {
402 &self.stats
403 }
404
405 pub fn register_term(&mut self, term: TermId, theory: TheoryId) {
407 if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
408 self.eclasses[eclass_id].add_term(term, theory);
409 } else {
410 let eclass_id = self.eclasses.len();
411 self.eclasses.push(InterfaceEClass::new(term, theory));
412 self.term_to_eclass.insert(term, eclass_id);
413 self.stats.eclasses += 1;
414 }
415 }
416
417 pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
419 let lhs_class = self.find_or_create(lhs);
420 let rhs_class = self.find_or_create(rhs);
421
422 if lhs_class == rhs_class {
423 return Ok(());
424 }
425
426 let (small, large) =
428 if self.eclasses[lhs_class].members.len() < self.eclasses[rhs_class].members.len() {
429 (lhs_class, rhs_class)
430 } else {
431 (rhs_class, lhs_class)
432 };
433
434 let small_eclass = self.eclasses[small].clone();
435 self.eclasses[large].merge(&small_eclass);
436
437 for &term in &small_eclass.members {
439 self.term_to_eclass.insert(term, large);
440 }
441
442 if self.eclasses[large].is_shared() {
444 self.generate_equalities_for_class(large)?;
445 }
446
447 Ok(())
448 }
449
450 fn find_or_create(&mut self, term: TermId) -> usize {
452 if let Some(&eclass_id) = self.term_to_eclass.get(&term) {
453 eclass_id
454 } else {
455 let eclass_id = self.eclasses.len();
456 self.eclasses.push(InterfaceEClass::new(term, 0));
457 self.term_to_eclass.insert(term, eclass_id);
458 self.stats.eclasses += 1;
459 eclass_id
460 }
461 }
462
463 fn generate_equalities_for_class(&mut self, eclass_id: usize) -> Result<(), String> {
465 if eclass_id >= self.eclasses.len() {
466 return Err("Invalid eclass ID".to_string());
467 }
468
469 let eclass = &self.eclasses[eclass_id];
470 let equalities = eclass.generate_equalities(self.timestamp, self.decision_level);
471
472 for eq in equalities {
473 if !self.generated.contains(&eq.equality) {
474 self.generated.insert(eq.equality);
475 self.pending.push(eq);
476 self.stats.equalities_generated += 1;
477
478 match eclass.strategy {
480 GenerationStrategy::Eager => self.stats.eager_generations += 1,
481 GenerationStrategy::Lazy => self.stats.lazy_generations += 1,
482 GenerationStrategy::Minimal => self.stats.minimal_generations += 1,
483 _ => {}
484 }
485 }
486 }
487
488 self.timestamp += 1;
489 Ok(())
490 }
491
492 pub fn get_pending_batch(&mut self) -> Vec<InterfaceEquality> {
494 let mut batch = Vec::new();
495
496 while batch.len() < self.config.max_batch_size {
497 if let Some(eq) = self.pending.pop() {
498 batch.push(eq);
499 } else {
500 break;
501 }
502 }
503
504 if !batch.is_empty() {
505 self.stats.batches_sent += 1;
506 }
507
508 batch
509 }
510
511 pub fn get_all_pending(&mut self) -> Vec<InterfaceEquality> {
513 let mut all = Vec::new();
514
515 while let Some(eq) = self.pending.pop() {
516 all.push(eq);
517 }
518
519 if !all.is_empty() {
520 self.stats.batches_sent += 1;
521 }
522
523 all
524 }
525
526 pub fn set_strategy(
528 &mut self,
529 term: TermId,
530 strategy: GenerationStrategy,
531 ) -> Result<(), String> {
532 let eclass_id = self
533 .term_to_eclass
534 .get(&term)
535 .ok_or("Term not registered")?;
536 self.eclasses[*eclass_id].strategy = strategy;
537 Ok(())
538 }
539
540 pub fn minimize_equalities(&mut self) {
544 if !self.config.enable_minimization {
545 return;
546 }
547
548 let all_pending: Vec<_> = self.pending.drain().collect();
550 let mut necessary = Vec::new();
551
552 let mut by_class: FxHashMap<usize, Vec<InterfaceEquality>> = FxHashMap::default();
554
555 for eq in all_pending {
556 if let Some(&eclass_id) = self.term_to_eclass.get(&eq.equality.lhs) {
557 by_class.entry(eclass_id).or_default().push(eq);
558 }
559 }
560
561 for (_eclass_id, mut equalities) in by_class {
563 if equalities.len() <= 2 {
564 necessary.extend(equalities);
565 continue;
566 }
567
568 let rep = equalities[0].equality.lhs;
570
571 equalities.retain(|eq| eq.equality.lhs == rep || eq.equality.rhs == rep);
573
574 let before = equalities.len();
575 let minimized = equalities.len();
576 self.stats.equalities_minimized += (before - minimized) as u64;
577
578 necessary.extend(equalities);
579 }
580
581 for eq in necessary {
583 self.pending.push(eq);
584 }
585 }
586
587 pub fn update_relevancy(&mut self, term: TermId, score: u32) {
589 if !self.config.track_relevancy {
590 return;
591 }
592
593 self.relevancy.insert(term, score);
594
595 let all_pending: Vec<_> = self.pending.drain().collect();
597
598 for mut eq in all_pending {
599 if eq.equality.lhs == term || eq.equality.rhs == term {
600 eq.priority.relevancy = score;
601 }
602 self.pending.push(eq);
603 }
604 }
605
606 pub fn push_decision_level(&mut self) {
608 self.decision_level += 1;
609 }
610
611 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
613 if level > self.decision_level {
614 return Err("Cannot backtrack to future level".to_string());
615 }
616
617 let all_pending: Vec<_> = self.pending.drain().collect();
619
620 for eq in all_pending {
621 if eq.priority.decision_level <= level {
622 self.pending.push(eq);
623 } else {
624 self.generated.remove(&eq.equality);
625 }
626 }
627
628 let levels_to_remove: Vec<_> = self
630 .history
631 .keys()
632 .filter(|&&l| l > level)
633 .copied()
634 .collect();
635
636 for l in levels_to_remove {
637 self.history.remove(&l);
638 }
639
640 self.decision_level = level;
641 Ok(())
642 }
643
644 pub fn clear(&mut self) {
646 self.term_to_eclass.clear();
647 self.eclasses.clear();
648 self.pending.clear();
649 self.generated.clear();
650 self.timestamp = 0;
651 self.decision_level = 0;
652 self.relevancy.clear();
653 self.history.clear();
654 }
655
656 pub fn reset_stats(&mut self) {
658 self.stats = InterfaceEqualityStats::default();
659 }
660
661 pub fn pending_count(&self) -> usize {
663 self.pending.len()
664 }
665
666 pub fn is_generated(&self, eq: &Equality) -> bool {
668 self.generated.contains(eq)
669 }
670
671 pub fn force_generate_all(&mut self) -> Result<(), String> {
673 for eclass_id in 0..self.eclasses.len() {
674 if self.eclasses[eclass_id].is_shared() {
675 self.generate_equalities_for_class(eclass_id)?;
676 }
677 }
678 Ok(())
679 }
680
681 pub fn get_eclass(&self, term: TermId) -> Option<&InterfaceEClass> {
683 self.term_to_eclass
684 .get(&term)
685 .and_then(|&id| self.eclasses.get(id))
686 }
687
688 pub fn get_representative(&self, term: TermId) -> Option<TermId> {
690 self.get_eclass(term).map(|ec| ec.representative)
691 }
692
693 pub fn are_equal(&self, lhs: TermId, rhs: TermId) -> bool {
695 if let (Some(&lhs_class), Some(&rhs_class)) =
696 (self.term_to_eclass.get(&lhs), self.term_to_eclass.get(&rhs))
697 {
698 lhs_class == rhs_class
699 } else {
700 false
701 }
702 }
703}
704
705impl Default for InterfaceEqualityManager {
706 fn default() -> Self {
707 Self::new()
708 }
709}
710
711pub struct EqualityScheduler {
713 scheduled: BinaryHeap<InterfaceEquality>,
715 policy: SchedulingPolicy,
717}
718
719#[derive(Debug, Clone, Copy, PartialEq, Eq)]
721pub enum SchedulingPolicy {
722 Fifo,
724 Priority,
726 Relevancy,
728 RoundRobin,
730}
731
732impl EqualityScheduler {
733 pub fn new(policy: SchedulingPolicy) -> Self {
735 Self {
736 scheduled: BinaryHeap::new(),
737 policy,
738 }
739 }
740
741 pub fn schedule(&mut self, equality: InterfaceEquality) {
743 self.scheduled.push(equality);
744 }
745
746 #[allow(clippy::should_implement_trait)]
748 pub fn next(&mut self) -> Option<InterfaceEquality> {
749 match self.policy {
750 SchedulingPolicy::Fifo => {
751 let all: Vec<_> = self.scheduled.drain().collect();
753
754 all.into_iter().next()
755 }
756 SchedulingPolicy::Priority | SchedulingPolicy::Relevancy => self.scheduled.pop(),
757 SchedulingPolicy::RoundRobin => {
758 self.scheduled.pop()
760 }
761 }
762 }
763
764 pub fn next_batch(&mut self, size: usize) -> Vec<InterfaceEquality> {
766 let mut batch = Vec::new();
767
768 for _ in 0..size {
769 if let Some(eq) = self.next() {
770 batch.push(eq);
771 } else {
772 break;
773 }
774 }
775
776 batch
777 }
778
779 pub fn clear(&mut self) {
781 self.scheduled.clear();
782 }
783}
784
785pub struct EqualityMinimizer {
787 parent: FxHashMap<TermId, TermId>,
789 rank: FxHashMap<TermId, usize>,
791}
792
793impl EqualityMinimizer {
794 pub fn new() -> Self {
796 Self {
797 parent: FxHashMap::default(),
798 rank: FxHashMap::default(),
799 }
800 }
801
802 pub fn add_equality(&mut self, eq: Equality) {
804 let lhs_root = self.find(eq.lhs);
805 let rhs_root = self.find(eq.rhs);
806
807 if lhs_root == rhs_root {
808 return;
809 }
810
811 let lhs_rank = self.rank.get(&lhs_root).copied().unwrap_or(0);
812 let rhs_rank = self.rank.get(&rhs_root).copied().unwrap_or(0);
813
814 if lhs_rank < rhs_rank {
815 self.parent.insert(lhs_root, rhs_root);
816 } else if lhs_rank > rhs_rank {
817 self.parent.insert(rhs_root, lhs_root);
818 } else {
819 self.parent.insert(lhs_root, rhs_root);
820 self.rank.insert(rhs_root, rhs_rank + 1);
821 }
822 }
823
824 fn find(&mut self, mut term: TermId) -> TermId {
826 let mut path = Vec::new();
827
828 while let Some(&parent) = self.parent.get(&term) {
829 if parent == term {
830 break;
831 }
832 path.push(term);
833 term = parent;
834 }
835
836 for node in path {
837 self.parent.insert(node, term);
838 }
839
840 term
841 }
842
843 pub fn is_redundant(&mut self, eq: &Equality) -> bool {
845 self.find(eq.lhs) == self.find(eq.rhs)
846 }
847
848 pub fn minimize(&mut self, equalities: Vec<Equality>) -> Vec<Equality> {
850 let mut minimal = Vec::new();
851
852 for eq in equalities {
853 if !self.is_redundant(&eq) {
854 self.add_equality(eq);
855 minimal.push(eq);
856 }
857 }
858
859 minimal
860 }
861
862 pub fn clear(&mut self) {
864 self.parent.clear();
865 self.rank.clear();
866 }
867}
868
869impl Default for EqualityMinimizer {
870 fn default() -> Self {
871 Self::new()
872 }
873}
874
875#[cfg(test)]
876mod tests {
877 use super::*;
878
879 #[test]
880 fn test_interface_eclass() {
881 let mut eclass = InterfaceEClass::new(1, 0);
882 eclass.add_term(2, 1);
883
884 assert_eq!(eclass.members.len(), 2);
885 assert!(eclass.is_shared());
886 }
887
888 #[test]
889 fn test_minimal_generation() {
890 let mut eclass = InterfaceEClass::new(1, 0);
891 eclass.add_term(2, 0);
892 eclass.add_term(3, 0);
893
894 let equalities = eclass.generate_minimal(0, 0);
895 assert_eq!(equalities.len(), 2); }
897
898 #[test]
899 fn test_manager_creation() {
900 let manager = InterfaceEqualityManager::new();
901 assert_eq!(manager.stats().equalities_generated, 0);
902 }
903
904 #[test]
905 fn test_register_term() {
906 let mut manager = InterfaceEqualityManager::new();
907 manager.register_term(1, 0);
908 manager.register_term(1, 1);
909
910 assert_eq!(manager.stats().eclasses, 1);
911 }
912
913 #[test]
914 fn test_assert_equality() {
915 let mut manager = InterfaceEqualityManager::new();
916 manager.register_term(1, 0);
917 manager.register_term(2, 1);
918
919 manager.assert_equality(1, 2).expect("Assert failed");
920 assert!(manager.are_equal(1, 2));
921 }
922
923 #[test]
924 fn test_get_pending() {
925 let mut manager = InterfaceEqualityManager::new();
926 manager.register_term(1, 0);
927 manager.register_term(2, 1);
928 manager.register_term(1, 1); manager.assert_equality(1, 2).expect("Assert failed");
931
932 let pending = manager.get_all_pending();
933 assert!(!pending.is_empty());
934 }
935
936 #[test]
937 fn test_minimization() {
938 let mut manager = InterfaceEqualityManager::new();
939
940 for i in 1..=5 {
942 manager.register_term(i, 0);
943 manager.register_term(i, 1);
944 }
945
946 for i in 2..=5 {
947 manager.assert_equality(1, i).expect("Assert failed");
948 }
949
950 manager.minimize_equalities();
951
952 let pending = manager.get_all_pending();
953 assert!(pending.len() <= 4);
955 }
956
957 #[test]
958 fn test_scheduler() {
959 let mut scheduler = EqualityScheduler::new(SchedulingPolicy::Priority);
960
961 let eq = InterfaceEquality {
962 equality: Equality::new(1, 2),
963 theories: FxHashSet::default(),
964 priority: EqualityPriority {
965 level: 100,
966 relevancy: 50,
967 decision_level: 0,
968 },
969 is_necessary: true,
970 timestamp: 0,
971 };
972
973 scheduler.schedule(eq);
974 assert!(scheduler.next().is_some());
975 }
976
977 #[test]
978 fn test_minimizer() {
979 let mut minimizer = EqualityMinimizer::new();
980
981 let eq1 = Equality::new(1, 2);
982 let eq2 = Equality::new(2, 3);
983 let eq3 = Equality::new(1, 3); minimizer.add_equality(eq1);
986 minimizer.add_equality(eq2);
987
988 assert!(minimizer.is_redundant(&eq3));
989 }
990
991 #[test]
992 fn test_backtrack() {
993 let mut manager = InterfaceEqualityManager::new();
994
995 manager.push_decision_level();
996 manager.register_term(1, 0);
997
998 manager.backtrack(0).expect("Backtrack failed");
999 }
1000
1001 #[test]
1002 fn test_set_strategy() {
1003 let mut manager = InterfaceEqualityManager::new();
1004 manager.register_term(1, 0);
1005
1006 manager
1007 .set_strategy(1, GenerationStrategy::Eager)
1008 .expect("Set strategy failed");
1009
1010 let eclass = manager.get_eclass(1).expect("No eclass");
1011 assert_eq!(eclass.strategy, GenerationStrategy::Eager);
1012 }
1013}