Skip to main content

oxiz_solver/combination/
partition_refinement.rs

1//! Partition Refinement for Theory Combination.
2//!
3//! This module implements partition refinement algorithms for managing
4//! equality arrangements in Nelson-Oppen combination:
5//! - Set partition enumeration
6//! - Partition refinement with constraints
7//! - Efficient arrangement generation
8//! - Backtrackable partition data structures
9//!
10//! ## Partition Refinement
11//!
12//! Given a set of terms {t1, ..., tn}, we need to enumerate all possible
13//! **partitions** (equivalence relations) over these terms. Each partition
14//! represents a possible equality arrangement.
15//!
16//! ## Bell Numbers
17//!
18//! The number of partitions of n elements is given by the Bell number B(n):
19//! - B(1) = 1
20//! - B(2) = 2
21//! - B(3) = 5
22//! - B(4) = 15
23//! - B(5) = 52
24//!
25//! This grows very quickly, so efficient enumeration and pruning is critical.
26//!
27//! ## Partition Refinement Algorithm
28//!
29//! Starting from the finest partition (all singletons), we can:
30//! 1. Merge classes based on constraints
31//! 2. Enumerate coarser partitions
32//! 3. Backtrack when conflicts arise
33//!
34//! ## References
35//!
36//! - Knuth TAOCP Vol 4A: "Combinatorial Algorithms, Part 1"
37//! - Restricted Growth Strings for partition enumeration
38//! - Z3's theory combination implementation
39
40#![allow(missing_docs)]
41#![allow(dead_code)]
42#[allow(unused_imports)]
43use crate::prelude::*;
44
45/// Term identifier.
46pub type TermId = u32;
47
48/// Theory identifier.
49pub type TheoryId = u32;
50
51/// Decision level.
52pub type DecisionLevel = u32;
53
54/// Class identifier in a partition.
55pub type ClassId = usize;
56
57/// Equality between terms.
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59pub struct Equality {
60    /// Left-hand side.
61    pub lhs: TermId,
62    /// Right-hand side.
63    pub rhs: TermId,
64}
65
66impl Equality {
67    /// Create new equality.
68    pub fn new(lhs: TermId, rhs: TermId) -> Self {
69        if lhs <= rhs {
70            Self { lhs, rhs }
71        } else {
72            Self { lhs: rhs, rhs: lhs }
73        }
74    }
75}
76
77/// Set partition of terms.
78#[derive(Debug, Clone)]
79pub struct Partition {
80    /// Equivalence classes.
81    classes: Vec<FxHashSet<TermId>>,
82
83    /// Term to class mapping.
84    term_to_class: FxHashMap<TermId, ClassId>,
85
86    /// Representative for each class.
87    representatives: Vec<TermId>,
88}
89
90impl Partition {
91    /// Create finest partition (all singletons).
92    pub fn finest(terms: &[TermId]) -> Self {
93        let mut classes = Vec::new();
94        let mut term_to_class = FxHashMap::default();
95        let mut representatives = Vec::new();
96
97        for (i, &term) in terms.iter().enumerate() {
98            let mut class = FxHashSet::default();
99            class.insert(term);
100            classes.push(class);
101            term_to_class.insert(term, i);
102            representatives.push(term);
103        }
104
105        Self {
106            classes,
107            term_to_class,
108            representatives,
109        }
110    }
111
112    /// Create coarsest partition (all terms in one class).
113    pub fn coarsest(terms: &[TermId]) -> Self {
114        if terms.is_empty() {
115            return Self {
116                classes: Vec::new(),
117                term_to_class: FxHashMap::default(),
118                representatives: Vec::new(),
119            };
120        }
121
122        let mut class = FxHashSet::default();
123        let mut term_to_class = FxHashMap::default();
124
125        for &term in terms {
126            class.insert(term);
127            term_to_class.insert(term, 0);
128        }
129
130        Self {
131            classes: vec![class],
132            term_to_class,
133            representatives: vec![terms[0]],
134        }
135    }
136
137    /// Merge two classes.
138    pub fn merge(&mut self, t1: TermId, t2: TermId) -> Result<(), String> {
139        let c1 = *self.term_to_class.get(&t1).ok_or("Term not in partition")?;
140        let c2 = *self.term_to_class.get(&t2).ok_or("Term not in partition")?;
141
142        if c1 == c2 {
143            return Ok(());
144        }
145
146        // Merge smaller into larger
147        let (src, dst) = if self.classes[c1].len() < self.classes[c2].len() {
148            (c1, c2)
149        } else {
150            (c2, c1)
151        };
152
153        // Move all terms from src to dst
154        let src_terms: Vec<_> = self.classes[src].iter().copied().collect();
155        for term in src_terms {
156            self.classes[dst].insert(term);
157            self.term_to_class.insert(term, dst);
158        }
159
160        // Clear source class
161        self.classes[src].clear();
162
163        Ok(())
164    }
165
166    /// Get all equalities implied by this partition.
167    pub fn get_equalities(&self) -> Vec<Equality> {
168        let mut equalities = Vec::new();
169
170        for class in &self.classes {
171            if class.len() > 1 {
172                let terms: Vec<_> = class.iter().copied().collect();
173                // Use star topology: all terms equal to first term
174                let rep = terms[0];
175                for &term in &terms[1..] {
176                    equalities.push(Equality::new(rep, term));
177                }
178            }
179        }
180
181        equalities
182    }
183
184    /// Get number of non-empty classes.
185    pub fn num_classes(&self) -> usize {
186        self.classes.iter().filter(|c| !c.is_empty()).count()
187    }
188
189    /// Check if two terms are in the same class.
190    pub fn are_equal(&self, t1: TermId, t2: TermId) -> bool {
191        if let (Some(&c1), Some(&c2)) = (self.term_to_class.get(&t1), self.term_to_class.get(&t2)) {
192            c1 == c2
193        } else {
194            false
195        }
196    }
197
198    /// Get representative for a term.
199    pub fn get_representative(&self, term: TermId) -> Option<TermId> {
200        self.term_to_class
201            .get(&term)
202            .and_then(|&class_id| self.representatives.get(class_id))
203            .copied()
204    }
205
206    /// Get all terms in the same class as a term.
207    pub fn get_class(&self, term: TermId) -> Option<&FxHashSet<TermId>> {
208        self.term_to_class
209            .get(&term)
210            .and_then(|&class_id| self.classes.get(class_id))
211    }
212
213    /// Clone partition.
214    pub fn clone_partition(&self) -> Partition {
215        self.clone()
216    }
217}
218
219/// Partition refinement algorithm.
220pub struct PartitionRefinement {
221    /// Current partition.
222    partition: Partition,
223
224    /// Refinement history for backtracking.
225    history: Vec<Partition>,
226
227    /// Decision levels.
228    decision_levels: Vec<DecisionLevel>,
229
230    /// Current decision level.
231    current_level: DecisionLevel,
232}
233
234impl PartitionRefinement {
235    /// Create new refinement starting from finest partition.
236    pub fn new(terms: &[TermId]) -> Self {
237        Self {
238            partition: Partition::finest(terms),
239            history: Vec::new(),
240            decision_levels: Vec::new(),
241            current_level: 0,
242        }
243    }
244
245    /// Refine with equality.
246    pub fn refine(&mut self, eq: Equality) -> Result<(), String> {
247        self.history.push(self.partition.clone_partition());
248        self.decision_levels.push(self.current_level);
249        self.partition.merge(eq.lhs, eq.rhs)
250    }
251
252    /// Refine with multiple equalities.
253    pub fn refine_batch(&mut self, equalities: &[Equality]) -> Result<(), String> {
254        for &eq in equalities {
255            self.refine(eq)?;
256        }
257        Ok(())
258    }
259
260    /// Get current partition.
261    pub fn current(&self) -> &Partition {
262        &self.partition
263    }
264
265    /// Backtrack one step.
266    pub fn backtrack_step(&mut self) -> Result<(), String> {
267        self.partition = self.history.pop().ok_or("No refinement to backtrack")?;
268        self.decision_levels.pop();
269        Ok(())
270    }
271
272    /// Backtrack to decision level.
273    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
274        while !self.decision_levels.is_empty() {
275            if let Some(&last_level) = self.decision_levels.last() {
276                if last_level > level {
277                    self.backtrack_step()?;
278                } else {
279                    break;
280                }
281            } else {
282                break;
283            }
284        }
285
286        self.current_level = level;
287        Ok(())
288    }
289
290    /// Push decision level.
291    pub fn push_decision_level(&mut self) {
292        self.current_level += 1;
293    }
294
295    /// Clear history.
296    pub fn clear_history(&mut self) {
297        self.history.clear();
298        self.decision_levels.clear();
299    }
300}
301
302/// Partition enumerator using Restricted Growth Strings.
303pub struct PartitionEnumerator {
304    /// Number of elements.
305    n: usize,
306
307    /// Terms being partitioned.
308    terms: Vec<TermId>,
309
310    /// Current RGS (Restricted Growth String).
311    rgs: Vec<usize>,
312
313    /// Maximum value seen so far.
314    max_val: usize,
315
316    /// Is enumeration complete?
317    done: bool,
318}
319
320impl PartitionEnumerator {
321    /// Create new enumerator.
322    pub fn new(terms: Vec<TermId>) -> Self {
323        let n = terms.len();
324        Self {
325            n,
326            terms,
327            rgs: vec![0; n],
328            max_val: 0,
329            done: n == 0,
330        }
331    }
332
333    /// Get next partition.
334    #[allow(clippy::should_implement_trait)]
335    pub fn next(&mut self) -> Option<Partition> {
336        if self.done {
337            return None;
338        }
339
340        // Build partition from current RGS
341        let partition = self.rgs_to_partition();
342
343        // Generate next RGS
344        self.next_rgs();
345
346        Some(partition)
347    }
348
349    /// Convert RGS to partition.
350    fn rgs_to_partition(&self) -> Partition {
351        let mut classes: Vec<FxHashSet<TermId>> = vec![FxHashSet::default(); self.max_val + 1];
352        let mut term_to_class = FxHashMap::default();
353        let mut representatives = vec![0; self.max_val + 1];
354
355        for (i, &class_id) in self.rgs.iter().enumerate() {
356            let term = self.terms[i];
357            classes[class_id].insert(term);
358            term_to_class.insert(term, class_id);
359
360            if representatives[class_id] == 0 || term < representatives[class_id] {
361                representatives[class_id] = term;
362            }
363        }
364
365        Partition {
366            classes,
367            term_to_class,
368            representatives,
369        }
370    }
371
372    /// Generate next RGS.
373    fn next_rgs(&mut self) {
374        // Find rightmost position that can be incremented
375        let mut i = self.n;
376        while i > 0 {
377            i -= 1;
378
379            let can_increment = if i == 0 {
380                false
381            } else {
382                let max_up_to_i = self.rgs[..i].iter().max().copied().unwrap_or(0);
383                self.rgs[i] <= max_up_to_i
384            };
385
386            if can_increment {
387                self.rgs[i] += 1;
388
389                // Update max_val
390                self.max_val = self.rgs.iter().max().copied().unwrap_or(0);
391
392                // Reset suffix to 0
393                for j in (i + 1)..self.n {
394                    self.rgs[j] = 0;
395                }
396
397                return;
398            }
399        }
400
401        self.done = true;
402    }
403
404    /// Reset enumerator.
405    pub fn reset(&mut self) {
406        self.rgs = vec![0; self.n];
407        self.max_val = 0;
408        self.done = self.n == 0;
409    }
410
411    /// Get number of remaining partitions (approximate).
412    pub fn count_remaining(&self) -> usize {
413        // Bell number computation (simplified)
414        bell_number(self.n)
415    }
416}
417
418/// Compute Bell number B(n).
419fn bell_number(n: usize) -> usize {
420    if n == 0 {
421        return 1;
422    }
423
424    // Use Stirling numbers (simplified for small n)
425    match n {
426        0 => 1,
427        1 => 1,
428        2 => 2,
429        3 => 5,
430        4 => 15,
431        5 => 52,
432        6 => 203,
433        7 => 877,
434        8 => 4140,
435        _ => usize::MAX, // Too large
436    }
437}
438
439/// Configuration for partition refinement.
440#[derive(Debug, Clone)]
441pub struct PartitionRefinementConfig {
442    /// Enable partition enumeration.
443    pub enable_enumeration: bool,
444
445    /// Maximum partitions to enumerate.
446    pub max_partitions: usize,
447
448    /// Enable constraint-guided refinement.
449    pub constraint_guided: bool,
450
451    /// Enable backtracking.
452    pub enable_backtracking: bool,
453}
454
455impl Default for PartitionRefinementConfig {
456    fn default() -> Self {
457        Self {
458            enable_enumeration: true,
459            max_partitions: 1000,
460            constraint_guided: true,
461            enable_backtracking: true,
462        }
463    }
464}
465
466/// Statistics for partition refinement.
467#[derive(Debug, Clone, Default)]
468pub struct PartitionRefinementStats {
469    /// Refinements performed.
470    pub refinements: u64,
471    /// Partitions enumerated.
472    pub partitions_enumerated: u64,
473    /// Backtracks.
474    pub backtracks: u64,
475    /// Constraints applied.
476    pub constraints_applied: u64,
477}
478
479/// Partition refinement manager.
480pub struct PartitionRefinementManager {
481    /// Configuration.
482    config: PartitionRefinementConfig,
483
484    /// Statistics.
485    stats: PartitionRefinementStats,
486
487    /// Refinement algorithm.
488    refinement: PartitionRefinement,
489
490    /// Enumerator (if enumeration enabled).
491    enumerator: Option<PartitionEnumerator>,
492
493    /// Constraint queue.
494    constraints: VecDeque<Equality>,
495}
496
497impl PartitionRefinementManager {
498    /// Create new manager.
499    pub fn new(terms: Vec<TermId>) -> Self {
500        Self::with_config(terms, PartitionRefinementConfig::default())
501    }
502
503    /// Create with configuration.
504    pub fn with_config(terms: Vec<TermId>, config: PartitionRefinementConfig) -> Self {
505        let enumerator = if config.enable_enumeration {
506            Some(PartitionEnumerator::new(terms.clone()))
507        } else {
508            None
509        };
510
511        Self {
512            config,
513            stats: PartitionRefinementStats::default(),
514            refinement: PartitionRefinement::new(&terms),
515            enumerator,
516            constraints: VecDeque::new(),
517        }
518    }
519
520    /// Get statistics.
521    pub fn stats(&self) -> &PartitionRefinementStats {
522        &self.stats
523    }
524
525    /// Add constraint.
526    pub fn add_constraint(&mut self, eq: Equality) {
527        self.constraints.push_back(eq);
528        self.stats.constraints_applied += 1;
529    }
530
531    /// Apply constraints and refine.
532    pub fn apply_constraints(&mut self) -> Result<(), String> {
533        while let Some(eq) = self.constraints.pop_front() {
534            self.refinement.refine(eq)?;
535            self.stats.refinements += 1;
536        }
537        Ok(())
538    }
539
540    /// Get current partition.
541    pub fn current_partition(&self) -> &Partition {
542        self.refinement.current()
543    }
544
545    /// Get next enumerated partition.
546    pub fn next_partition(&mut self) -> Option<Partition> {
547        if let Some(ref mut enumerator) = self.enumerator {
548            if self.stats.partitions_enumerated >= self.config.max_partitions as u64 {
549                return None;
550            }
551
552            let partition = enumerator.next();
553            if partition.is_some() {
554                self.stats.partitions_enumerated += 1;
555            }
556            partition
557        } else {
558            None
559        }
560    }
561
562    /// Backtrack to decision level.
563    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
564        if !self.config.enable_backtracking {
565            return Ok(());
566        }
567
568        self.refinement.backtrack(level)?;
569        self.stats.backtracks += 1;
570        Ok(())
571    }
572
573    /// Push decision level.
574    pub fn push_decision_level(&mut self) {
575        self.refinement.push_decision_level();
576    }
577
578    /// Clear all state.
579    pub fn clear(&mut self) {
580        self.refinement.clear_history();
581        self.constraints.clear();
582
583        if let Some(ref mut enumerator) = self.enumerator {
584            enumerator.reset();
585        }
586    }
587
588    /// Reset statistics.
589    pub fn reset_stats(&mut self) {
590        self.stats = PartitionRefinementStats::default();
591    }
592}
593
594/// Partition comparison utilities.
595pub struct PartitionComparator;
596
597impl PartitionComparator {
598    /// Check if p1 is finer than p2.
599    pub fn is_finer(p1: &Partition, p2: &Partition) -> bool {
600        // p1 is finer if every class in p1 is a subset of some class in p2
601        for class1 in &p1.classes {
602            if class1.is_empty() {
603                continue;
604            }
605
606            // Check if all terms in class1 are in the same class in p2
607            let first_term = *class1.iter().next().expect("Non-empty class");
608            let p2_class = p2.term_to_class.get(&first_term);
609
610            for &term in class1 {
611                if p2.term_to_class.get(&term) != p2_class {
612                    return false;
613                }
614            }
615        }
616
617        true
618    }
619
620    /// Check if partitions are equal.
621    pub fn are_equal(p1: &Partition, p2: &Partition) -> bool {
622        Self::is_finer(p1, p2) && Self::is_finer(p2, p1)
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629
630    #[test]
631    fn test_finest_partition() {
632        let terms = vec![1, 2, 3];
633        let partition = Partition::finest(&terms);
634
635        assert_eq!(partition.num_classes(), 3);
636        assert!(!partition.are_equal(1, 2));
637    }
638
639    #[test]
640    fn test_coarsest_partition() {
641        let terms = vec![1, 2, 3];
642        let partition = Partition::coarsest(&terms);
643
644        assert_eq!(partition.num_classes(), 1);
645        assert!(partition.are_equal(1, 2));
646        assert!(partition.are_equal(2, 3));
647    }
648
649    #[test]
650    fn test_partition_merge() {
651        let terms = vec![1, 2, 3, 4];
652        let mut partition = Partition::finest(&terms);
653
654        partition.merge(1, 2).expect("Merge failed");
655        assert_eq!(partition.num_classes(), 3);
656        assert!(partition.are_equal(1, 2));
657        assert!(!partition.are_equal(1, 3));
658    }
659
660    #[test]
661    fn test_partition_equalities() {
662        let terms = vec![1, 2, 3];
663        let mut partition = Partition::finest(&terms);
664
665        partition.merge(1, 2).expect("Merge failed");
666        partition.merge(2, 3).expect("Merge failed");
667
668        let equalities = partition.get_equalities();
669        assert_eq!(equalities.len(), 2); // Star topology
670    }
671
672    #[test]
673    fn test_refinement() {
674        let terms = vec![1, 2, 3, 4];
675        let mut refinement = PartitionRefinement::new(&terms);
676
677        refinement
678            .refine(Equality::new(1, 2))
679            .expect("Refine failed");
680        assert!(refinement.current().are_equal(1, 2));
681    }
682
683    #[test]
684    fn test_refinement_backtrack() {
685        let terms = vec![1, 2, 3, 4];
686        let mut refinement = PartitionRefinement::new(&terms);
687
688        refinement
689            .refine(Equality::new(1, 2))
690            .expect("Refine failed");
691        refinement.backtrack_step().expect("Backtrack failed");
692
693        assert!(!refinement.current().are_equal(1, 2));
694    }
695
696    #[test]
697    fn test_bell_number() {
698        assert_eq!(bell_number(0), 1);
699        assert_eq!(bell_number(1), 1);
700        assert_eq!(bell_number(2), 2);
701        assert_eq!(bell_number(3), 5);
702        assert_eq!(bell_number(4), 15);
703    }
704
705    #[test]
706    fn test_partition_enumerator() {
707        let terms = vec![1, 2, 3];
708        let mut enumerator = PartitionEnumerator::new(terms);
709
710        let mut count = 0;
711        while enumerator.next().is_some() {
712            count += 1;
713        }
714
715        assert_eq!(count, 5); // B(3) = 5
716    }
717
718    #[test]
719    fn test_manager() {
720        let terms = vec![1, 2, 3];
721        let mut manager = PartitionRefinementManager::new(terms);
722
723        manager.add_constraint(Equality::new(1, 2));
724        manager.apply_constraints().expect("Apply failed");
725
726        assert!(manager.current_partition().are_equal(1, 2));
727    }
728
729    #[test]
730    fn test_partition_comparison() {
731        let terms = vec![1, 2, 3];
732
733        let finest = Partition::finest(&terms);
734        let coarsest = Partition::coarsest(&terms);
735
736        assert!(PartitionComparator::is_finer(&finest, &coarsest));
737        assert!(!PartitionComparator::is_finer(&coarsest, &finest));
738    }
739
740    #[test]
741    fn test_representative() {
742        let terms = vec![1, 2, 3];
743        let mut partition = Partition::finest(&terms);
744
745        partition.merge(1, 2).expect("Merge failed");
746
747        let rep1 = partition.get_representative(1);
748        let rep2 = partition.get_representative(2);
749
750        assert_eq!(rep1, rep2);
751    }
752
753    #[test]
754    fn test_get_class() {
755        let terms = vec![1, 2, 3, 4];
756        let mut partition = Partition::finest(&terms);
757
758        partition.merge(1, 2).expect("Merge failed");
759        partition.merge(2, 3).expect("Merge failed");
760
761        let class = partition.get_class(1).expect("No class");
762        assert_eq!(class.len(), 3);
763        assert!(class.contains(&1));
764        assert!(class.contains(&2));
765        assert!(class.contains(&3));
766    }
767}