Skip to main content

oxiz_solver/combination/
partition_refinement.rs

1//! Partition Refinement for Theory Combination.
2//!
3//! This module implements partition refinement algorithms for managing
4//! equality arrangements in Nelson-Oppen combination:
5//! - Set partition enumeration
6//! - Partition refinement with constraints
7//! - Efficient arrangement generation
8//! - Backtrackable partition data structures
9//!
10//! ## Partition Refinement
11//!
12//! Given a set of terms {t1, ..., tn}, we need to enumerate all possible
13//! **partitions** (equivalence relations) over these terms. Each partition
14//! represents a possible equality arrangement.
15//!
16//! ## Bell Numbers
17//!
18//! The number of partitions of n elements is given by the Bell number B(n):
19//! - B(1) = 1
20//! - B(2) = 2
21//! - B(3) = 5
22//! - B(4) = 15
23//! - B(5) = 52
24//!
25//! This grows very quickly, so efficient enumeration and pruning is critical.
26//!
27//! ## Partition Refinement Algorithm
28//!
29//! Starting from the finest partition (all singletons), we can:
30//! 1. Merge classes based on constraints
31//! 2. Enumerate coarser partitions
32//! 3. Backtrack when conflicts arise
33//!
34//! ## References
35//!
36//! - Knuth TAOCP Vol 4A: "Combinatorial Algorithms, Part 1"
37//! - Restricted Growth Strings for partition enumeration
38//! - Z3's theory combination implementation
39
40#![allow(missing_docs)]
41#![allow(dead_code)]
42
43use rustc_hash::{FxHashMap, FxHashSet};
44use std::collections::VecDeque;
45
46/// Term identifier.
47pub type TermId = u32;
48
49/// Theory identifier.
50pub type TheoryId = u32;
51
52/// Decision level.
53pub type DecisionLevel = u32;
54
55/// Class identifier in a partition.
56pub type ClassId = usize;
57
58/// Equality between terms.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub struct Equality {
61    /// Left-hand side.
62    pub lhs: TermId,
63    /// Right-hand side.
64    pub rhs: TermId,
65}
66
67impl Equality {
68    /// Create new equality.
69    pub fn new(lhs: TermId, rhs: TermId) -> Self {
70        if lhs <= rhs {
71            Self { lhs, rhs }
72        } else {
73            Self { lhs: rhs, rhs: lhs }
74        }
75    }
76}
77
78/// Set partition of terms.
79#[derive(Debug, Clone)]
80pub struct Partition {
81    /// Equivalence classes.
82    classes: Vec<FxHashSet<TermId>>,
83
84    /// Term to class mapping.
85    term_to_class: FxHashMap<TermId, ClassId>,
86
87    /// Representative for each class.
88    representatives: Vec<TermId>,
89}
90
91impl Partition {
92    /// Create finest partition (all singletons).
93    pub fn finest(terms: &[TermId]) -> Self {
94        let mut classes = Vec::new();
95        let mut term_to_class = FxHashMap::default();
96        let mut representatives = Vec::new();
97
98        for (i, &term) in terms.iter().enumerate() {
99            let mut class = FxHashSet::default();
100            class.insert(term);
101            classes.push(class);
102            term_to_class.insert(term, i);
103            representatives.push(term);
104        }
105
106        Self {
107            classes,
108            term_to_class,
109            representatives,
110        }
111    }
112
113    /// Create coarsest partition (all terms in one class).
114    pub fn coarsest(terms: &[TermId]) -> Self {
115        if terms.is_empty() {
116            return Self {
117                classes: Vec::new(),
118                term_to_class: FxHashMap::default(),
119                representatives: Vec::new(),
120            };
121        }
122
123        let mut class = FxHashSet::default();
124        let mut term_to_class = FxHashMap::default();
125
126        for &term in terms {
127            class.insert(term);
128            term_to_class.insert(term, 0);
129        }
130
131        Self {
132            classes: vec![class],
133            term_to_class,
134            representatives: vec![terms[0]],
135        }
136    }
137
138    /// Merge two classes.
139    pub fn merge(&mut self, t1: TermId, t2: TermId) -> Result<(), String> {
140        let c1 = *self.term_to_class.get(&t1).ok_or("Term not in partition")?;
141        let c2 = *self.term_to_class.get(&t2).ok_or("Term not in partition")?;
142
143        if c1 == c2 {
144            return Ok(());
145        }
146
147        // Merge smaller into larger
148        let (src, dst) = if self.classes[c1].len() < self.classes[c2].len() {
149            (c1, c2)
150        } else {
151            (c2, c1)
152        };
153
154        // Move all terms from src to dst
155        let src_terms: Vec<_> = self.classes[src].iter().copied().collect();
156        for term in src_terms {
157            self.classes[dst].insert(term);
158            self.term_to_class.insert(term, dst);
159        }
160
161        // Clear source class
162        self.classes[src].clear();
163
164        Ok(())
165    }
166
167    /// Get all equalities implied by this partition.
168    pub fn get_equalities(&self) -> Vec<Equality> {
169        let mut equalities = Vec::new();
170
171        for class in &self.classes {
172            if class.len() > 1 {
173                let terms: Vec<_> = class.iter().copied().collect();
174                // Use star topology: all terms equal to first term
175                let rep = terms[0];
176                for &term in &terms[1..] {
177                    equalities.push(Equality::new(rep, term));
178                }
179            }
180        }
181
182        equalities
183    }
184
185    /// Get number of non-empty classes.
186    pub fn num_classes(&self) -> usize {
187        self.classes.iter().filter(|c| !c.is_empty()).count()
188    }
189
190    /// Check if two terms are in the same class.
191    pub fn are_equal(&self, t1: TermId, t2: TermId) -> bool {
192        if let (Some(&c1), Some(&c2)) = (self.term_to_class.get(&t1), self.term_to_class.get(&t2)) {
193            c1 == c2
194        } else {
195            false
196        }
197    }
198
199    /// Get representative for a term.
200    pub fn get_representative(&self, term: TermId) -> Option<TermId> {
201        self.term_to_class
202            .get(&term)
203            .and_then(|&class_id| self.representatives.get(class_id))
204            .copied()
205    }
206
207    /// Get all terms in the same class as a term.
208    pub fn get_class(&self, term: TermId) -> Option<&FxHashSet<TermId>> {
209        self.term_to_class
210            .get(&term)
211            .and_then(|&class_id| self.classes.get(class_id))
212    }
213
214    /// Clone partition.
215    pub fn clone_partition(&self) -> Partition {
216        self.clone()
217    }
218}
219
220/// Partition refinement algorithm.
221pub struct PartitionRefinement {
222    /// Current partition.
223    partition: Partition,
224
225    /// Refinement history for backtracking.
226    history: Vec<Partition>,
227
228    /// Decision levels.
229    decision_levels: Vec<DecisionLevel>,
230
231    /// Current decision level.
232    current_level: DecisionLevel,
233}
234
235impl PartitionRefinement {
236    /// Create new refinement starting from finest partition.
237    pub fn new(terms: &[TermId]) -> Self {
238        Self {
239            partition: Partition::finest(terms),
240            history: Vec::new(),
241            decision_levels: Vec::new(),
242            current_level: 0,
243        }
244    }
245
246    /// Refine with equality.
247    pub fn refine(&mut self, eq: Equality) -> Result<(), String> {
248        self.history.push(self.partition.clone_partition());
249        self.decision_levels.push(self.current_level);
250        self.partition.merge(eq.lhs, eq.rhs)
251    }
252
253    /// Refine with multiple equalities.
254    pub fn refine_batch(&mut self, equalities: &[Equality]) -> Result<(), String> {
255        for &eq in equalities {
256            self.refine(eq)?;
257        }
258        Ok(())
259    }
260
261    /// Get current partition.
262    pub fn current(&self) -> &Partition {
263        &self.partition
264    }
265
266    /// Backtrack one step.
267    pub fn backtrack_step(&mut self) -> Result<(), String> {
268        self.partition = self.history.pop().ok_or("No refinement to backtrack")?;
269        self.decision_levels.pop();
270        Ok(())
271    }
272
273    /// Backtrack to decision level.
274    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
275        while !self.decision_levels.is_empty() {
276            if let Some(&last_level) = self.decision_levels.last() {
277                if last_level > level {
278                    self.backtrack_step()?;
279                } else {
280                    break;
281                }
282            } else {
283                break;
284            }
285        }
286
287        self.current_level = level;
288        Ok(())
289    }
290
291    /// Push decision level.
292    pub fn push_decision_level(&mut self) {
293        self.current_level += 1;
294    }
295
296    /// Clear history.
297    pub fn clear_history(&mut self) {
298        self.history.clear();
299        self.decision_levels.clear();
300    }
301}
302
303/// Partition enumerator using Restricted Growth Strings.
304pub struct PartitionEnumerator {
305    /// Number of elements.
306    n: usize,
307
308    /// Terms being partitioned.
309    terms: Vec<TermId>,
310
311    /// Current RGS (Restricted Growth String).
312    rgs: Vec<usize>,
313
314    /// Maximum value seen so far.
315    max_val: usize,
316
317    /// Is enumeration complete?
318    done: bool,
319}
320
321impl PartitionEnumerator {
322    /// Create new enumerator.
323    pub fn new(terms: Vec<TermId>) -> Self {
324        let n = terms.len();
325        Self {
326            n,
327            terms,
328            rgs: vec![0; n],
329            max_val: 0,
330            done: n == 0,
331        }
332    }
333
334    /// Get next partition.
335    #[allow(clippy::should_implement_trait)]
336    pub fn next(&mut self) -> Option<Partition> {
337        if self.done {
338            return None;
339        }
340
341        // Build partition from current RGS
342        let partition = self.rgs_to_partition();
343
344        // Generate next RGS
345        self.next_rgs();
346
347        Some(partition)
348    }
349
350    /// Convert RGS to partition.
351    fn rgs_to_partition(&self) -> Partition {
352        let mut classes: Vec<FxHashSet<TermId>> = vec![FxHashSet::default(); self.max_val + 1];
353        let mut term_to_class = FxHashMap::default();
354        let mut representatives = vec![0; self.max_val + 1];
355
356        for (i, &class_id) in self.rgs.iter().enumerate() {
357            let term = self.terms[i];
358            classes[class_id].insert(term);
359            term_to_class.insert(term, class_id);
360
361            if representatives[class_id] == 0 || term < representatives[class_id] {
362                representatives[class_id] = term;
363            }
364        }
365
366        Partition {
367            classes,
368            term_to_class,
369            representatives,
370        }
371    }
372
373    /// Generate next RGS.
374    fn next_rgs(&mut self) {
375        // Find rightmost position that can be incremented
376        let mut i = self.n;
377        while i > 0 {
378            i -= 1;
379
380            let can_increment = if i == 0 {
381                false
382            } else {
383                let max_up_to_i = self.rgs[..i].iter().max().copied().unwrap_or(0);
384                self.rgs[i] <= max_up_to_i
385            };
386
387            if can_increment {
388                self.rgs[i] += 1;
389
390                // Update max_val
391                self.max_val = self.rgs.iter().max().copied().unwrap_or(0);
392
393                // Reset suffix to 0
394                for j in (i + 1)..self.n {
395                    self.rgs[j] = 0;
396                }
397
398                return;
399            }
400        }
401
402        self.done = true;
403    }
404
405    /// Reset enumerator.
406    pub fn reset(&mut self) {
407        self.rgs = vec![0; self.n];
408        self.max_val = 0;
409        self.done = self.n == 0;
410    }
411
412    /// Get number of remaining partitions (approximate).
413    pub fn count_remaining(&self) -> usize {
414        // Bell number computation (simplified)
415        bell_number(self.n)
416    }
417}
418
419/// Compute Bell number B(n).
420fn bell_number(n: usize) -> usize {
421    if n == 0 {
422        return 1;
423    }
424
425    // Use Stirling numbers (simplified for small n)
426    match n {
427        0 => 1,
428        1 => 1,
429        2 => 2,
430        3 => 5,
431        4 => 15,
432        5 => 52,
433        6 => 203,
434        7 => 877,
435        8 => 4140,
436        _ => usize::MAX, // Too large
437    }
438}
439
440/// Configuration for partition refinement.
441#[derive(Debug, Clone)]
442pub struct PartitionRefinementConfig {
443    /// Enable partition enumeration.
444    pub enable_enumeration: bool,
445
446    /// Maximum partitions to enumerate.
447    pub max_partitions: usize,
448
449    /// Enable constraint-guided refinement.
450    pub constraint_guided: bool,
451
452    /// Enable backtracking.
453    pub enable_backtracking: bool,
454}
455
456impl Default for PartitionRefinementConfig {
457    fn default() -> Self {
458        Self {
459            enable_enumeration: true,
460            max_partitions: 1000,
461            constraint_guided: true,
462            enable_backtracking: true,
463        }
464    }
465}
466
467/// Statistics for partition refinement.
468#[derive(Debug, Clone, Default)]
469pub struct PartitionRefinementStats {
470    /// Refinements performed.
471    pub refinements: u64,
472    /// Partitions enumerated.
473    pub partitions_enumerated: u64,
474    /// Backtracks.
475    pub backtracks: u64,
476    /// Constraints applied.
477    pub constraints_applied: u64,
478}
479
480/// Partition refinement manager.
481pub struct PartitionRefinementManager {
482    /// Configuration.
483    config: PartitionRefinementConfig,
484
485    /// Statistics.
486    stats: PartitionRefinementStats,
487
488    /// Refinement algorithm.
489    refinement: PartitionRefinement,
490
491    /// Enumerator (if enumeration enabled).
492    enumerator: Option<PartitionEnumerator>,
493
494    /// Constraint queue.
495    constraints: VecDeque<Equality>,
496}
497
498impl PartitionRefinementManager {
499    /// Create new manager.
500    pub fn new(terms: Vec<TermId>) -> Self {
501        Self::with_config(terms, PartitionRefinementConfig::default())
502    }
503
504    /// Create with configuration.
505    pub fn with_config(terms: Vec<TermId>, config: PartitionRefinementConfig) -> Self {
506        let enumerator = if config.enable_enumeration {
507            Some(PartitionEnumerator::new(terms.clone()))
508        } else {
509            None
510        };
511
512        Self {
513            config,
514            stats: PartitionRefinementStats::default(),
515            refinement: PartitionRefinement::new(&terms),
516            enumerator,
517            constraints: VecDeque::new(),
518        }
519    }
520
521    /// Get statistics.
522    pub fn stats(&self) -> &PartitionRefinementStats {
523        &self.stats
524    }
525
526    /// Add constraint.
527    pub fn add_constraint(&mut self, eq: Equality) {
528        self.constraints.push_back(eq);
529        self.stats.constraints_applied += 1;
530    }
531
532    /// Apply constraints and refine.
533    pub fn apply_constraints(&mut self) -> Result<(), String> {
534        while let Some(eq) = self.constraints.pop_front() {
535            self.refinement.refine(eq)?;
536            self.stats.refinements += 1;
537        }
538        Ok(())
539    }
540
541    /// Get current partition.
542    pub fn current_partition(&self) -> &Partition {
543        self.refinement.current()
544    }
545
546    /// Get next enumerated partition.
547    pub fn next_partition(&mut self) -> Option<Partition> {
548        if let Some(ref mut enumerator) = self.enumerator {
549            if self.stats.partitions_enumerated >= self.config.max_partitions as u64 {
550                return None;
551            }
552
553            let partition = enumerator.next();
554            if partition.is_some() {
555                self.stats.partitions_enumerated += 1;
556            }
557            partition
558        } else {
559            None
560        }
561    }
562
563    /// Backtrack to decision level.
564    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
565        if !self.config.enable_backtracking {
566            return Ok(());
567        }
568
569        self.refinement.backtrack(level)?;
570        self.stats.backtracks += 1;
571        Ok(())
572    }
573
574    /// Push decision level.
575    pub fn push_decision_level(&mut self) {
576        self.refinement.push_decision_level();
577    }
578
579    /// Clear all state.
580    pub fn clear(&mut self) {
581        self.refinement.clear_history();
582        self.constraints.clear();
583
584        if let Some(ref mut enumerator) = self.enumerator {
585            enumerator.reset();
586        }
587    }
588
589    /// Reset statistics.
590    pub fn reset_stats(&mut self) {
591        self.stats = PartitionRefinementStats::default();
592    }
593}
594
595/// Partition comparison utilities.
596pub struct PartitionComparator;
597
598impl PartitionComparator {
599    /// Check if p1 is finer than p2.
600    pub fn is_finer(p1: &Partition, p2: &Partition) -> bool {
601        // p1 is finer if every class in p1 is a subset of some class in p2
602        for class1 in &p1.classes {
603            if class1.is_empty() {
604                continue;
605            }
606
607            // Check if all terms in class1 are in the same class in p2
608            let first_term = *class1.iter().next().expect("Non-empty class");
609            let p2_class = p2.term_to_class.get(&first_term);
610
611            for &term in class1 {
612                if p2.term_to_class.get(&term) != p2_class {
613                    return false;
614                }
615            }
616        }
617
618        true
619    }
620
621    /// Check if partitions are equal.
622    pub fn are_equal(p1: &Partition, p2: &Partition) -> bool {
623        Self::is_finer(p1, p2) && Self::is_finer(p2, p1)
624    }
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    #[test]
632    fn test_finest_partition() {
633        let terms = vec![1, 2, 3];
634        let partition = Partition::finest(&terms);
635
636        assert_eq!(partition.num_classes(), 3);
637        assert!(!partition.are_equal(1, 2));
638    }
639
640    #[test]
641    fn test_coarsest_partition() {
642        let terms = vec![1, 2, 3];
643        let partition = Partition::coarsest(&terms);
644
645        assert_eq!(partition.num_classes(), 1);
646        assert!(partition.are_equal(1, 2));
647        assert!(partition.are_equal(2, 3));
648    }
649
650    #[test]
651    fn test_partition_merge() {
652        let terms = vec![1, 2, 3, 4];
653        let mut partition = Partition::finest(&terms);
654
655        partition.merge(1, 2).expect("Merge failed");
656        assert_eq!(partition.num_classes(), 3);
657        assert!(partition.are_equal(1, 2));
658        assert!(!partition.are_equal(1, 3));
659    }
660
661    #[test]
662    fn test_partition_equalities() {
663        let terms = vec![1, 2, 3];
664        let mut partition = Partition::finest(&terms);
665
666        partition.merge(1, 2).expect("Merge failed");
667        partition.merge(2, 3).expect("Merge failed");
668
669        let equalities = partition.get_equalities();
670        assert_eq!(equalities.len(), 2); // Star topology
671    }
672
673    #[test]
674    fn test_refinement() {
675        let terms = vec![1, 2, 3, 4];
676        let mut refinement = PartitionRefinement::new(&terms);
677
678        refinement
679            .refine(Equality::new(1, 2))
680            .expect("Refine failed");
681        assert!(refinement.current().are_equal(1, 2));
682    }
683
684    #[test]
685    fn test_refinement_backtrack() {
686        let terms = vec![1, 2, 3, 4];
687        let mut refinement = PartitionRefinement::new(&terms);
688
689        refinement
690            .refine(Equality::new(1, 2))
691            .expect("Refine failed");
692        refinement.backtrack_step().expect("Backtrack failed");
693
694        assert!(!refinement.current().are_equal(1, 2));
695    }
696
697    #[test]
698    fn test_bell_number() {
699        assert_eq!(bell_number(0), 1);
700        assert_eq!(bell_number(1), 1);
701        assert_eq!(bell_number(2), 2);
702        assert_eq!(bell_number(3), 5);
703        assert_eq!(bell_number(4), 15);
704    }
705
706    #[test]
707    fn test_partition_enumerator() {
708        let terms = vec![1, 2, 3];
709        let mut enumerator = PartitionEnumerator::new(terms);
710
711        let mut count = 0;
712        while enumerator.next().is_some() {
713            count += 1;
714        }
715
716        assert_eq!(count, 5); // B(3) = 5
717    }
718
719    #[test]
720    fn test_manager() {
721        let terms = vec![1, 2, 3];
722        let mut manager = PartitionRefinementManager::new(terms);
723
724        manager.add_constraint(Equality::new(1, 2));
725        manager.apply_constraints().expect("Apply failed");
726
727        assert!(manager.current_partition().are_equal(1, 2));
728    }
729
730    #[test]
731    fn test_partition_comparison() {
732        let terms = vec![1, 2, 3];
733
734        let finest = Partition::finest(&terms);
735        let coarsest = Partition::coarsest(&terms);
736
737        assert!(PartitionComparator::is_finer(&finest, &coarsest));
738        assert!(!PartitionComparator::is_finer(&coarsest, &finest));
739    }
740
741    #[test]
742    fn test_representative() {
743        let terms = vec![1, 2, 3];
744        let mut partition = Partition::finest(&terms);
745
746        partition.merge(1, 2).expect("Merge failed");
747
748        let rep1 = partition.get_representative(1);
749        let rep2 = partition.get_representative(2);
750
751        assert_eq!(rep1, rep2);
752    }
753
754    #[test]
755    fn test_get_class() {
756        let terms = vec![1, 2, 3, 4];
757        let mut partition = Partition::finest(&terms);
758
759        partition.merge(1, 2).expect("Merge failed");
760        partition.merge(2, 3).expect("Merge failed");
761
762        let class = partition.get_class(1).expect("No class");
763        assert_eq!(class.len(), 3);
764        assert!(class.contains(&1));
765        assert!(class.contains(&2));
766        assert!(class.contains(&3));
767    }
768}