use std::{fmt, iter, ops::Range, ops::RangeBounds};
use sorted_iter::sorted_iterator::SortedByItem;
use crate::{div_ceil, safe_n_mask};
#[cfg(test)]
mod tests;
trait BlockT {
const BITS64: usize;
}
impl BlockT for u32 {
const BITS64: usize = u32::BITS as usize;
}
#[derive(Clone, Copy, Default, PartialEq, Eq)]
pub struct Bitset<B: AsRef<[u32]>>(pub B);
pub trait ExtendBlocks: AsMut<[u32]> + AsRef<[u32]> {
fn extend_blocks(&mut self, extra_blocks: usize);
}
impl ExtendBlocks for Box<[u32]> {
fn extend_blocks(&mut self, extra_blocks: usize) {
let old_len = self.len();
let new_len = (old_len + extra_blocks).next_power_of_two().max(8);
let mut self_vec = std::mem::take(self).into_vec();
self_vec.extend(iter::repeat(0).take(new_len - old_len));
*self = self_vec.into();
}
}
impl ExtendBlocks for Vec<u32> {
fn extend_blocks(&mut self, extra_blocks: usize) {
self.extend(iter::repeat(0).take(extra_blocks));
}
}
#[cfg(feature = "smallvec")]
impl<A: smallvec::Array<Item = u32>> ExtendBlocks for smallvec::SmallVec<A> {
fn extend_blocks(&mut self, extra_blocks: usize) {
self.extend(iter::repeat(0).take(extra_blocks));
}
}
impl<B: ExtendBlocks> Bitset<B> {
pub fn enable_bit_extending(&mut self, bit: usize) {
let block = bit / u32::BITS64;
let offset = bit % u32::BITS64;
let blocks_len = self.0.as_ref().len();
if block >= blocks_len {
let extra_blocks = block - blocks_len + 1;
self.0.extend_blocks(extra_blocks);
}
let blocks = self.0.as_mut();
blocks[block] |= 1 << offset;
}
}
impl<B: AsRef<[u32]> + AsMut<[u32]>> Bitset<B> {
#[inline]
pub fn enable_bit(&mut self, bit: usize) -> Option<()> {
let block = bit / u32::BITS64;
let offset = bit % u32::BITS64;
self.0.as_mut().get_mut(block).map(|block| {
*block |= 1 << offset;
})
}
#[inline]
pub fn disable_bit(&mut self, bit: usize) -> Option<()> {
let block = bit / u32::BITS64;
let offset = bit % u32::BITS64;
self.0.as_mut().get_mut(block).map(|block| {
*block &= !(1 << offset);
})
}
#[inline]
pub fn disable_range(&mut self, range: Range<usize>) {
range.for_each(|i| {
self.disable_bit(i);
});
}
}
impl<B: AsRef<[u32]>> Bitset<B> {
#[inline]
pub fn bit_len(&self) -> usize {
self.0.as_ref().len() * u32::BITS64
}
#[inline]
pub fn bit(&self, at: usize) -> bool {
let block = at / u32::BITS64;
let offset = (at % u32::BITS64) as u32;
let offset = 1 << offset;
let Some(block) = self.0.as_ref().get(block) else {
return false;
};
block & offset == offset
}
#[inline]
#[allow(clippy::similar_names)] pub fn u32_at(&self, at: usize) -> Result<u32, u32> {
let block = at / u32::BITS64;
let offset = (at % u32::BITS64) as u32;
if offset == 0 {
self.0.as_ref().get(block).copied().ok_or(0)
} else {
let inset = u32::BITS - offset;
let msb_0 = self.0.as_ref().get(block).map_or(0, |&t| t) >> offset;
let lsb_1 = self.0.as_ref().get(block + 1).map_or(0, |&t| t) << inset;
let mask = safe_n_mask(inset);
let spills_out = at + 32 > self.bit_len();
let ctor = if spills_out { Err } else { Ok };
ctor((msb_0 & mask) | (lsb_1 & !mask))
}
}
#[inline]
#[allow(clippy::similar_names)] pub fn n_at(&self, n: u32, at: usize) -> Option<u32> {
let block = at / u32::BITS64;
let offset = (at % u32::BITS64) as u32;
let n_mask = safe_n_mask(n);
if at + n as usize > self.bit_len() {
None
} else if offset + n <= 32 {
let value = *self.0.as_ref().get(block)?;
Some((value >> offset) & n_mask)
} else {
let inset = u32::BITS - offset;
let msb_0 = self.0.as_ref().get(block)? >> offset;
let lsb_1 = self.0.as_ref().get(block + 1)?.wrapping_shl(inset);
let mask = safe_n_mask(inset);
let value = (msb_0 & mask) | (lsb_1 & !mask);
Some(value & n_mask)
}
}
#[inline]
pub fn ones(&self) -> Ones {
let blocks = self.0.as_ref();
let (bitset, remaining_blocks) = blocks.split_first().map_or((0, blocks), |(b, r)| (*b, r));
Ones { block_idx: 0, crop: 0, bitset, remaining_blocks }
}
#[inline]
pub fn ones_in_range(&self, range: impl RangeBounds<usize>) -> Ones {
let start = match range.start_bound() {
std::ops::Bound::Included(start) => *start,
std::ops::Bound::Excluded(start) => *start + 1,
std::ops::Bound::Unbounded => 0,
};
let end = match range.end_bound() {
std::ops::Bound::Included(end) => *end + 1,
std::ops::Bound::Excluded(end) => *end,
std::ops::Bound::Unbounded => self.bit_len(),
};
let crop = Range {
start: (start % u32::BITS64) as u32,
end: (end % u32::BITS64) as u32,
};
let range = Range {
start: start / u32::BITS64,
end: div_ceil(end, u32::BITS64),
};
let all_blocks = &self.0.as_ref()[range.clone()];
let (mut bitset, remaining_blocks) = all_blocks
.split_first()
.map_or((0, all_blocks), |(b, r)| (*b, r));
bitset &= ((1 << crop.start) - 1) ^ u32::MAX;
if remaining_blocks.is_empty() && crop.end != 0 {
bitset &= (1 << crop.end) - 1;
}
Ones {
block_idx: range.start as u32,
crop: crop.end,
bitset,
remaining_blocks,
}
}
}
impl<B: AsRef<[u32]>> fmt::Debug for Bitset<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[")?;
for (i, block) in self.0.as_ref().iter().enumerate() {
if i != 0 {
write!(f, "_")?;
}
write!(f, "{block:08x}")?;
}
write!(f, "]")?;
Ok(())
}
}
impl<'a, B: AsRef<[u32]>> IntoIterator for &'a Bitset<B> {
type Item = u32;
type IntoIter = Ones<'a>;
fn into_iter(self) -> Self::IntoIter {
self.ones_in_range(0..self.bit_len())
}
}
impl Extend<u32> for Bitset<Vec<u32>> {
#[inline]
fn extend<T: IntoIterator<Item = u32>>(&mut self, iter: T) {
iter.into_iter()
.for_each(|bit| self.enable_bit_extending(bit as usize));
}
}
impl Extend<usize> for Bitset<Vec<u32>> {
#[inline]
fn extend<T: IntoIterator<Item = usize>>(&mut self, iter: T) {
iter.into_iter()
.for_each(|bit| self.enable_bit_extending(bit));
}
}
impl Extend<u32> for Bitset<Box<[u32]>> {
#[inline]
fn extend<T: IntoIterator<Item = u32>>(&mut self, iter: T) {
iter.into_iter().for_each(|bit| {
self.enable_bit(bit as usize);
});
}
}
impl Extend<usize> for Bitset<Box<[u32]>> {
#[inline]
fn extend<T: IntoIterator<Item = usize>>(&mut self, iter: T) {
iter.into_iter().for_each(|bit| {
self.enable_bit(bit);
});
}
}
impl FromIterator<u32> for Bitset<Box<[u32]>> {
fn from_iter<T: IntoIterator<Item = u32>>(iter: T) -> Self {
let acc: Bitset<Vec<_>> = iter.into_iter().collect();
Bitset(acc.0.into_boxed_slice())
}
}
impl FromIterator<u32> for Bitset<Vec<u32>> {
fn from_iter<T: IntoIterator<Item = u32>>(iter: T) -> Self {
let iter = iter.into_iter();
let mut acc = Bitset(Vec::new());
acc.extend(iter);
acc
}
}
impl FromIterator<usize> for Bitset<Box<[u32]>> {
fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
let acc: Bitset<Vec<_>> = iter.into_iter().collect();
Bitset(acc.0.into_boxed_slice())
}
}
impl FromIterator<usize> for Bitset<Vec<u32>> {
fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
let iter = iter.into_iter();
let mut acc = Bitset(Vec::new());
acc.extend(iter);
acc
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Ones<'a> {
block_idx: u32,
crop: u32,
bitset: u32,
remaining_blocks: &'a [u32],
}
impl Iterator for Ones<'_> {
type Item = u32;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
while self.bitset == 0 {
let Some((&bitset, remaining_blocks)) = self.remaining_blocks.split_first() else {
return None;
};
self.bitset = bitset;
self.remaining_blocks = remaining_blocks;
if self.remaining_blocks.is_empty() && self.crop != 0 {
self.bitset &= (1 << self.crop) - 1;
}
self.block_idx += 1;
}
let t = self.bitset & 0_u32.wrapping_sub(self.bitset);
let r = self.bitset.trailing_zeros();
self.bitset ^= t;
Some(self.block_idx * u32::BITS + r)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let bitset_ones = self.bitset.count_ones();
let Some((last, slice)) = self.remaining_blocks.split_last() else {
return (bitset_ones as usize, Some(bitset_ones as usize));
};
let ones: u32 = slice.iter().map(|b| b.count_ones()).sum();
let trailing_bits = last & !((1 << self.crop) - 1);
let trailing_bits = trailing_bits.count_ones();
let exact_size = (bitset_ones + ones + trailing_bits) as usize;
(exact_size, Some(exact_size))
}
}
impl ExactSizeIterator for Ones<'_> {}
impl SortedByItem for Ones<'_> {}
impl Ones<'_> {
#[must_use]
pub fn all_one(self) -> bool {
let Some((last, slice)) = self.remaining_blocks.split_last() else {
let mask = (1 << self.crop) - 1;
return (self.bitset & mask) == mask;
};
let bitset_ones = self.bitset.count_ones() == self.bitset.trailing_ones();
let prefix_ones = slice.iter().fold(true, |acc, &b| acc & (b == u32::MAX));
let mask = (1 << self.crop) - 1;
let tail_ones = (last & mask) == mask;
bitset_ones && prefix_ones && tail_ones
}
}