1#![allow(missing_docs)]
41#![allow(dead_code)]
42#[allow(unused_imports)]
43use crate::prelude::*;
44
45pub type TermId = u32;
47
48pub type TheoryId = u32;
50
51pub type DecisionLevel = u32;
53
54pub type ClassId = usize;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59pub struct Equality {
60 pub lhs: TermId,
62 pub rhs: TermId,
64}
65
66impl Equality {
67 pub fn new(lhs: TermId, rhs: TermId) -> Self {
69 if lhs <= rhs {
70 Self { lhs, rhs }
71 } else {
72 Self { lhs: rhs, rhs: lhs }
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct Partition {
80 classes: Vec<FxHashSet<TermId>>,
82
83 term_to_class: FxHashMap<TermId, ClassId>,
85
86 representatives: Vec<TermId>,
88}
89
90impl Partition {
91 pub fn finest(terms: &[TermId]) -> Self {
93 let mut classes = Vec::new();
94 let mut term_to_class = FxHashMap::default();
95 let mut representatives = Vec::new();
96
97 for (i, &term) in terms.iter().enumerate() {
98 let mut class = FxHashSet::default();
99 class.insert(term);
100 classes.push(class);
101 term_to_class.insert(term, i);
102 representatives.push(term);
103 }
104
105 Self {
106 classes,
107 term_to_class,
108 representatives,
109 }
110 }
111
112 pub fn coarsest(terms: &[TermId]) -> Self {
114 if terms.is_empty() {
115 return Self {
116 classes: Vec::new(),
117 term_to_class: FxHashMap::default(),
118 representatives: Vec::new(),
119 };
120 }
121
122 let mut class = FxHashSet::default();
123 let mut term_to_class = FxHashMap::default();
124
125 for &term in terms {
126 class.insert(term);
127 term_to_class.insert(term, 0);
128 }
129
130 Self {
131 classes: vec![class],
132 term_to_class,
133 representatives: vec![terms[0]],
134 }
135 }
136
137 pub fn merge(&mut self, t1: TermId, t2: TermId) -> Result<(), String> {
139 let c1 = *self.term_to_class.get(&t1).ok_or("Term not in partition")?;
140 let c2 = *self.term_to_class.get(&t2).ok_or("Term not in partition")?;
141
142 if c1 == c2 {
143 return Ok(());
144 }
145
146 let (src, dst) = if self.classes[c1].len() < self.classes[c2].len() {
148 (c1, c2)
149 } else {
150 (c2, c1)
151 };
152
153 let src_terms: Vec<_> = self.classes[src].iter().copied().collect();
155 for term in src_terms {
156 self.classes[dst].insert(term);
157 self.term_to_class.insert(term, dst);
158 }
159
160 self.classes[src].clear();
162
163 Ok(())
164 }
165
166 pub fn get_equalities(&self) -> Vec<Equality> {
168 let mut equalities = Vec::new();
169
170 for class in &self.classes {
171 if class.len() > 1 {
172 let terms: Vec<_> = class.iter().copied().collect();
173 let rep = terms[0];
175 for &term in &terms[1..] {
176 equalities.push(Equality::new(rep, term));
177 }
178 }
179 }
180
181 equalities
182 }
183
184 pub fn num_classes(&self) -> usize {
186 self.classes.iter().filter(|c| !c.is_empty()).count()
187 }
188
189 pub fn are_equal(&self, t1: TermId, t2: TermId) -> bool {
191 if let (Some(&c1), Some(&c2)) = (self.term_to_class.get(&t1), self.term_to_class.get(&t2)) {
192 c1 == c2
193 } else {
194 false
195 }
196 }
197
198 pub fn get_representative(&self, term: TermId) -> Option<TermId> {
200 self.term_to_class
201 .get(&term)
202 .and_then(|&class_id| self.representatives.get(class_id))
203 .copied()
204 }
205
206 pub fn get_class(&self, term: TermId) -> Option<&FxHashSet<TermId>> {
208 self.term_to_class
209 .get(&term)
210 .and_then(|&class_id| self.classes.get(class_id))
211 }
212
213 pub fn clone_partition(&self) -> Partition {
215 self.clone()
216 }
217}
218
219pub struct PartitionRefinement {
221 partition: Partition,
223
224 history: Vec<Partition>,
226
227 decision_levels: Vec<DecisionLevel>,
229
230 current_level: DecisionLevel,
232}
233
234impl PartitionRefinement {
235 pub fn new(terms: &[TermId]) -> Self {
237 Self {
238 partition: Partition::finest(terms),
239 history: Vec::new(),
240 decision_levels: Vec::new(),
241 current_level: 0,
242 }
243 }
244
245 pub fn refine(&mut self, eq: Equality) -> Result<(), String> {
247 self.history.push(self.partition.clone_partition());
248 self.decision_levels.push(self.current_level);
249 self.partition.merge(eq.lhs, eq.rhs)
250 }
251
252 pub fn refine_batch(&mut self, equalities: &[Equality]) -> Result<(), String> {
254 for &eq in equalities {
255 self.refine(eq)?;
256 }
257 Ok(())
258 }
259
260 pub fn current(&self) -> &Partition {
262 &self.partition
263 }
264
265 pub fn backtrack_step(&mut self) -> Result<(), String> {
267 self.partition = self.history.pop().ok_or("No refinement to backtrack")?;
268 self.decision_levels.pop();
269 Ok(())
270 }
271
272 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
274 while !self.decision_levels.is_empty() {
275 if let Some(&last_level) = self.decision_levels.last() {
276 if last_level > level {
277 self.backtrack_step()?;
278 } else {
279 break;
280 }
281 } else {
282 break;
283 }
284 }
285
286 self.current_level = level;
287 Ok(())
288 }
289
290 pub fn push_decision_level(&mut self) {
292 self.current_level += 1;
293 }
294
295 pub fn clear_history(&mut self) {
297 self.history.clear();
298 self.decision_levels.clear();
299 }
300}
301
302pub struct PartitionEnumerator {
304 n: usize,
306
307 terms: Vec<TermId>,
309
310 rgs: Vec<usize>,
312
313 max_val: usize,
315
316 done: bool,
318}
319
320impl PartitionEnumerator {
321 pub fn new(terms: Vec<TermId>) -> Self {
323 let n = terms.len();
324 Self {
325 n,
326 terms,
327 rgs: vec![0; n],
328 max_val: 0,
329 done: n == 0,
330 }
331 }
332
333 #[allow(clippy::should_implement_trait)]
335 pub fn next(&mut self) -> Option<Partition> {
336 if self.done {
337 return None;
338 }
339
340 let partition = self.rgs_to_partition();
342
343 self.next_rgs();
345
346 Some(partition)
347 }
348
349 fn rgs_to_partition(&self) -> Partition {
351 let mut classes: Vec<FxHashSet<TermId>> = vec![FxHashSet::default(); self.max_val + 1];
352 let mut term_to_class = FxHashMap::default();
353 let mut representatives = vec![0; self.max_val + 1];
354
355 for (i, &class_id) in self.rgs.iter().enumerate() {
356 let term = self.terms[i];
357 classes[class_id].insert(term);
358 term_to_class.insert(term, class_id);
359
360 if representatives[class_id] == 0 || term < representatives[class_id] {
361 representatives[class_id] = term;
362 }
363 }
364
365 Partition {
366 classes,
367 term_to_class,
368 representatives,
369 }
370 }
371
372 fn next_rgs(&mut self) {
374 let mut i = self.n;
376 while i > 0 {
377 i -= 1;
378
379 let can_increment = if i == 0 {
380 false
381 } else {
382 let max_up_to_i = self.rgs[..i].iter().max().copied().unwrap_or(0);
383 self.rgs[i] <= max_up_to_i
384 };
385
386 if can_increment {
387 self.rgs[i] += 1;
388
389 self.max_val = self.rgs.iter().max().copied().unwrap_or(0);
391
392 for j in (i + 1)..self.n {
394 self.rgs[j] = 0;
395 }
396
397 return;
398 }
399 }
400
401 self.done = true;
402 }
403
404 pub fn reset(&mut self) {
406 self.rgs = vec![0; self.n];
407 self.max_val = 0;
408 self.done = self.n == 0;
409 }
410
411 pub fn count_remaining(&self) -> usize {
413 bell_number(self.n)
415 }
416}
417
418fn bell_number(n: usize) -> usize {
420 if n == 0 {
421 return 1;
422 }
423
424 match n {
426 0 => 1,
427 1 => 1,
428 2 => 2,
429 3 => 5,
430 4 => 15,
431 5 => 52,
432 6 => 203,
433 7 => 877,
434 8 => 4140,
435 _ => usize::MAX, }
437}
438
439#[derive(Debug, Clone)]
441pub struct PartitionRefinementConfig {
442 pub enable_enumeration: bool,
444
445 pub max_partitions: usize,
447
448 pub constraint_guided: bool,
450
451 pub enable_backtracking: bool,
453}
454
455impl Default for PartitionRefinementConfig {
456 fn default() -> Self {
457 Self {
458 enable_enumeration: true,
459 max_partitions: 1000,
460 constraint_guided: true,
461 enable_backtracking: true,
462 }
463 }
464}
465
466#[derive(Debug, Clone, Default)]
468pub struct PartitionRefinementStats {
469 pub refinements: u64,
471 pub partitions_enumerated: u64,
473 pub backtracks: u64,
475 pub constraints_applied: u64,
477}
478
479pub struct PartitionRefinementManager {
481 config: PartitionRefinementConfig,
483
484 stats: PartitionRefinementStats,
486
487 refinement: PartitionRefinement,
489
490 enumerator: Option<PartitionEnumerator>,
492
493 constraints: VecDeque<Equality>,
495}
496
497impl PartitionRefinementManager {
498 pub fn new(terms: Vec<TermId>) -> Self {
500 Self::with_config(terms, PartitionRefinementConfig::default())
501 }
502
503 pub fn with_config(terms: Vec<TermId>, config: PartitionRefinementConfig) -> Self {
505 let enumerator = if config.enable_enumeration {
506 Some(PartitionEnumerator::new(terms.clone()))
507 } else {
508 None
509 };
510
511 Self {
512 config,
513 stats: PartitionRefinementStats::default(),
514 refinement: PartitionRefinement::new(&terms),
515 enumerator,
516 constraints: VecDeque::new(),
517 }
518 }
519
520 pub fn stats(&self) -> &PartitionRefinementStats {
522 &self.stats
523 }
524
525 pub fn add_constraint(&mut self, eq: Equality) {
527 self.constraints.push_back(eq);
528 self.stats.constraints_applied += 1;
529 }
530
531 pub fn apply_constraints(&mut self) -> Result<(), String> {
533 while let Some(eq) = self.constraints.pop_front() {
534 self.refinement.refine(eq)?;
535 self.stats.refinements += 1;
536 }
537 Ok(())
538 }
539
540 pub fn current_partition(&self) -> &Partition {
542 self.refinement.current()
543 }
544
545 pub fn next_partition(&mut self) -> Option<Partition> {
547 if let Some(ref mut enumerator) = self.enumerator {
548 if self.stats.partitions_enumerated >= self.config.max_partitions as u64 {
549 return None;
550 }
551
552 let partition = enumerator.next();
553 if partition.is_some() {
554 self.stats.partitions_enumerated += 1;
555 }
556 partition
557 } else {
558 None
559 }
560 }
561
562 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
564 if !self.config.enable_backtracking {
565 return Ok(());
566 }
567
568 self.refinement.backtrack(level)?;
569 self.stats.backtracks += 1;
570 Ok(())
571 }
572
573 pub fn push_decision_level(&mut self) {
575 self.refinement.push_decision_level();
576 }
577
578 pub fn clear(&mut self) {
580 self.refinement.clear_history();
581 self.constraints.clear();
582
583 if let Some(ref mut enumerator) = self.enumerator {
584 enumerator.reset();
585 }
586 }
587
588 pub fn reset_stats(&mut self) {
590 self.stats = PartitionRefinementStats::default();
591 }
592}
593
594pub struct PartitionComparator;
596
597impl PartitionComparator {
598 pub fn is_finer(p1: &Partition, p2: &Partition) -> bool {
600 for class1 in &p1.classes {
602 if class1.is_empty() {
603 continue;
604 }
605
606 let first_term = *class1.iter().next().expect("Non-empty class");
608 let p2_class = p2.term_to_class.get(&first_term);
609
610 for &term in class1 {
611 if p2.term_to_class.get(&term) != p2_class {
612 return false;
613 }
614 }
615 }
616
617 true
618 }
619
620 pub fn are_equal(p1: &Partition, p2: &Partition) -> bool {
622 Self::is_finer(p1, p2) && Self::is_finer(p2, p1)
623 }
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_finest_partition() {
632 let terms = vec![1, 2, 3];
633 let partition = Partition::finest(&terms);
634
635 assert_eq!(partition.num_classes(), 3);
636 assert!(!partition.are_equal(1, 2));
637 }
638
639 #[test]
640 fn test_coarsest_partition() {
641 let terms = vec![1, 2, 3];
642 let partition = Partition::coarsest(&terms);
643
644 assert_eq!(partition.num_classes(), 1);
645 assert!(partition.are_equal(1, 2));
646 assert!(partition.are_equal(2, 3));
647 }
648
649 #[test]
650 fn test_partition_merge() {
651 let terms = vec![1, 2, 3, 4];
652 let mut partition = Partition::finest(&terms);
653
654 partition.merge(1, 2).expect("Merge failed");
655 assert_eq!(partition.num_classes(), 3);
656 assert!(partition.are_equal(1, 2));
657 assert!(!partition.are_equal(1, 3));
658 }
659
660 #[test]
661 fn test_partition_equalities() {
662 let terms = vec![1, 2, 3];
663 let mut partition = Partition::finest(&terms);
664
665 partition.merge(1, 2).expect("Merge failed");
666 partition.merge(2, 3).expect("Merge failed");
667
668 let equalities = partition.get_equalities();
669 assert_eq!(equalities.len(), 2); }
671
672 #[test]
673 fn test_refinement() {
674 let terms = vec![1, 2, 3, 4];
675 let mut refinement = PartitionRefinement::new(&terms);
676
677 refinement
678 .refine(Equality::new(1, 2))
679 .expect("Refine failed");
680 assert!(refinement.current().are_equal(1, 2));
681 }
682
683 #[test]
684 fn test_refinement_backtrack() {
685 let terms = vec![1, 2, 3, 4];
686 let mut refinement = PartitionRefinement::new(&terms);
687
688 refinement
689 .refine(Equality::new(1, 2))
690 .expect("Refine failed");
691 refinement.backtrack_step().expect("Backtrack failed");
692
693 assert!(!refinement.current().are_equal(1, 2));
694 }
695
696 #[test]
697 fn test_bell_number() {
698 assert_eq!(bell_number(0), 1);
699 assert_eq!(bell_number(1), 1);
700 assert_eq!(bell_number(2), 2);
701 assert_eq!(bell_number(3), 5);
702 assert_eq!(bell_number(4), 15);
703 }
704
705 #[test]
706 fn test_partition_enumerator() {
707 let terms = vec![1, 2, 3];
708 let mut enumerator = PartitionEnumerator::new(terms);
709
710 let mut count = 0;
711 while enumerator.next().is_some() {
712 count += 1;
713 }
714
715 assert_eq!(count, 5); }
717
718 #[test]
719 fn test_manager() {
720 let terms = vec![1, 2, 3];
721 let mut manager = PartitionRefinementManager::new(terms);
722
723 manager.add_constraint(Equality::new(1, 2));
724 manager.apply_constraints().expect("Apply failed");
725
726 assert!(manager.current_partition().are_equal(1, 2));
727 }
728
729 #[test]
730 fn test_partition_comparison() {
731 let terms = vec![1, 2, 3];
732
733 let finest = Partition::finest(&terms);
734 let coarsest = Partition::coarsest(&terms);
735
736 assert!(PartitionComparator::is_finer(&finest, &coarsest));
737 assert!(!PartitionComparator::is_finer(&coarsest, &finest));
738 }
739
740 #[test]
741 fn test_representative() {
742 let terms = vec![1, 2, 3];
743 let mut partition = Partition::finest(&terms);
744
745 partition.merge(1, 2).expect("Merge failed");
746
747 let rep1 = partition.get_representative(1);
748 let rep2 = partition.get_representative(2);
749
750 assert_eq!(rep1, rep2);
751 }
752
753 #[test]
754 fn test_get_class() {
755 let terms = vec![1, 2, 3, 4];
756 let mut partition = Partition::finest(&terms);
757
758 partition.merge(1, 2).expect("Merge failed");
759 partition.merge(2, 3).expect("Merge failed");
760
761 let class = partition.get_class(1).expect("No class");
762 assert_eq!(class.len(), 3);
763 assert!(class.contains(&1));
764 assert!(class.contains(&2));
765 assert!(class.contains(&3));
766 }
767}