use crate::scalar::{self, ScalarBitSet, ScalarBitSetStorage};
use alloc::boxed::Box;
use core::{cmp, iter, mem};
use wasmtime_core::alloc::{TryExtend, TryVec};
use wasmtime_core::error::OutOfMemory;
#[derive(Clone, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "enable-serde",
derive(serde_derive::Serialize, serde_derive::Deserialize)
)]
pub struct CompoundBitSet<T = usize> {
elems: Box<[ScalarBitSet<T>]>,
max: Option<u32>,
}
impl core::fmt::Debug for CompoundBitSet {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "CompoundBitSet ")?;
f.debug_set().entries(self.iter()).finish()
}
}
impl CompoundBitSet {
#[inline]
pub fn new() -> Self {
CompoundBitSet::default()
}
}
impl<T: ScalarBitSetStorage> CompoundBitSet<T> {
const BITS_PER_SCALAR: usize = mem::size_of::<T>() * 8;
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self::try_with_capacity(capacity).unwrap()
}
#[inline]
pub fn try_with_capacity(capacity: usize) -> Result<Self, OutOfMemory> {
let mut bitset = Self::default();
bitset.try_ensure_capacity(capacity)?;
Ok(bitset)
}
#[inline]
pub fn len(&self) -> usize {
self.elems.iter().map(|sub| usize::from(sub.len())).sum()
}
pub fn capacity(&self) -> usize {
self.elems.len() * Self::BITS_PER_SCALAR
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn word_and_bit(i: usize) -> (usize, u8) {
let word = i / Self::BITS_PER_SCALAR;
let bit = i % Self::BITS_PER_SCALAR;
let bit = u8::try_from(bit).unwrap();
(word, bit)
}
#[inline]
fn elem(word: usize, bit: u8) -> usize {
let bit = usize::from(bit);
debug_assert!(bit < Self::BITS_PER_SCALAR);
word * Self::BITS_PER_SCALAR + bit
}
#[inline]
pub fn contains(&self, i: usize) -> bool {
let (word, bit) = Self::word_and_bit(i);
if word < self.elems.len() {
self.elems[word].contains(bit)
} else {
false
}
}
#[inline]
pub fn ensure_capacity(&mut self, n: usize) {
self.try_ensure_capacity(n).unwrap()
}
#[inline]
pub fn try_ensure_capacity(&mut self, n: usize) -> Result<(), OutOfMemory> {
let (word, _bit) = Self::word_and_bit(match n.checked_sub(1) {
None => return Ok(()),
Some(n) => n,
});
if word < self.elems.len() {
return Ok(());
}
assert!(word < usize::try_from(isize::MAX).unwrap());
let delta = word - self.elems.len();
let to_grow = delta + 1;
let to_grow = cmp::max(to_grow, self.elems.len());
let to_grow = cmp::max(to_grow, 4);
let mut new_elems = TryVec::from(mem::take(&mut self.elems));
new_elems.reserve_exact(to_grow)?;
new_elems.try_extend(iter::repeat(ScalarBitSet::new()).take(to_grow))?;
self.elems = new_elems.into_boxed_slice()?;
Ok(())
}
#[inline]
pub fn insert(&mut self, i: usize) -> bool {
self.ensure_capacity(i + 1);
let (word, bit) = Self::word_and_bit(i);
let is_new = self.elems[word].insert(bit);
let i = u32::try_from(i).unwrap();
self.max = self.max.map(|max| cmp::max(max, i)).or(Some(i));
is_new
}
#[inline]
pub fn remove(&mut self, i: usize) -> bool {
let (word, bit) = Self::word_and_bit(i);
if word < self.elems.len() {
let sub = &mut self.elems[word];
let was_present = sub.remove(bit);
if was_present && self.max() == Some(i) {
self.update_max(word);
}
was_present
} else {
false
}
}
fn update_max(&mut self, word_of_old_max: usize) {
self.max = self.elems[0..word_of_old_max + 1]
.iter()
.enumerate()
.rev()
.filter_map(|(word, sub)| {
let bit = sub.max()?;
Some(u32::try_from(Self::elem(word, bit)).unwrap())
})
.next();
}
#[inline]
pub fn max(&self) -> Option<usize> {
self.max.map(|m| usize::try_from(m).unwrap())
}
#[inline]
pub fn pop(&mut self) -> Option<usize> {
let max = self.max()?;
self.remove(max);
Some(max)
}
#[inline]
pub fn clear(&mut self) {
let max = match self.max() {
Some(max) => max,
None => return,
};
let (word, _bit) = Self::word_and_bit(max);
debug_assert!(self.elems[word + 1..].iter().all(|sub| sub.is_empty()));
for sub in &mut self.elems[..=word] {
*sub = ScalarBitSet::new();
}
self.max = None;
}
#[inline]
pub fn iter(&self) -> Iter<'_, T> {
Iter {
bitset: self,
word: 0,
sub: None,
}
}
pub fn iter_scalars(&self) -> impl Iterator<Item = ScalarBitSet<T>> + '_ {
let nwords = match self.max {
Some(n) => 1 + (n as usize / Self::BITS_PER_SCALAR),
None => 0,
};
self.elems.iter().copied().take(nwords)
}
}
impl<'a, T: ScalarBitSetStorage> IntoIterator for &'a CompoundBitSet<T> {
type Item = usize;
type IntoIter = Iter<'a, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct Iter<'a, T = usize> {
bitset: &'a CompoundBitSet<T>,
word: usize,
sub: Option<scalar::Iter<T>>,
}
impl<T: ScalarBitSetStorage> Iterator for Iter<'_, T> {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<usize> {
loop {
if let Some(sub) = &mut self.sub {
if let Some(bit) = sub.next() {
return Some(CompoundBitSet::<T>::elem(self.word, bit));
} else {
self.word += 1;
}
}
self.sub = Some(self.bitset.elems.get(self.word)?.iter());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_capacity_no_allocs() {
let set = CompoundBitSet::<u32>::with_capacity(0);
assert_eq!(set.capacity(), 0);
let set = CompoundBitSet::new();
assert_eq!(set.capacity(), 0);
}
}