use crate::scalar::{self, ScalarBitSet};
use alloc::boxed::Box;
use core::{cmp, iter, mem};
#[derive(Clone, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "enable-serde",
derive(serde_derive::Serialize, serde_derive::Deserialize)
)]
pub struct CompoundBitSet {
elems: Box<[ScalarBitSet<usize>]>,
max: Option<u32>,
}
impl core::fmt::Debug for CompoundBitSet {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "CompoundBitSet ")?;
f.debug_set().entries(self.iter()).finish()
}
}
const BITS_PER_WORD: usize = mem::size_of::<usize>() * 8;
impl CompoundBitSet {
#[inline]
pub fn new() -> Self {
CompoundBitSet::default()
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
let mut bitset = Self::new();
bitset.ensure_capacity(capacity);
bitset
}
#[inline]
pub fn len(&self) -> usize {
self.elems.iter().map(|sub| usize::from(sub.len())).sum()
}
pub fn capacity(&self) -> usize {
self.elems.len() * BITS_PER_WORD
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn word_and_bit(i: usize) -> (usize, u8) {
let word = i / BITS_PER_WORD;
let bit = i % BITS_PER_WORD;
let bit = u8::try_from(bit).unwrap();
(word, bit)
}
#[inline]
fn elem(word: usize, bit: u8) -> usize {
let bit = usize::from(bit);
debug_assert!(bit < BITS_PER_WORD);
word * BITS_PER_WORD + bit
}
#[inline]
pub fn contains(&self, i: usize) -> bool {
let (word, bit) = Self::word_and_bit(i);
if word < self.elems.len() {
self.elems[word].contains(bit)
} else {
false
}
}
#[inline]
pub fn ensure_capacity(&mut self, n: usize) {
let (word, _bit) = Self::word_and_bit(n);
if word >= self.elems.len() {
assert!(word < usize::try_from(isize::MAX).unwrap());
let delta = word - self.elems.len();
let to_grow = delta + 1;
let to_grow = cmp::max(to_grow, self.elems.len() * 2);
let to_grow = cmp::max(to_grow, 4);
let new_elems = self
.elems
.iter()
.copied()
.chain(iter::repeat(ScalarBitSet::new()).take(to_grow))
.collect::<Box<[_]>>();
self.elems = new_elems;
}
}
#[inline]
pub fn insert(&mut self, i: usize) -> bool {
self.ensure_capacity(i + 1);
let (word, bit) = Self::word_and_bit(i);
let is_new = self.elems[word].insert(bit);
let i = u32::try_from(i).unwrap();
self.max = self.max.map(|max| cmp::max(max, i)).or(Some(i));
is_new
}
#[inline]
pub fn remove(&mut self, i: usize) -> bool {
let (word, bit) = Self::word_and_bit(i);
if word < self.elems.len() {
let sub = &mut self.elems[word];
let was_present = sub.remove(bit);
if was_present && self.max() == Some(i) {
self.update_max(word);
}
was_present
} else {
false
}
}
fn update_max(&mut self, word_of_old_max: usize) {
self.max = self.elems[0..word_of_old_max + 1]
.iter()
.enumerate()
.rev()
.filter_map(|(word, sub)| {
let bit = sub.max()?;
Some(u32::try_from(Self::elem(word, bit)).unwrap())
})
.next();
}
#[inline]
pub fn max(&self) -> Option<usize> {
self.max.map(|m| usize::try_from(m).unwrap())
}
#[inline]
pub fn pop(&mut self) -> Option<usize> {
let max = self.max()?;
self.remove(max);
Some(max)
}
#[inline]
pub fn clear(&mut self) {
let max = match self.max() {
Some(max) => max,
None => return,
};
let (word, _bit) = Self::word_and_bit(max);
debug_assert!(self.elems[word + 1..].iter().all(|sub| sub.is_empty()));
for sub in &mut self.elems[..=word] {
*sub = ScalarBitSet::new();
}
self.max = None;
}
#[inline]
pub fn iter(&self) -> Iter<'_> {
Iter {
bitset: self,
word: 0,
sub: None,
}
}
}
impl<'a> IntoIterator for &'a CompoundBitSet {
type Item = usize;
type IntoIter = Iter<'a>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct Iter<'a> {
bitset: &'a CompoundBitSet,
word: usize,
sub: Option<scalar::Iter<usize>>,
}
impl Iterator for Iter<'_> {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<usize> {
loop {
if let Some(sub) = &mut self.sub {
if let Some(bit) = sub.next() {
return Some(CompoundBitSet::elem(self.word, bit));
} else {
self.word += 1;
}
}
self.sub = Some(self.bitset.elems.get(self.word)?.iter());
}
}
}