1#![allow(missing_docs)]
41#![allow(dead_code)]
42
43use rustc_hash::{FxHashMap, FxHashSet};
44use std::collections::VecDeque;
45
46pub type TermId = u32;
48
49pub type TheoryId = u32;
51
52pub type DecisionLevel = u32;
54
55pub type ClassId = usize;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub struct Equality {
61 pub lhs: TermId,
63 pub rhs: TermId,
65}
66
67impl Equality {
68 pub fn new(lhs: TermId, rhs: TermId) -> Self {
70 if lhs <= rhs {
71 Self { lhs, rhs }
72 } else {
73 Self { lhs: rhs, rhs: lhs }
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct Partition {
81 classes: Vec<FxHashSet<TermId>>,
83
84 term_to_class: FxHashMap<TermId, ClassId>,
86
87 representatives: Vec<TermId>,
89}
90
91impl Partition {
92 pub fn finest(terms: &[TermId]) -> Self {
94 let mut classes = Vec::new();
95 let mut term_to_class = FxHashMap::default();
96 let mut representatives = Vec::new();
97
98 for (i, &term) in terms.iter().enumerate() {
99 let mut class = FxHashSet::default();
100 class.insert(term);
101 classes.push(class);
102 term_to_class.insert(term, i);
103 representatives.push(term);
104 }
105
106 Self {
107 classes,
108 term_to_class,
109 representatives,
110 }
111 }
112
113 pub fn coarsest(terms: &[TermId]) -> Self {
115 if terms.is_empty() {
116 return Self {
117 classes: Vec::new(),
118 term_to_class: FxHashMap::default(),
119 representatives: Vec::new(),
120 };
121 }
122
123 let mut class = FxHashSet::default();
124 let mut term_to_class = FxHashMap::default();
125
126 for &term in terms {
127 class.insert(term);
128 term_to_class.insert(term, 0);
129 }
130
131 Self {
132 classes: vec![class],
133 term_to_class,
134 representatives: vec![terms[0]],
135 }
136 }
137
138 pub fn merge(&mut self, t1: TermId, t2: TermId) -> Result<(), String> {
140 let c1 = *self.term_to_class.get(&t1).ok_or("Term not in partition")?;
141 let c2 = *self.term_to_class.get(&t2).ok_or("Term not in partition")?;
142
143 if c1 == c2 {
144 return Ok(());
145 }
146
147 let (src, dst) = if self.classes[c1].len() < self.classes[c2].len() {
149 (c1, c2)
150 } else {
151 (c2, c1)
152 };
153
154 let src_terms: Vec<_> = self.classes[src].iter().copied().collect();
156 for term in src_terms {
157 self.classes[dst].insert(term);
158 self.term_to_class.insert(term, dst);
159 }
160
161 self.classes[src].clear();
163
164 Ok(())
165 }
166
167 pub fn get_equalities(&self) -> Vec<Equality> {
169 let mut equalities = Vec::new();
170
171 for class in &self.classes {
172 if class.len() > 1 {
173 let terms: Vec<_> = class.iter().copied().collect();
174 let rep = terms[0];
176 for &term in &terms[1..] {
177 equalities.push(Equality::new(rep, term));
178 }
179 }
180 }
181
182 equalities
183 }
184
185 pub fn num_classes(&self) -> usize {
187 self.classes.iter().filter(|c| !c.is_empty()).count()
188 }
189
190 pub fn are_equal(&self, t1: TermId, t2: TermId) -> bool {
192 if let (Some(&c1), Some(&c2)) = (self.term_to_class.get(&t1), self.term_to_class.get(&t2)) {
193 c1 == c2
194 } else {
195 false
196 }
197 }
198
199 pub fn get_representative(&self, term: TermId) -> Option<TermId> {
201 self.term_to_class
202 .get(&term)
203 .and_then(|&class_id| self.representatives.get(class_id))
204 .copied()
205 }
206
207 pub fn get_class(&self, term: TermId) -> Option<&FxHashSet<TermId>> {
209 self.term_to_class
210 .get(&term)
211 .and_then(|&class_id| self.classes.get(class_id))
212 }
213
214 pub fn clone_partition(&self) -> Partition {
216 self.clone()
217 }
218}
219
220pub struct PartitionRefinement {
222 partition: Partition,
224
225 history: Vec<Partition>,
227
228 decision_levels: Vec<DecisionLevel>,
230
231 current_level: DecisionLevel,
233}
234
235impl PartitionRefinement {
236 pub fn new(terms: &[TermId]) -> Self {
238 Self {
239 partition: Partition::finest(terms),
240 history: Vec::new(),
241 decision_levels: Vec::new(),
242 current_level: 0,
243 }
244 }
245
246 pub fn refine(&mut self, eq: Equality) -> Result<(), String> {
248 self.history.push(self.partition.clone_partition());
249 self.decision_levels.push(self.current_level);
250 self.partition.merge(eq.lhs, eq.rhs)
251 }
252
253 pub fn refine_batch(&mut self, equalities: &[Equality]) -> Result<(), String> {
255 for &eq in equalities {
256 self.refine(eq)?;
257 }
258 Ok(())
259 }
260
261 pub fn current(&self) -> &Partition {
263 &self.partition
264 }
265
266 pub fn backtrack_step(&mut self) -> Result<(), String> {
268 self.partition = self.history.pop().ok_or("No refinement to backtrack")?;
269 self.decision_levels.pop();
270 Ok(())
271 }
272
273 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
275 while !self.decision_levels.is_empty() {
276 if let Some(&last_level) = self.decision_levels.last() {
277 if last_level > level {
278 self.backtrack_step()?;
279 } else {
280 break;
281 }
282 } else {
283 break;
284 }
285 }
286
287 self.current_level = level;
288 Ok(())
289 }
290
291 pub fn push_decision_level(&mut self) {
293 self.current_level += 1;
294 }
295
296 pub fn clear_history(&mut self) {
298 self.history.clear();
299 self.decision_levels.clear();
300 }
301}
302
303pub struct PartitionEnumerator {
305 n: usize,
307
308 terms: Vec<TermId>,
310
311 rgs: Vec<usize>,
313
314 max_val: usize,
316
317 done: bool,
319}
320
321impl PartitionEnumerator {
322 pub fn new(terms: Vec<TermId>) -> Self {
324 let n = terms.len();
325 Self {
326 n,
327 terms,
328 rgs: vec![0; n],
329 max_val: 0,
330 done: n == 0,
331 }
332 }
333
334 #[allow(clippy::should_implement_trait)]
336 pub fn next(&mut self) -> Option<Partition> {
337 if self.done {
338 return None;
339 }
340
341 let partition = self.rgs_to_partition();
343
344 self.next_rgs();
346
347 Some(partition)
348 }
349
350 fn rgs_to_partition(&self) -> Partition {
352 let mut classes: Vec<FxHashSet<TermId>> = vec![FxHashSet::default(); self.max_val + 1];
353 let mut term_to_class = FxHashMap::default();
354 let mut representatives = vec![0; self.max_val + 1];
355
356 for (i, &class_id) in self.rgs.iter().enumerate() {
357 let term = self.terms[i];
358 classes[class_id].insert(term);
359 term_to_class.insert(term, class_id);
360
361 if representatives[class_id] == 0 || term < representatives[class_id] {
362 representatives[class_id] = term;
363 }
364 }
365
366 Partition {
367 classes,
368 term_to_class,
369 representatives,
370 }
371 }
372
373 fn next_rgs(&mut self) {
375 let mut i = self.n;
377 while i > 0 {
378 i -= 1;
379
380 let can_increment = if i == 0 {
381 false
382 } else {
383 let max_up_to_i = self.rgs[..i].iter().max().copied().unwrap_or(0);
384 self.rgs[i] <= max_up_to_i
385 };
386
387 if can_increment {
388 self.rgs[i] += 1;
389
390 self.max_val = self.rgs.iter().max().copied().unwrap_or(0);
392
393 for j in (i + 1)..self.n {
395 self.rgs[j] = 0;
396 }
397
398 return;
399 }
400 }
401
402 self.done = true;
403 }
404
405 pub fn reset(&mut self) {
407 self.rgs = vec![0; self.n];
408 self.max_val = 0;
409 self.done = self.n == 0;
410 }
411
412 pub fn count_remaining(&self) -> usize {
414 bell_number(self.n)
416 }
417}
418
419fn bell_number(n: usize) -> usize {
421 if n == 0 {
422 return 1;
423 }
424
425 match n {
427 0 => 1,
428 1 => 1,
429 2 => 2,
430 3 => 5,
431 4 => 15,
432 5 => 52,
433 6 => 203,
434 7 => 877,
435 8 => 4140,
436 _ => usize::MAX, }
438}
439
440#[derive(Debug, Clone)]
442pub struct PartitionRefinementConfig {
443 pub enable_enumeration: bool,
445
446 pub max_partitions: usize,
448
449 pub constraint_guided: bool,
451
452 pub enable_backtracking: bool,
454}
455
456impl Default for PartitionRefinementConfig {
457 fn default() -> Self {
458 Self {
459 enable_enumeration: true,
460 max_partitions: 1000,
461 constraint_guided: true,
462 enable_backtracking: true,
463 }
464 }
465}
466
467#[derive(Debug, Clone, Default)]
469pub struct PartitionRefinementStats {
470 pub refinements: u64,
472 pub partitions_enumerated: u64,
474 pub backtracks: u64,
476 pub constraints_applied: u64,
478}
479
480pub struct PartitionRefinementManager {
482 config: PartitionRefinementConfig,
484
485 stats: PartitionRefinementStats,
487
488 refinement: PartitionRefinement,
490
491 enumerator: Option<PartitionEnumerator>,
493
494 constraints: VecDeque<Equality>,
496}
497
498impl PartitionRefinementManager {
499 pub fn new(terms: Vec<TermId>) -> Self {
501 Self::with_config(terms, PartitionRefinementConfig::default())
502 }
503
504 pub fn with_config(terms: Vec<TermId>, config: PartitionRefinementConfig) -> Self {
506 let enumerator = if config.enable_enumeration {
507 Some(PartitionEnumerator::new(terms.clone()))
508 } else {
509 None
510 };
511
512 Self {
513 config,
514 stats: PartitionRefinementStats::default(),
515 refinement: PartitionRefinement::new(&terms),
516 enumerator,
517 constraints: VecDeque::new(),
518 }
519 }
520
521 pub fn stats(&self) -> &PartitionRefinementStats {
523 &self.stats
524 }
525
526 pub fn add_constraint(&mut self, eq: Equality) {
528 self.constraints.push_back(eq);
529 self.stats.constraints_applied += 1;
530 }
531
532 pub fn apply_constraints(&mut self) -> Result<(), String> {
534 while let Some(eq) = self.constraints.pop_front() {
535 self.refinement.refine(eq)?;
536 self.stats.refinements += 1;
537 }
538 Ok(())
539 }
540
541 pub fn current_partition(&self) -> &Partition {
543 self.refinement.current()
544 }
545
546 pub fn next_partition(&mut self) -> Option<Partition> {
548 if let Some(ref mut enumerator) = self.enumerator {
549 if self.stats.partitions_enumerated >= self.config.max_partitions as u64 {
550 return None;
551 }
552
553 let partition = enumerator.next();
554 if partition.is_some() {
555 self.stats.partitions_enumerated += 1;
556 }
557 partition
558 } else {
559 None
560 }
561 }
562
563 pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
565 if !self.config.enable_backtracking {
566 return Ok(());
567 }
568
569 self.refinement.backtrack(level)?;
570 self.stats.backtracks += 1;
571 Ok(())
572 }
573
574 pub fn push_decision_level(&mut self) {
576 self.refinement.push_decision_level();
577 }
578
579 pub fn clear(&mut self) {
581 self.refinement.clear_history();
582 self.constraints.clear();
583
584 if let Some(ref mut enumerator) = self.enumerator {
585 enumerator.reset();
586 }
587 }
588
589 pub fn reset_stats(&mut self) {
591 self.stats = PartitionRefinementStats::default();
592 }
593}
594
595pub struct PartitionComparator;
597
598impl PartitionComparator {
599 pub fn is_finer(p1: &Partition, p2: &Partition) -> bool {
601 for class1 in &p1.classes {
603 if class1.is_empty() {
604 continue;
605 }
606
607 let first_term = *class1.iter().next().expect("Non-empty class");
609 let p2_class = p2.term_to_class.get(&first_term);
610
611 for &term in class1 {
612 if p2.term_to_class.get(&term) != p2_class {
613 return false;
614 }
615 }
616 }
617
618 true
619 }
620
621 pub fn are_equal(p1: &Partition, p2: &Partition) -> bool {
623 Self::is_finer(p1, p2) && Self::is_finer(p2, p1)
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn test_finest_partition() {
633 let terms = vec![1, 2, 3];
634 let partition = Partition::finest(&terms);
635
636 assert_eq!(partition.num_classes(), 3);
637 assert!(!partition.are_equal(1, 2));
638 }
639
640 #[test]
641 fn test_coarsest_partition() {
642 let terms = vec![1, 2, 3];
643 let partition = Partition::coarsest(&terms);
644
645 assert_eq!(partition.num_classes(), 1);
646 assert!(partition.are_equal(1, 2));
647 assert!(partition.are_equal(2, 3));
648 }
649
650 #[test]
651 fn test_partition_merge() {
652 let terms = vec![1, 2, 3, 4];
653 let mut partition = Partition::finest(&terms);
654
655 partition.merge(1, 2).expect("Merge failed");
656 assert_eq!(partition.num_classes(), 3);
657 assert!(partition.are_equal(1, 2));
658 assert!(!partition.are_equal(1, 3));
659 }
660
661 #[test]
662 fn test_partition_equalities() {
663 let terms = vec![1, 2, 3];
664 let mut partition = Partition::finest(&terms);
665
666 partition.merge(1, 2).expect("Merge failed");
667 partition.merge(2, 3).expect("Merge failed");
668
669 let equalities = partition.get_equalities();
670 assert_eq!(equalities.len(), 2); }
672
673 #[test]
674 fn test_refinement() {
675 let terms = vec![1, 2, 3, 4];
676 let mut refinement = PartitionRefinement::new(&terms);
677
678 refinement
679 .refine(Equality::new(1, 2))
680 .expect("Refine failed");
681 assert!(refinement.current().are_equal(1, 2));
682 }
683
684 #[test]
685 fn test_refinement_backtrack() {
686 let terms = vec![1, 2, 3, 4];
687 let mut refinement = PartitionRefinement::new(&terms);
688
689 refinement
690 .refine(Equality::new(1, 2))
691 .expect("Refine failed");
692 refinement.backtrack_step().expect("Backtrack failed");
693
694 assert!(!refinement.current().are_equal(1, 2));
695 }
696
697 #[test]
698 fn test_bell_number() {
699 assert_eq!(bell_number(0), 1);
700 assert_eq!(bell_number(1), 1);
701 assert_eq!(bell_number(2), 2);
702 assert_eq!(bell_number(3), 5);
703 assert_eq!(bell_number(4), 15);
704 }
705
706 #[test]
707 fn test_partition_enumerator() {
708 let terms = vec![1, 2, 3];
709 let mut enumerator = PartitionEnumerator::new(terms);
710
711 let mut count = 0;
712 while enumerator.next().is_some() {
713 count += 1;
714 }
715
716 assert_eq!(count, 5); }
718
719 #[test]
720 fn test_manager() {
721 let terms = vec![1, 2, 3];
722 let mut manager = PartitionRefinementManager::new(terms);
723
724 manager.add_constraint(Equality::new(1, 2));
725 manager.apply_constraints().expect("Apply failed");
726
727 assert!(manager.current_partition().are_equal(1, 2));
728 }
729
730 #[test]
731 fn test_partition_comparison() {
732 let terms = vec![1, 2, 3];
733
734 let finest = Partition::finest(&terms);
735 let coarsest = Partition::coarsest(&terms);
736
737 assert!(PartitionComparator::is_finer(&finest, &coarsest));
738 assert!(!PartitionComparator::is_finer(&coarsest, &finest));
739 }
740
741 #[test]
742 fn test_representative() {
743 let terms = vec![1, 2, 3];
744 let mut partition = Partition::finest(&terms);
745
746 partition.merge(1, 2).expect("Merge failed");
747
748 let rep1 = partition.get_representative(1);
749 let rep2 = partition.get_representative(2);
750
751 assert_eq!(rep1, rep2);
752 }
753
754 #[test]
755 fn test_get_class() {
756 let terms = vec![1, 2, 3, 4];
757 let mut partition = Partition::finest(&terms);
758
759 partition.merge(1, 2).expect("Merge failed");
760 partition.merge(2, 3).expect("Merge failed");
761
762 let class = partition.get_class(1).expect("No class");
763 assert_eq!(class.len(), 3);
764 assert!(class.contains(&1));
765 assert!(class.contains(&2));
766 assert!(class.contains(&3));
767 }
768}