Skip to main content

oxiz_solver/combination/
partition_refinement.rs

1//! Partition Refinement for Theory Combination.
2//!
3//! This module implements partition refinement algorithms for managing
4//! equality arrangements in Nelson-Oppen combination:
5//! - Set partition enumeration
6//! - Partition refinement with constraints
7//! - Efficient arrangement generation
8//! - Backtrackable partition data structures
9//!
10//! ## Partition Refinement
11//!
12//! Given a set of terms {t1, ..., tn}, we need to enumerate all possible
13//! **partitions** (equivalence relations) over these terms. Each partition
14//! represents a possible equality arrangement.
15//!
16//! ## Bell Numbers
17//!
18//! The number of partitions of n elements is given by the Bell number B(n):
19//! - B(1) = 1
20//! - B(2) = 2
21//! - B(3) = 5
22//! - B(4) = 15
23//! - B(5) = 52
24//!
25//! This grows very quickly, so efficient enumeration and pruning is critical.
26//!
27//! ## Partition Refinement Algorithm
28//!
29//! Starting from the finest partition (all singletons), we can:
30//! 1. Merge classes based on constraints
31//! 2. Enumerate coarser partitions
32//! 3. Backtrack when conflicts arise
33//!
34//! ## References
35//!
36//! - Knuth TAOCP Vol 4A: "Combinatorial Algorithms, Part 1"
37//! - Restricted Growth Strings for partition enumeration
38//! - Z3's theory combination implementation
39
40#![allow(missing_docs)]
41#[allow(unused_imports)]
42use crate::prelude::*;
43
44/// Term identifier.
45pub type TermId = u32;
46
47/// Theory identifier.
48pub type TheoryId = u32;
49
50/// Decision level.
51pub type DecisionLevel = u32;
52
53/// Class identifier in a partition.
54pub type ClassId = usize;
55
56/// Equality between terms.
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub struct Equality {
59    /// Left-hand side.
60    pub lhs: TermId,
61    /// Right-hand side.
62    pub rhs: TermId,
63}
64
65impl Equality {
66    /// Create new equality.
67    pub fn new(lhs: TermId, rhs: TermId) -> Self {
68        if lhs <= rhs {
69            Self { lhs, rhs }
70        } else {
71            Self { lhs: rhs, rhs: lhs }
72        }
73    }
74}
75
76/// Set partition of terms.
77#[derive(Debug, Clone)]
78pub struct Partition {
79    /// Equivalence classes.
80    classes: Vec<FxHashSet<TermId>>,
81
82    /// Term to class mapping.
83    term_to_class: FxHashMap<TermId, ClassId>,
84
85    /// Representative for each class.
86    representatives: Vec<TermId>,
87}
88
89impl Partition {
90    /// Create finest partition (all singletons).
91    pub fn finest(terms: &[TermId]) -> Self {
92        let mut classes = Vec::new();
93        let mut term_to_class = FxHashMap::default();
94        let mut representatives = Vec::new();
95
96        for (i, &term) in terms.iter().enumerate() {
97            let mut class = FxHashSet::default();
98            class.insert(term);
99            classes.push(class);
100            term_to_class.insert(term, i);
101            representatives.push(term);
102        }
103
104        Self {
105            classes,
106            term_to_class,
107            representatives,
108        }
109    }
110
111    /// Create coarsest partition (all terms in one class).
112    pub fn coarsest(terms: &[TermId]) -> Self {
113        if terms.is_empty() {
114            return Self {
115                classes: Vec::new(),
116                term_to_class: FxHashMap::default(),
117                representatives: Vec::new(),
118            };
119        }
120
121        let mut class = FxHashSet::default();
122        let mut term_to_class = FxHashMap::default();
123
124        for &term in terms {
125            class.insert(term);
126            term_to_class.insert(term, 0);
127        }
128
129        Self {
130            classes: vec![class],
131            term_to_class,
132            representatives: vec![terms[0]],
133        }
134    }
135
136    /// Merge two classes.
137    pub fn merge(&mut self, t1: TermId, t2: TermId) -> Result<(), String> {
138        let c1 = *self.term_to_class.get(&t1).ok_or("Term not in partition")?;
139        let c2 = *self.term_to_class.get(&t2).ok_or("Term not in partition")?;
140
141        if c1 == c2 {
142            return Ok(());
143        }
144
145        // Merge smaller into larger
146        let (src, dst) = if self.classes[c1].len() < self.classes[c2].len() {
147            (c1, c2)
148        } else {
149            (c2, c1)
150        };
151
152        // Move all terms from src to dst
153        let src_terms: Vec<_> = self.classes[src].iter().copied().collect();
154        for term in src_terms {
155            self.classes[dst].insert(term);
156            self.term_to_class.insert(term, dst);
157        }
158
159        // Clear source class
160        self.classes[src].clear();
161
162        Ok(())
163    }
164
165    /// Get all equalities implied by this partition.
166    pub fn get_equalities(&self) -> Vec<Equality> {
167        let mut equalities = Vec::new();
168
169        for class in &self.classes {
170            if class.len() > 1 {
171                let terms: Vec<_> = class.iter().copied().collect();
172                // Use star topology: all terms equal to first term
173                let rep = terms[0];
174                for &term in &terms[1..] {
175                    equalities.push(Equality::new(rep, term));
176                }
177            }
178        }
179
180        equalities
181    }
182
183    /// Get number of non-empty classes.
184    pub fn num_classes(&self) -> usize {
185        self.classes.iter().filter(|c| !c.is_empty()).count()
186    }
187
188    /// Check if two terms are in the same class.
189    pub fn are_equal(&self, t1: TermId, t2: TermId) -> bool {
190        if let (Some(&c1), Some(&c2)) = (self.term_to_class.get(&t1), self.term_to_class.get(&t2)) {
191            c1 == c2
192        } else {
193            false
194        }
195    }
196
197    /// Get representative for a term.
198    pub fn get_representative(&self, term: TermId) -> Option<TermId> {
199        self.term_to_class
200            .get(&term)
201            .and_then(|&class_id| self.representatives.get(class_id))
202            .copied()
203    }
204
205    /// Get all terms in the same class as a term.
206    pub fn get_class(&self, term: TermId) -> Option<&FxHashSet<TermId>> {
207        self.term_to_class
208            .get(&term)
209            .and_then(|&class_id| self.classes.get(class_id))
210    }
211
212    /// Clone partition.
213    pub fn clone_partition(&self) -> Partition {
214        self.clone()
215    }
216}
217
218/// Partition refinement algorithm.
219pub struct PartitionRefinement {
220    /// Current partition.
221    partition: Partition,
222
223    /// Refinement history for backtracking.
224    history: Vec<Partition>,
225
226    /// Decision levels.
227    decision_levels: Vec<DecisionLevel>,
228
229    /// Current decision level.
230    current_level: DecisionLevel,
231}
232
233impl PartitionRefinement {
234    /// Create new refinement starting from finest partition.
235    pub fn new(terms: &[TermId]) -> Self {
236        Self {
237            partition: Partition::finest(terms),
238            history: Vec::new(),
239            decision_levels: Vec::new(),
240            current_level: 0,
241        }
242    }
243
244    /// Refine with equality.
245    pub fn refine(&mut self, eq: Equality) -> Result<(), String> {
246        self.history.push(self.partition.clone_partition());
247        self.decision_levels.push(self.current_level);
248        self.partition.merge(eq.lhs, eq.rhs)
249    }
250
251    /// Refine with multiple equalities.
252    pub fn refine_batch(&mut self, equalities: &[Equality]) -> Result<(), String> {
253        for &eq in equalities {
254            self.refine(eq)?;
255        }
256        Ok(())
257    }
258
259    /// Get current partition.
260    pub fn current(&self) -> &Partition {
261        &self.partition
262    }
263
264    /// Backtrack one step.
265    pub fn backtrack_step(&mut self) -> Result<(), String> {
266        self.partition = self.history.pop().ok_or("No refinement to backtrack")?;
267        self.decision_levels.pop();
268        Ok(())
269    }
270
271    /// Backtrack to decision level.
272    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
273        while !self.decision_levels.is_empty() {
274            if let Some(&last_level) = self.decision_levels.last() {
275                if last_level > level {
276                    self.backtrack_step()?;
277                } else {
278                    break;
279                }
280            } else {
281                break;
282            }
283        }
284
285        self.current_level = level;
286        Ok(())
287    }
288
289    /// Push decision level.
290    pub fn push_decision_level(&mut self) {
291        self.current_level += 1;
292    }
293
294    /// Clear history.
295    pub fn clear_history(&mut self) {
296        self.history.clear();
297        self.decision_levels.clear();
298    }
299}
300
301/// Partition enumerator using Restricted Growth Strings.
302pub struct PartitionEnumerator {
303    /// Number of elements.
304    n: usize,
305
306    /// Terms being partitioned.
307    terms: Vec<TermId>,
308
309    /// Current RGS (Restricted Growth String).
310    rgs: Vec<usize>,
311
312    /// Maximum value seen so far.
313    max_val: usize,
314
315    /// Is enumeration complete?
316    done: bool,
317}
318
319impl PartitionEnumerator {
320    /// Create new enumerator.
321    pub fn new(terms: Vec<TermId>) -> Self {
322        let n = terms.len();
323        Self {
324            n,
325            terms,
326            rgs: vec![0; n],
327            max_val: 0,
328            done: n == 0,
329        }
330    }
331
332    /// Get next partition.
333    #[allow(clippy::should_implement_trait)]
334    pub fn next(&mut self) -> Option<Partition> {
335        if self.done {
336            return None;
337        }
338
339        // Build partition from current RGS
340        let partition = self.rgs_to_partition();
341
342        // Generate next RGS
343        self.next_rgs();
344
345        Some(partition)
346    }
347
348    /// Convert RGS to partition.
349    fn rgs_to_partition(&self) -> Partition {
350        let mut classes: Vec<FxHashSet<TermId>> = vec![FxHashSet::default(); self.max_val + 1];
351        let mut term_to_class = FxHashMap::default();
352        let mut representatives = vec![0; self.max_val + 1];
353
354        for (i, &class_id) in self.rgs.iter().enumerate() {
355            let term = self.terms[i];
356            classes[class_id].insert(term);
357            term_to_class.insert(term, class_id);
358
359            if representatives[class_id] == 0 || term < representatives[class_id] {
360                representatives[class_id] = term;
361            }
362        }
363
364        Partition {
365            classes,
366            term_to_class,
367            representatives,
368        }
369    }
370
371    /// Generate next RGS.
372    fn next_rgs(&mut self) {
373        // Find rightmost position that can be incremented
374        let mut i = self.n;
375        while i > 0 {
376            i -= 1;
377
378            let can_increment = if i == 0 {
379                false
380            } else {
381                let max_up_to_i = self.rgs[..i].iter().max().copied().unwrap_or(0);
382                self.rgs[i] <= max_up_to_i
383            };
384
385            if can_increment {
386                self.rgs[i] += 1;
387
388                // Update max_val
389                self.max_val = self.rgs.iter().max().copied().unwrap_or(0);
390
391                // Reset suffix to 0
392                for j in (i + 1)..self.n {
393                    self.rgs[j] = 0;
394                }
395
396                return;
397            }
398        }
399
400        self.done = true;
401    }
402
403    /// Reset enumerator.
404    pub fn reset(&mut self) {
405        self.rgs = vec![0; self.n];
406        self.max_val = 0;
407        self.done = self.n == 0;
408    }
409
410    /// Get number of remaining partitions (approximate).
411    pub fn count_remaining(&self) -> usize {
412        // Bell number computation (simplified)
413        bell_number(self.n)
414    }
415}
416
417/// Compute Bell number B(n).
418fn bell_number(n: usize) -> usize {
419    if n == 0 {
420        return 1;
421    }
422
423    // Use Stirling numbers (simplified for small n)
424    match n {
425        0 => 1,
426        1 => 1,
427        2 => 2,
428        3 => 5,
429        4 => 15,
430        5 => 52,
431        6 => 203,
432        7 => 877,
433        8 => 4140,
434        _ => usize::MAX, // Too large
435    }
436}
437
438/// Configuration for partition refinement.
439#[derive(Debug, Clone)]
440pub struct PartitionRefinementConfig {
441    /// Enable partition enumeration.
442    pub enable_enumeration: bool,
443
444    /// Maximum partitions to enumerate.
445    pub max_partitions: usize,
446
447    /// Enable constraint-guided refinement.
448    pub constraint_guided: bool,
449
450    /// Enable backtracking.
451    pub enable_backtracking: bool,
452}
453
454impl Default for PartitionRefinementConfig {
455    fn default() -> Self {
456        Self {
457            enable_enumeration: true,
458            max_partitions: 1000,
459            constraint_guided: true,
460            enable_backtracking: true,
461        }
462    }
463}
464
465/// Statistics for partition refinement.
466#[derive(Debug, Clone, Default)]
467pub struct PartitionRefinementStats {
468    /// Refinements performed.
469    pub refinements: u64,
470    /// Partitions enumerated.
471    pub partitions_enumerated: u64,
472    /// Backtracks.
473    pub backtracks: u64,
474    /// Constraints applied.
475    pub constraints_applied: u64,
476}
477
478/// Partition refinement manager.
479pub struct PartitionRefinementManager {
480    /// Configuration.
481    config: PartitionRefinementConfig,
482
483    /// Statistics.
484    stats: PartitionRefinementStats,
485
486    /// Refinement algorithm.
487    refinement: PartitionRefinement,
488
489    /// Enumerator (if enumeration enabled).
490    enumerator: Option<PartitionEnumerator>,
491
492    /// Constraint queue.
493    constraints: VecDeque<Equality>,
494}
495
496impl PartitionRefinementManager {
497    /// Create new manager.
498    pub fn new(terms: Vec<TermId>) -> Self {
499        Self::with_config(terms, PartitionRefinementConfig::default())
500    }
501
502    /// Create with configuration.
503    pub fn with_config(terms: Vec<TermId>, config: PartitionRefinementConfig) -> Self {
504        let enumerator = if config.enable_enumeration {
505            Some(PartitionEnumerator::new(terms.clone()))
506        } else {
507            None
508        };
509
510        Self {
511            config,
512            stats: PartitionRefinementStats::default(),
513            refinement: PartitionRefinement::new(&terms),
514            enumerator,
515            constraints: VecDeque::new(),
516        }
517    }
518
519    /// Get statistics.
520    pub fn stats(&self) -> &PartitionRefinementStats {
521        &self.stats
522    }
523
524    /// Add constraint.
525    pub fn add_constraint(&mut self, eq: Equality) {
526        self.constraints.push_back(eq);
527        self.stats.constraints_applied += 1;
528    }
529
530    /// Apply constraints and refine.
531    pub fn apply_constraints(&mut self) -> Result<(), String> {
532        while let Some(eq) = self.constraints.pop_front() {
533            self.refinement.refine(eq)?;
534            self.stats.refinements += 1;
535        }
536        Ok(())
537    }
538
539    /// Get current partition.
540    pub fn current_partition(&self) -> &Partition {
541        self.refinement.current()
542    }
543
544    /// Get next enumerated partition.
545    pub fn next_partition(&mut self) -> Option<Partition> {
546        if let Some(ref mut enumerator) = self.enumerator {
547            if self.stats.partitions_enumerated >= self.config.max_partitions as u64 {
548                return None;
549            }
550
551            let partition = enumerator.next();
552            if partition.is_some() {
553                self.stats.partitions_enumerated += 1;
554            }
555            partition
556        } else {
557            None
558        }
559    }
560
561    /// Backtrack to decision level.
562    pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
563        if !self.config.enable_backtracking {
564            return Ok(());
565        }
566
567        self.refinement.backtrack(level)?;
568        self.stats.backtracks += 1;
569        Ok(())
570    }
571
572    /// Push decision level.
573    pub fn push_decision_level(&mut self) {
574        self.refinement.push_decision_level();
575    }
576
577    /// Clear all state.
578    pub fn clear(&mut self) {
579        self.refinement.clear_history();
580        self.constraints.clear();
581
582        if let Some(ref mut enumerator) = self.enumerator {
583            enumerator.reset();
584        }
585    }
586
587    /// Reset statistics.
588    pub fn reset_stats(&mut self) {
589        self.stats = PartitionRefinementStats::default();
590    }
591}
592
593/// Partition comparison utilities.
594pub struct PartitionComparator;
595
596impl PartitionComparator {
597    /// Check if p1 is finer than p2.
598    pub fn is_finer(p1: &Partition, p2: &Partition) -> bool {
599        // p1 is finer if every class in p1 is a subset of some class in p2
600        for class1 in &p1.classes {
601            if class1.is_empty() {
602                continue;
603            }
604
605            // Check if all terms in class1 are in the same class in p2
606            let first_term = *class1.iter().next().expect("Non-empty class");
607            let p2_class = p2.term_to_class.get(&first_term);
608
609            for &term in class1 {
610                if p2.term_to_class.get(&term) != p2_class {
611                    return false;
612                }
613            }
614        }
615
616        true
617    }
618
619    /// Check if partitions are equal.
620    pub fn are_equal(p1: &Partition, p2: &Partition) -> bool {
621        Self::is_finer(p1, p2) && Self::is_finer(p2, p1)
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[test]
630    fn test_finest_partition() {
631        let terms = vec![1, 2, 3];
632        let partition = Partition::finest(&terms);
633
634        assert_eq!(partition.num_classes(), 3);
635        assert!(!partition.are_equal(1, 2));
636    }
637
638    #[test]
639    fn test_coarsest_partition() {
640        let terms = vec![1, 2, 3];
641        let partition = Partition::coarsest(&terms);
642
643        assert_eq!(partition.num_classes(), 1);
644        assert!(partition.are_equal(1, 2));
645        assert!(partition.are_equal(2, 3));
646    }
647
648    #[test]
649    fn test_partition_merge() {
650        let terms = vec![1, 2, 3, 4];
651        let mut partition = Partition::finest(&terms);
652
653        partition.merge(1, 2).expect("Merge failed");
654        assert_eq!(partition.num_classes(), 3);
655        assert!(partition.are_equal(1, 2));
656        assert!(!partition.are_equal(1, 3));
657    }
658
659    #[test]
660    fn test_partition_equalities() {
661        let terms = vec![1, 2, 3];
662        let mut partition = Partition::finest(&terms);
663
664        partition.merge(1, 2).expect("Merge failed");
665        partition.merge(2, 3).expect("Merge failed");
666
667        let equalities = partition.get_equalities();
668        assert_eq!(equalities.len(), 2); // Star topology
669    }
670
671    #[test]
672    fn test_refinement() {
673        let terms = vec![1, 2, 3, 4];
674        let mut refinement = PartitionRefinement::new(&terms);
675
676        refinement
677            .refine(Equality::new(1, 2))
678            .expect("Refine failed");
679        assert!(refinement.current().are_equal(1, 2));
680    }
681
682    #[test]
683    fn test_refinement_backtrack() {
684        let terms = vec![1, 2, 3, 4];
685        let mut refinement = PartitionRefinement::new(&terms);
686
687        refinement
688            .refine(Equality::new(1, 2))
689            .expect("Refine failed");
690        refinement.backtrack_step().expect("Backtrack failed");
691
692        assert!(!refinement.current().are_equal(1, 2));
693    }
694
695    #[test]
696    fn test_bell_number() {
697        assert_eq!(bell_number(0), 1);
698        assert_eq!(bell_number(1), 1);
699        assert_eq!(bell_number(2), 2);
700        assert_eq!(bell_number(3), 5);
701        assert_eq!(bell_number(4), 15);
702    }
703
704    #[test]
705    fn test_partition_enumerator() {
706        let terms = vec![1, 2, 3];
707        let mut enumerator = PartitionEnumerator::new(terms);
708
709        let mut count = 0;
710        while enumerator.next().is_some() {
711            count += 1;
712        }
713
714        assert_eq!(count, 5); // B(3) = 5
715    }
716
717    #[test]
718    fn test_manager() {
719        let terms = vec![1, 2, 3];
720        let mut manager = PartitionRefinementManager::new(terms);
721
722        manager.add_constraint(Equality::new(1, 2));
723        manager.apply_constraints().expect("Apply failed");
724
725        assert!(manager.current_partition().are_equal(1, 2));
726    }
727
728    #[test]
729    fn test_partition_comparison() {
730        let terms = vec![1, 2, 3];
731
732        let finest = Partition::finest(&terms);
733        let coarsest = Partition::coarsest(&terms);
734
735        assert!(PartitionComparator::is_finer(&finest, &coarsest));
736        assert!(!PartitionComparator::is_finer(&coarsest, &finest));
737    }
738
739    #[test]
740    fn test_representative() {
741        let terms = vec![1, 2, 3];
742        let mut partition = Partition::finest(&terms);
743
744        partition.merge(1, 2).expect("Merge failed");
745
746        let rep1 = partition.get_representative(1);
747        let rep2 = partition.get_representative(2);
748
749        assert_eq!(rep1, rep2);
750    }
751
752    #[test]
753    fn test_get_class() {
754        let terms = vec![1, 2, 3, 4];
755        let mut partition = Partition::finest(&terms);
756
757        partition.merge(1, 2).expect("Merge failed");
758        partition.merge(2, 3).expect("Merge failed");
759
760        let class = partition.get_class(1).expect("No class");
761        assert_eq!(class.len(), 3);
762        assert!(class.contains(&1));
763        assert!(class.contains(&2));
764        assert!(class.contains(&3));
765    }
766}