#![allow(missing_docs)]
#![allow(dead_code)]
#[allow(unused_imports)]
use crate::prelude::*;
pub type TermId = u32;
pub type TheoryId = u32;
pub type DecisionLevel = u32;
pub type ClassId = usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Equality {
pub lhs: TermId,
pub rhs: TermId,
}
impl Equality {
pub fn new(lhs: TermId, rhs: TermId) -> Self {
if lhs <= rhs {
Self { lhs, rhs }
} else {
Self { lhs: rhs, rhs: lhs }
}
}
}
#[derive(Debug, Clone)]
pub struct Partition {
classes: Vec<FxHashSet<TermId>>,
term_to_class: FxHashMap<TermId, ClassId>,
representatives: Vec<TermId>,
}
impl Partition {
pub fn finest(terms: &[TermId]) -> Self {
let mut classes = Vec::new();
let mut term_to_class = FxHashMap::default();
let mut representatives = Vec::new();
for (i, &term) in terms.iter().enumerate() {
let mut class = FxHashSet::default();
class.insert(term);
classes.push(class);
term_to_class.insert(term, i);
representatives.push(term);
}
Self {
classes,
term_to_class,
representatives,
}
}
pub fn coarsest(terms: &[TermId]) -> Self {
if terms.is_empty() {
return Self {
classes: Vec::new(),
term_to_class: FxHashMap::default(),
representatives: Vec::new(),
};
}
let mut class = FxHashSet::default();
let mut term_to_class = FxHashMap::default();
for &term in terms {
class.insert(term);
term_to_class.insert(term, 0);
}
Self {
classes: vec![class],
term_to_class,
representatives: vec![terms[0]],
}
}
pub fn merge(&mut self, t1: TermId, t2: TermId) -> Result<(), String> {
let c1 = *self.term_to_class.get(&t1).ok_or("Term not in partition")?;
let c2 = *self.term_to_class.get(&t2).ok_or("Term not in partition")?;
if c1 == c2 {
return Ok(());
}
let (src, dst) = if self.classes[c1].len() < self.classes[c2].len() {
(c1, c2)
} else {
(c2, c1)
};
let src_terms: Vec<_> = self.classes[src].iter().copied().collect();
for term in src_terms {
self.classes[dst].insert(term);
self.term_to_class.insert(term, dst);
}
self.classes[src].clear();
Ok(())
}
pub fn get_equalities(&self) -> Vec<Equality> {
let mut equalities = Vec::new();
for class in &self.classes {
if class.len() > 1 {
let terms: Vec<_> = class.iter().copied().collect();
let rep = terms[0];
for &term in &terms[1..] {
equalities.push(Equality::new(rep, term));
}
}
}
equalities
}
pub fn num_classes(&self) -> usize {
self.classes.iter().filter(|c| !c.is_empty()).count()
}
pub fn are_equal(&self, t1: TermId, t2: TermId) -> bool {
if let (Some(&c1), Some(&c2)) = (self.term_to_class.get(&t1), self.term_to_class.get(&t2)) {
c1 == c2
} else {
false
}
}
pub fn get_representative(&self, term: TermId) -> Option<TermId> {
self.term_to_class
.get(&term)
.and_then(|&class_id| self.representatives.get(class_id))
.copied()
}
pub fn get_class(&self, term: TermId) -> Option<&FxHashSet<TermId>> {
self.term_to_class
.get(&term)
.and_then(|&class_id| self.classes.get(class_id))
}
pub fn clone_partition(&self) -> Partition {
self.clone()
}
}
pub struct PartitionRefinement {
partition: Partition,
history: Vec<Partition>,
decision_levels: Vec<DecisionLevel>,
current_level: DecisionLevel,
}
impl PartitionRefinement {
pub fn new(terms: &[TermId]) -> Self {
Self {
partition: Partition::finest(terms),
history: Vec::new(),
decision_levels: Vec::new(),
current_level: 0,
}
}
pub fn refine(&mut self, eq: Equality) -> Result<(), String> {
self.history.push(self.partition.clone_partition());
self.decision_levels.push(self.current_level);
self.partition.merge(eq.lhs, eq.rhs)
}
pub fn refine_batch(&mut self, equalities: &[Equality]) -> Result<(), String> {
for &eq in equalities {
self.refine(eq)?;
}
Ok(())
}
pub fn current(&self) -> &Partition {
&self.partition
}
pub fn backtrack_step(&mut self) -> Result<(), String> {
self.partition = self.history.pop().ok_or("No refinement to backtrack")?;
self.decision_levels.pop();
Ok(())
}
pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
while !self.decision_levels.is_empty() {
if let Some(&last_level) = self.decision_levels.last() {
if last_level > level {
self.backtrack_step()?;
} else {
break;
}
} else {
break;
}
}
self.current_level = level;
Ok(())
}
pub fn push_decision_level(&mut self) {
self.current_level += 1;
}
pub fn clear_history(&mut self) {
self.history.clear();
self.decision_levels.clear();
}
}
pub struct PartitionEnumerator {
n: usize,
terms: Vec<TermId>,
rgs: Vec<usize>,
max_val: usize,
done: bool,
}
impl PartitionEnumerator {
pub fn new(terms: Vec<TermId>) -> Self {
let n = terms.len();
Self {
n,
terms,
rgs: vec![0; n],
max_val: 0,
done: n == 0,
}
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Option<Partition> {
if self.done {
return None;
}
let partition = self.rgs_to_partition();
self.next_rgs();
Some(partition)
}
fn rgs_to_partition(&self) -> Partition {
let mut classes: Vec<FxHashSet<TermId>> = vec![FxHashSet::default(); self.max_val + 1];
let mut term_to_class = FxHashMap::default();
let mut representatives = vec![0; self.max_val + 1];
for (i, &class_id) in self.rgs.iter().enumerate() {
let term = self.terms[i];
classes[class_id].insert(term);
term_to_class.insert(term, class_id);
if representatives[class_id] == 0 || term < representatives[class_id] {
representatives[class_id] = term;
}
}
Partition {
classes,
term_to_class,
representatives,
}
}
fn next_rgs(&mut self) {
let mut i = self.n;
while i > 0 {
i -= 1;
let can_increment = if i == 0 {
false
} else {
let max_up_to_i = self.rgs[..i].iter().max().copied().unwrap_or(0);
self.rgs[i] <= max_up_to_i
};
if can_increment {
self.rgs[i] += 1;
self.max_val = self.rgs.iter().max().copied().unwrap_or(0);
for j in (i + 1)..self.n {
self.rgs[j] = 0;
}
return;
}
}
self.done = true;
}
pub fn reset(&mut self) {
self.rgs = vec![0; self.n];
self.max_val = 0;
self.done = self.n == 0;
}
pub fn count_remaining(&self) -> usize {
bell_number(self.n)
}
}
fn bell_number(n: usize) -> usize {
if n == 0 {
return 1;
}
match n {
0 => 1,
1 => 1,
2 => 2,
3 => 5,
4 => 15,
5 => 52,
6 => 203,
7 => 877,
8 => 4140,
_ => usize::MAX, }
}
#[derive(Debug, Clone)]
pub struct PartitionRefinementConfig {
pub enable_enumeration: bool,
pub max_partitions: usize,
pub constraint_guided: bool,
pub enable_backtracking: bool,
}
impl Default for PartitionRefinementConfig {
fn default() -> Self {
Self {
enable_enumeration: true,
max_partitions: 1000,
constraint_guided: true,
enable_backtracking: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PartitionRefinementStats {
pub refinements: u64,
pub partitions_enumerated: u64,
pub backtracks: u64,
pub constraints_applied: u64,
}
pub struct PartitionRefinementManager {
config: PartitionRefinementConfig,
stats: PartitionRefinementStats,
refinement: PartitionRefinement,
enumerator: Option<PartitionEnumerator>,
constraints: VecDeque<Equality>,
}
impl PartitionRefinementManager {
pub fn new(terms: Vec<TermId>) -> Self {
Self::with_config(terms, PartitionRefinementConfig::default())
}
pub fn with_config(terms: Vec<TermId>, config: PartitionRefinementConfig) -> Self {
let enumerator = if config.enable_enumeration {
Some(PartitionEnumerator::new(terms.clone()))
} else {
None
};
Self {
config,
stats: PartitionRefinementStats::default(),
refinement: PartitionRefinement::new(&terms),
enumerator,
constraints: VecDeque::new(),
}
}
pub fn stats(&self) -> &PartitionRefinementStats {
&self.stats
}
pub fn add_constraint(&mut self, eq: Equality) {
self.constraints.push_back(eq);
self.stats.constraints_applied += 1;
}
pub fn apply_constraints(&mut self) -> Result<(), String> {
while let Some(eq) = self.constraints.pop_front() {
self.refinement.refine(eq)?;
self.stats.refinements += 1;
}
Ok(())
}
pub fn current_partition(&self) -> &Partition {
self.refinement.current()
}
pub fn next_partition(&mut self) -> Option<Partition> {
if let Some(ref mut enumerator) = self.enumerator {
if self.stats.partitions_enumerated >= self.config.max_partitions as u64 {
return None;
}
let partition = enumerator.next();
if partition.is_some() {
self.stats.partitions_enumerated += 1;
}
partition
} else {
None
}
}
pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
if !self.config.enable_backtracking {
return Ok(());
}
self.refinement.backtrack(level)?;
self.stats.backtracks += 1;
Ok(())
}
pub fn push_decision_level(&mut self) {
self.refinement.push_decision_level();
}
pub fn clear(&mut self) {
self.refinement.clear_history();
self.constraints.clear();
if let Some(ref mut enumerator) = self.enumerator {
enumerator.reset();
}
}
pub fn reset_stats(&mut self) {
self.stats = PartitionRefinementStats::default();
}
}
pub struct PartitionComparator;
impl PartitionComparator {
pub fn is_finer(p1: &Partition, p2: &Partition) -> bool {
for class1 in &p1.classes {
if class1.is_empty() {
continue;
}
let first_term = *class1.iter().next().expect("Non-empty class");
let p2_class = p2.term_to_class.get(&first_term);
for &term in class1 {
if p2.term_to_class.get(&term) != p2_class {
return false;
}
}
}
true
}
pub fn are_equal(p1: &Partition, p2: &Partition) -> bool {
Self::is_finer(p1, p2) && Self::is_finer(p2, p1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_finest_partition() {
let terms = vec![1, 2, 3];
let partition = Partition::finest(&terms);
assert_eq!(partition.num_classes(), 3);
assert!(!partition.are_equal(1, 2));
}
#[test]
fn test_coarsest_partition() {
let terms = vec![1, 2, 3];
let partition = Partition::coarsest(&terms);
assert_eq!(partition.num_classes(), 1);
assert!(partition.are_equal(1, 2));
assert!(partition.are_equal(2, 3));
}
#[test]
fn test_partition_merge() {
let terms = vec![1, 2, 3, 4];
let mut partition = Partition::finest(&terms);
partition.merge(1, 2).expect("Merge failed");
assert_eq!(partition.num_classes(), 3);
assert!(partition.are_equal(1, 2));
assert!(!partition.are_equal(1, 3));
}
#[test]
fn test_partition_equalities() {
let terms = vec![1, 2, 3];
let mut partition = Partition::finest(&terms);
partition.merge(1, 2).expect("Merge failed");
partition.merge(2, 3).expect("Merge failed");
let equalities = partition.get_equalities();
assert_eq!(equalities.len(), 2); }
#[test]
fn test_refinement() {
let terms = vec![1, 2, 3, 4];
let mut refinement = PartitionRefinement::new(&terms);
refinement
.refine(Equality::new(1, 2))
.expect("Refine failed");
assert!(refinement.current().are_equal(1, 2));
}
#[test]
fn test_refinement_backtrack() {
let terms = vec![1, 2, 3, 4];
let mut refinement = PartitionRefinement::new(&terms);
refinement
.refine(Equality::new(1, 2))
.expect("Refine failed");
refinement.backtrack_step().expect("Backtrack failed");
assert!(!refinement.current().are_equal(1, 2));
}
#[test]
fn test_bell_number() {
assert_eq!(bell_number(0), 1);
assert_eq!(bell_number(1), 1);
assert_eq!(bell_number(2), 2);
assert_eq!(bell_number(3), 5);
assert_eq!(bell_number(4), 15);
}
#[test]
fn test_partition_enumerator() {
let terms = vec![1, 2, 3];
let mut enumerator = PartitionEnumerator::new(terms);
let mut count = 0;
while enumerator.next().is_some() {
count += 1;
}
assert_eq!(count, 5); }
#[test]
fn test_manager() {
let terms = vec![1, 2, 3];
let mut manager = PartitionRefinementManager::new(terms);
manager.add_constraint(Equality::new(1, 2));
manager.apply_constraints().expect("Apply failed");
assert!(manager.current_partition().are_equal(1, 2));
}
#[test]
fn test_partition_comparison() {
let terms = vec![1, 2, 3];
let finest = Partition::finest(&terms);
let coarsest = Partition::coarsest(&terms);
assert!(PartitionComparator::is_finer(&finest, &coarsest));
assert!(!PartitionComparator::is_finer(&coarsest, &finest));
}
#[test]
fn test_representative() {
let terms = vec![1, 2, 3];
let mut partition = Partition::finest(&terms);
partition.merge(1, 2).expect("Merge failed");
let rep1 = partition.get_representative(1);
let rep2 = partition.get_representative(2);
assert_eq!(rep1, rep2);
}
#[test]
fn test_get_class() {
let terms = vec![1, 2, 3, 4];
let mut partition = Partition::finest(&terms);
partition.merge(1, 2).expect("Merge failed");
partition.merge(2, 3).expect("Merge failed");
let class = partition.get_class(1).expect("No class");
assert_eq!(class.len(), 3);
assert!(class.contains(&1));
assert!(class.contains(&2));
assert!(class.contains(&3));
}
}