use crate::nonce::Nonce;
use mirai_annotations::checked_verify;
use std::{
collections::{BTreeMap, BTreeSet},
usize::MAX,
};
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct Partition {
nonce_to_id: BTreeMap<Nonce, usize>,
id_to_nonce_set: BTreeMap<usize, BTreeSet<Nonce>>,
}
impl Partition {
pub fn add_nonce(&mut self, new_nonce: Nonce) {
let nonce_const = new_nonce.inner();
self.nonce_to_id.insert(new_nonce.clone(), nonce_const);
let mut singleton_set = BTreeSet::new();
singleton_set.insert(new_nonce);
self.id_to_nonce_set.insert(nonce_const, singleton_set);
}
pub fn remove_nonce(&mut self, nonce: Nonce) {
let id = self.nonce_to_id.remove(&nonce).unwrap();
self.id_to_nonce_set.entry(id).and_modify(|x| {
x.remove(&nonce);
});
if self.id_to_nonce_set[&id].is_empty() {
self.id_to_nonce_set.remove(&id).unwrap();
}
}
pub fn add_equality(&mut self, nonce1: Nonce, nonce2: Nonce) {
let id1 = self.nonce_to_id[&nonce1];
let id2 = self.nonce_to_id[&nonce2];
if id1 == id2 {
return;
}
let mut nonce_set2 = self.id_to_nonce_set.remove(&id2).unwrap();
for nonce in &nonce_set2 {
self.nonce_to_id
.entry(nonce.clone())
.and_modify(|x| *x = id1);
}
self.id_to_nonce_set.entry(id1).and_modify(|x| {
x.append(&mut nonce_set2);
});
}
pub fn is_equal(&self, nonce1: Nonce, nonce2: Nonce) -> bool {
self.nonce_to_id[&nonce1] == self.nonce_to_id[&nonce2]
}
pub fn construct_canonical_partition(&self, nonce_map: &BTreeMap<Nonce, Nonce>) -> Self {
let mut id_to_nonce_set = BTreeMap::new();
for nonce_set in self.id_to_nonce_set.values() {
let canonical_nonce_set: BTreeSet<Nonce> = nonce_set
.iter()
.map(|nonce| nonce_map[nonce].clone())
.collect();
let canonical_id = Self::canonical_id(&canonical_nonce_set);
id_to_nonce_set.insert(canonical_id, canonical_nonce_set);
}
let nonce_to_id = Self::compute_nonce_to_id(&id_to_nonce_set);
Self {
nonce_to_id,
id_to_nonce_set,
}
}
pub fn nonces(&self) -> BTreeSet<Nonce> {
self.nonce_to_id.keys().cloned().collect()
}
pub fn join(&self, partition: &Partition) -> Self {
checked_verify!(self.nonces() == partition.nonces());
let mut nonce_to_id_pair = BTreeMap::new();
let mut id_pair_to_nonce_set = BTreeMap::new();
for (nonce, id) in self.nonce_to_id.iter() {
let id_pair = (id, partition.nonce_to_id[nonce]);
nonce_to_id_pair.insert(nonce.clone(), id_pair);
id_pair_to_nonce_set.entry(id_pair).or_insert({
let nonce_set_for_id_pair: BTreeSet<Nonce> = self.id_to_nonce_set[&id_pair.0]
.intersection(&partition.id_to_nonce_set[&id_pair.1])
.cloned()
.collect();
nonce_set_for_id_pair
});
}
let id_to_nonce_set: BTreeMap<usize, BTreeSet<Nonce>> = id_pair_to_nonce_set
.into_iter()
.map(|(_, nonce_set)| (Self::canonical_id(&nonce_set), nonce_set))
.collect();
let nonce_to_id = Self::compute_nonce_to_id(&id_to_nonce_set);
Self {
nonce_to_id,
id_to_nonce_set,
}
}
fn canonical_id(nonce_set: &BTreeSet<Nonce>) -> usize {
let mut minimum_id = MAX;
for nonce in nonce_set {
let id = nonce.inner();
if minimum_id > id {
minimum_id = id;
}
}
minimum_id
}
fn compute_nonce_to_id(
id_to_nonce_set: &BTreeMap<usize, BTreeSet<Nonce>>,
) -> BTreeMap<Nonce, usize> {
let mut nonce_to_id = BTreeMap::new();
for (id, nonce_set) in id_to_nonce_set.iter() {
for nonce in nonce_set {
nonce_to_id.insert(nonce.clone(), id.clone());
}
}
nonce_to_id
}
}