1#![allow(missing_docs)]
41#[allow(unused_imports)]
42use crate::prelude::*;
43
44pub type TermId = u32;
46
47pub type TheoryId = u32;
49
50pub type DecisionLevel = u32;
52
53pub type ClassId = usize;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub struct Equality {
59 pub lhs: TermId,
61 pub rhs: TermId,
63}
64
65impl Equality {
66 pub fn new(lhs: TermId, rhs: TermId) -> Self {
68 if lhs <= rhs {
69 Self { lhs, rhs }
70 } else {
71 Self { lhs: rhs, rhs: lhs }
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct Partition {
79 classes: Vec<FxHashSet<TermId>>,
81
82 term_to_class: FxHashMap<TermId, ClassId>,
84
85 representatives: Vec<TermId>,
87}
88
89impl Partition {
90 pub fn finest(terms: &[TermId]) -> Self {
92 let mut classes = Vec::new();
93 let mut term_to_class = FxHashMap::default();
94 let mut representatives = Vec::new();
95
96 for (i, &term) in terms.iter().enumerate() {
97 let mut class = FxHashSet::default();
98 class.insert(term);
99 classes.push(class);
100 term_to_class.insert(term, i);
101 representatives.push(term);
102 }
103
104 Self {
105 classes,
106 term_to_class,
107 representatives,
108 }
109 }
110
111 pub fn coarsest(terms: &[TermId]) -> Self {
113 if terms.is_empty() {
114 return Self {
115 classes: Vec::new(),
116 term_to_class: FxHashMap::default(),
117 representatives: Vec::new(),
118 };
119 }
120
121 let mut class = FxHashSet::default();
122 let mut term_to_class = FxHashMap::default();
123
124 for &term in terms {
125 class.insert(term);
126 term_to_class.insert(term, 0);
127 }
128
129 Self {
130 classes: vec![class],
131 term_to_class,
132 representatives: vec![terms[0]],
133 }
134 }
135
136 pub fn merge(&mut self, t1: TermId, t2: TermId) -> Result<(), String> {
138 let c1 = *self.term_to_class.get(&t1).ok_or("Term not in partition")?;
139 let c2 = *self.term_to_class.get(&t2).ok_or("Term not in partition")?;
140
141 if c1 == c2 {
142 return Ok(());
143 }
144
145 let (src, dst) = if self.classes[c1].len() < self.classes[c2].len() {
147 (c1, c2)
148 } else {
149 (c2, c1)
150 };
151
152 let src_terms: Vec<_> = self.classes[src].iter().copied().collect();
154 for term in src_terms {
155 self.classes[dst].insert(term);
156 self.term_to_class.insert(term, dst);
157 }
158
159 self.classes[src].clear();
161
162 Ok(())
163 }
164
165 pub fn get_equalities(&self) -> Vec<Equality> {
167 let mut equalities = Vec::new();
168
169 for class in &self.classes {
170 if class.len() > 1 {
171 let terms: Vec<_> = class.iter().copied().collect();
172 let rep = terms[0];
174 for &term in &terms[1..] {
175 equalities.push(Equality::new(rep, term));
176 }
177 }
178 }
179
180 equalities
181 }
182
183 pub fn num_classes(&self) -> usize {
185 self.classes.iter().filter(|c| !c.is_empty()).count()
186 }
187
188 pub fn are_equal(&self, t1: TermId, t2: TermId) -> bool {
190 if let (Some(&c1), Some(&c2)) = (self.term_to_class.get(&t1), self.term_to_class.get(&t2)) {
191 c1 == c2
192 } else {
193 false
194 }
195 }
196
197 pub fn get_representative(&self, term: TermId) -> Option<TermId> {
199 self.term_to_class
200 .get(&term)
201 .and_then(|&class_id| self.representatives.get(class_id))
202 .copied()
203 }
204
205 pub fn get_class(&self, term: TermId) -> Option<&FxHashSet<TermId>> {
207 self.term_to_class
208 .get(&term)
209 .and_then(|&class_id| self.classes.get(class_id))
210 }
211
212 pub fn clone_partition(&self) -> Partition {
214 self.clone()
215 }
216}
217
218pub struct PartitionRefinement {
220 partition: Partition,
222
223 history: Vec<Partition>,
225
226 decision_levels: Vec<DecisionLevel>,
228
229 current_level: DecisionLevel,
231}
232
233impl PartitionRefinement {
234 pub fn new(terms: &[TermId]) -> Self {
236 Self {
237 partition: Partition::finest(terms),
238 history: Vec::new(),
239 decision_levels: Vec::new(),
240 current_level: 0,
241 }
242 }
243
244 pub fn refine(&mut self, eq: Equality) -> Result<(), String> {
246 self.history.push(self.partition.clone_partition());
247 self.decision_levels.push(self.current_level);
248 self.partition.merge(eq.lhs, eq.rhs)
249 }
250
251 pub fn refine_batch(&mut self, equalities: &[Equality]) -> Result<(), String> {
253 for &eq in equalities {
254 self.refine(eq)?;
255 }
256 Ok(())
257 }
258
259 pub fn current(&self) -> &Partition {
261 &self.partition
262 }
263
264 pub fn backtrack_step(&mut self) -> Result<(), String> {
266 self.partition = self.history.pop().ok_or("No refinement to backtrack")?;
267 self.decision_levels.pop();
268 Ok(())
269 }
270
271 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
273 while !self.decision_levels.is_empty() {
274 if let Some(&last_level) = self.decision_levels.last() {
275 if last_level > level {
276 self.backtrack_step()?;
277 } else {
278 break;
279 }
280 } else {
281 break;
282 }
283 }
284
285 self.current_level = level;
286 Ok(())
287 }
288
289 pub fn push_decision_level(&mut self) {
291 self.current_level += 1;
292 }
293
294 pub fn clear_history(&mut self) {
296 self.history.clear();
297 self.decision_levels.clear();
298 }
299}
300
301pub struct PartitionEnumerator {
303 n: usize,
305
306 terms: Vec<TermId>,
308
309 rgs: Vec<usize>,
311
312 max_val: usize,
314
315 done: bool,
317}
318
319impl PartitionEnumerator {
320 pub fn new(terms: Vec<TermId>) -> Self {
322 let n = terms.len();
323 Self {
324 n,
325 terms,
326 rgs: vec![0; n],
327 max_val: 0,
328 done: n == 0,
329 }
330 }
331
332 #[allow(clippy::should_implement_trait)]
334 pub fn next(&mut self) -> Option<Partition> {
335 if self.done {
336 return None;
337 }
338
339 let partition = self.rgs_to_partition();
341
342 self.next_rgs();
344
345 Some(partition)
346 }
347
348 fn rgs_to_partition(&self) -> Partition {
350 let mut classes: Vec<FxHashSet<TermId>> = vec![FxHashSet::default(); self.max_val + 1];
351 let mut term_to_class = FxHashMap::default();
352 let mut representatives = vec![0; self.max_val + 1];
353
354 for (i, &class_id) in self.rgs.iter().enumerate() {
355 let term = self.terms[i];
356 classes[class_id].insert(term);
357 term_to_class.insert(term, class_id);
358
359 if representatives[class_id] == 0 || term < representatives[class_id] {
360 representatives[class_id] = term;
361 }
362 }
363
364 Partition {
365 classes,
366 term_to_class,
367 representatives,
368 }
369 }
370
371 fn next_rgs(&mut self) {
373 let mut i = self.n;
375 while i > 0 {
376 i -= 1;
377
378 let can_increment = if i == 0 {
379 false
380 } else {
381 let max_up_to_i = self.rgs[..i].iter().max().copied().unwrap_or(0);
382 self.rgs[i] <= max_up_to_i
383 };
384
385 if can_increment {
386 self.rgs[i] += 1;
387
388 self.max_val = self.rgs.iter().max().copied().unwrap_or(0);
390
391 for j in (i + 1)..self.n {
393 self.rgs[j] = 0;
394 }
395
396 return;
397 }
398 }
399
400 self.done = true;
401 }
402
403 pub fn reset(&mut self) {
405 self.rgs = vec![0; self.n];
406 self.max_val = 0;
407 self.done = self.n == 0;
408 }
409
410 pub fn count_remaining(&self) -> usize {
412 bell_number(self.n)
414 }
415}
416
417fn bell_number(n: usize) -> usize {
419 if n == 0 {
420 return 1;
421 }
422
423 match n {
425 0 => 1,
426 1 => 1,
427 2 => 2,
428 3 => 5,
429 4 => 15,
430 5 => 52,
431 6 => 203,
432 7 => 877,
433 8 => 4140,
434 _ => usize::MAX, }
436}
437
438#[derive(Debug, Clone)]
440pub struct PartitionRefinementConfig {
441 pub enable_enumeration: bool,
443
444 pub max_partitions: usize,
446
447 pub constraint_guided: bool,
449
450 pub enable_backtracking: bool,
452}
453
454impl Default for PartitionRefinementConfig {
455 fn default() -> Self {
456 Self {
457 enable_enumeration: true,
458 max_partitions: 1000,
459 constraint_guided: true,
460 enable_backtracking: true,
461 }
462 }
463}
464
465#[derive(Debug, Clone, Default)]
467pub struct PartitionRefinementStats {
468 pub refinements: u64,
470 pub partitions_enumerated: u64,
472 pub backtracks: u64,
474 pub constraints_applied: u64,
476}
477
478pub struct PartitionRefinementManager {
480 config: PartitionRefinementConfig,
482
483 stats: PartitionRefinementStats,
485
486 refinement: PartitionRefinement,
488
489 enumerator: Option<PartitionEnumerator>,
491
492 constraints: VecDeque<Equality>,
494}
495
496impl PartitionRefinementManager {
497 pub fn new(terms: Vec<TermId>) -> Self {
499 Self::with_config(terms, PartitionRefinementConfig::default())
500 }
501
502 pub fn with_config(terms: Vec<TermId>, config: PartitionRefinementConfig) -> Self {
504 let enumerator = if config.enable_enumeration {
505 Some(PartitionEnumerator::new(terms.clone()))
506 } else {
507 None
508 };
509
510 Self {
511 config,
512 stats: PartitionRefinementStats::default(),
513 refinement: PartitionRefinement::new(&terms),
514 enumerator,
515 constraints: VecDeque::new(),
516 }
517 }
518
519 pub fn stats(&self) -> &PartitionRefinementStats {
521 &self.stats
522 }
523
524 pub fn add_constraint(&mut self, eq: Equality) {
526 self.constraints.push_back(eq);
527 self.stats.constraints_applied += 1;
528 }
529
530 pub fn apply_constraints(&mut self) -> Result<(), String> {
532 while let Some(eq) = self.constraints.pop_front() {
533 self.refinement.refine(eq)?;
534 self.stats.refinements += 1;
535 }
536 Ok(())
537 }
538
539 pub fn current_partition(&self) -> &Partition {
541 self.refinement.current()
542 }
543
544 pub fn next_partition(&mut self) -> Option<Partition> {
546 if let Some(ref mut enumerator) = self.enumerator {
547 if self.stats.partitions_enumerated >= self.config.max_partitions as u64 {
548 return None;
549 }
550
551 let partition = enumerator.next();
552 if partition.is_some() {
553 self.stats.partitions_enumerated += 1;
554 }
555 partition
556 } else {
557 None
558 }
559 }
560
561 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
563 if !self.config.enable_backtracking {
564 return Ok(());
565 }
566
567 self.refinement.backtrack(level)?;
568 self.stats.backtracks += 1;
569 Ok(())
570 }
571
572 pub fn push_decision_level(&mut self) {
574 self.refinement.push_decision_level();
575 }
576
577 pub fn clear(&mut self) {
579 self.refinement.clear_history();
580 self.constraints.clear();
581
582 if let Some(ref mut enumerator) = self.enumerator {
583 enumerator.reset();
584 }
585 }
586
587 pub fn reset_stats(&mut self) {
589 self.stats = PartitionRefinementStats::default();
590 }
591}
592
593pub struct PartitionComparator;
595
596impl PartitionComparator {
597 pub fn is_finer(p1: &Partition, p2: &Partition) -> bool {
599 for class1 in &p1.classes {
601 if class1.is_empty() {
602 continue;
603 }
604
605 let first_term = *class1.iter().next().expect("Non-empty class");
607 let p2_class = p2.term_to_class.get(&first_term);
608
609 for &term in class1 {
610 if p2.term_to_class.get(&term) != p2_class {
611 return false;
612 }
613 }
614 }
615
616 true
617 }
618
619 pub fn are_equal(p1: &Partition, p2: &Partition) -> bool {
621 Self::is_finer(p1, p2) && Self::is_finer(p2, p1)
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[test]
630 fn test_finest_partition() {
631 let terms = vec![1, 2, 3];
632 let partition = Partition::finest(&terms);
633
634 assert_eq!(partition.num_classes(), 3);
635 assert!(!partition.are_equal(1, 2));
636 }
637
638 #[test]
639 fn test_coarsest_partition() {
640 let terms = vec![1, 2, 3];
641 let partition = Partition::coarsest(&terms);
642
643 assert_eq!(partition.num_classes(), 1);
644 assert!(partition.are_equal(1, 2));
645 assert!(partition.are_equal(2, 3));
646 }
647
648 #[test]
649 fn test_partition_merge() {
650 let terms = vec![1, 2, 3, 4];
651 let mut partition = Partition::finest(&terms);
652
653 partition.merge(1, 2).expect("Merge failed");
654 assert_eq!(partition.num_classes(), 3);
655 assert!(partition.are_equal(1, 2));
656 assert!(!partition.are_equal(1, 3));
657 }
658
659 #[test]
660 fn test_partition_equalities() {
661 let terms = vec![1, 2, 3];
662 let mut partition = Partition::finest(&terms);
663
664 partition.merge(1, 2).expect("Merge failed");
665 partition.merge(2, 3).expect("Merge failed");
666
667 let equalities = partition.get_equalities();
668 assert_eq!(equalities.len(), 2); }
670
671 #[test]
672 fn test_refinement() {
673 let terms = vec![1, 2, 3, 4];
674 let mut refinement = PartitionRefinement::new(&terms);
675
676 refinement
677 .refine(Equality::new(1, 2))
678 .expect("Refine failed");
679 assert!(refinement.current().are_equal(1, 2));
680 }
681
682 #[test]
683 fn test_refinement_backtrack() {
684 let terms = vec![1, 2, 3, 4];
685 let mut refinement = PartitionRefinement::new(&terms);
686
687 refinement
688 .refine(Equality::new(1, 2))
689 .expect("Refine failed");
690 refinement.backtrack_step().expect("Backtrack failed");
691
692 assert!(!refinement.current().are_equal(1, 2));
693 }
694
695 #[test]
696 fn test_bell_number() {
697 assert_eq!(bell_number(0), 1);
698 assert_eq!(bell_number(1), 1);
699 assert_eq!(bell_number(2), 2);
700 assert_eq!(bell_number(3), 5);
701 assert_eq!(bell_number(4), 15);
702 }
703
704 #[test]
705 fn test_partition_enumerator() {
706 let terms = vec![1, 2, 3];
707 let mut enumerator = PartitionEnumerator::new(terms);
708
709 let mut count = 0;
710 while enumerator.next().is_some() {
711 count += 1;
712 }
713
714 assert_eq!(count, 5); }
716
717 #[test]
718 fn test_manager() {
719 let terms = vec![1, 2, 3];
720 let mut manager = PartitionRefinementManager::new(terms);
721
722 manager.add_constraint(Equality::new(1, 2));
723 manager.apply_constraints().expect("Apply failed");
724
725 assert!(manager.current_partition().are_equal(1, 2));
726 }
727
728 #[test]
729 fn test_partition_comparison() {
730 let terms = vec![1, 2, 3];
731
732 let finest = Partition::finest(&terms);
733 let coarsest = Partition::coarsest(&terms);
734
735 assert!(PartitionComparator::is_finer(&finest, &coarsest));
736 assert!(!PartitionComparator::is_finer(&coarsest, &finest));
737 }
738
739 #[test]
740 fn test_representative() {
741 let terms = vec![1, 2, 3];
742 let mut partition = Partition::finest(&terms);
743
744 partition.merge(1, 2).expect("Merge failed");
745
746 let rep1 = partition.get_representative(1);
747 let rep2 = partition.get_representative(2);
748
749 assert_eq!(rep1, rep2);
750 }
751
752 #[test]
753 fn test_get_class() {
754 let terms = vec![1, 2, 3, 4];
755 let mut partition = Partition::finest(&terms);
756
757 partition.merge(1, 2).expect("Merge failed");
758 partition.merge(2, 3).expect("Merge failed");
759
760 let class = partition.get_class(1).expect("No class");
761 assert_eq!(class.len(), 3);
762 assert!(class.contains(&1));
763 assert!(class.contains(&2));
764 assert!(class.contains(&3));
765 }
766}