use crate::traits::Word;
use ambassador::delegatable_trait;
use atomic_primitive::PrimitiveAtomicUnsigned;
use impl_tools::autoimpl;
use mem_dbg::{MemDbg, MemSize};
use num_primitive::PrimitiveInteger;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use std::{iter::FusedIterator, marker::PhantomData, sync::atomic::Ordering};
macro_rules! panic_if_out_of_bounds {
($index: expr, $len: expr) => {
if $index >= $len {
panic!("Bit index out of bounds: {} >= {}", $index, $len)
}
};
}
#[autoimpl(for<T: trait + ?Sized> &T, &mut T, Box<T>)]
#[delegatable_trait]
pub trait BitLength {
fn len(&self) -> usize;
}
impl<W: Word, T: ?Sized + AsRef<[W]> + BitLength> BitVecOps<W> for T {}
pub trait BitVecOps<W: Word>: AsRef<[W]> + BitLength {
#[inline]
fn get(&self, index: usize) -> bool {
panic_if_out_of_bounds!(index, self.len());
unsafe { self.get_unchecked(index) }
}
#[inline(always)]
unsafe fn get_unchecked(&self, index: usize) -> bool {
let bits_per_word = W::BITS as usize;
let word_index = index / bits_per_word;
let word = unsafe { *self.as_ref().get_unchecked(word_index) };
(word >> (index % bits_per_word)) & W::ONE != W::ZERO
}
#[inline(always)]
fn iter(&self) -> BitIter<'_, W, [W]> {
BitIter::new(self.as_ref(), self.len())
}
fn iter_ones(&self) -> OnesIter<'_, W, [W]> {
OnesIter::new(self.as_ref(), self.len())
}
fn iter_zeros(&self) -> ZerosIter<'_, W, [W]> {
ZerosIter::new(self.as_ref(), self.len())
}
#[cfg(feature = "rayon")]
fn par_count_ones(&self) -> usize {
let bits_per_word = W::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_ref();
let mut num_ones;
num_ones = bits[..full_words]
.par_iter()
.with_min_len(crate::RAYON_MIN_LEN)
.map(|x| x.count_ones() as usize)
.sum();
if residual != 0 {
num_ones +=
(self.as_ref()[full_words] << (bits_per_word - residual)).count_ones() as usize
}
num_ones
}
}
impl<W: Word, T: AsRef<[W]> + AsMut<[W]> + BitLength> BitVecOpsMut<W> for T {}
pub trait BitVecOpsMut<W: Word>: AsRef<[W]> + AsMut<[W]> + BitLength {
#[inline]
fn set(&mut self, index: usize, value: bool) {
panic_if_out_of_bounds!(index, self.len());
unsafe { self.set_unchecked(index, value) }
}
#[inline(always)]
unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
let bits_per_word = W::BITS as usize;
let word_index = index / bits_per_word;
let bit_index = index % bits_per_word;
let bits = self.as_mut();
unsafe {
if value {
*bits.get_unchecked_mut(word_index) |= W::ONE << bit_index;
} else {
*bits.get_unchecked_mut(word_index) &= !(W::ONE << bit_index);
}
}
}
fn fill(&mut self, value: bool) {
let bits_per_word = W::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_mut();
let word_value: W = if value { !W::ZERO } else { W::ZERO };
bits[..full_words].iter_mut().for_each(|x| *x = word_value);
if residual != 0 {
let mask = (W::ONE << residual) - W::ONE;
bits[full_words] = (bits[full_words] & !mask) | (word_value & mask);
}
}
#[cfg(feature = "rayon")]
fn par_fill(&mut self, value: bool) {
let bits_per_word = W::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_mut();
let word_value: W = if value { !W::ZERO } else { W::ZERO };
bits[..full_words]
.par_iter_mut()
.with_min_len(crate::RAYON_MIN_LEN)
.for_each(|x| *x = word_value);
if residual != 0 {
let mask = (W::ONE << residual) - W::ONE;
bits[full_words] = (bits[full_words] & !mask) | (word_value & mask);
}
}
fn reset(&mut self) {
self.fill(false);
}
#[cfg(feature = "rayon")]
fn par_reset(&mut self) {
self.par_fill(false);
}
fn flip(&mut self) {
let bits_per_word = W::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_mut();
bits[..full_words].iter_mut().for_each(|x| *x = !*x);
if residual != 0 {
let mask = (W::ONE << residual) - W::ONE;
bits[full_words] = (bits[full_words] & !mask) | (!bits[full_words] & mask);
}
}
#[cfg(feature = "rayon")]
fn par_flip(&mut self) {
let bits_per_word = W::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_mut();
bits[..full_words]
.par_iter_mut()
.with_min_len(crate::RAYON_MIN_LEN)
.for_each(|x| *x = !*x);
if residual != 0 {
let mask = (W::ONE << residual) - W::ONE;
bits[full_words] = (bits[full_words] & !mask) | (!bits[full_words] & mask);
}
}
}
pub trait BitVecValueOps<W: Word> {
fn get_value(&self, pos: usize, width: usize) -> W;
unsafe fn get_value_unchecked(&self, pos: usize, width: usize) -> W;
}
#[derive(Debug, Clone, MemSize, MemDbg)]
pub struct BitIter<'a, W: Word, B: ?Sized> {
bits: &'a B,
len: usize,
next_bit_pos: usize,
_phantom: PhantomData<W>,
}
impl<'a, W: Word, B: ?Sized + AsRef<[W]>> BitIter<'a, W, B> {
pub fn new(bits: &'a B, len: usize) -> Self {
debug_assert!(len <= bits.as_ref().len() * W::BITS as usize);
BitIter {
bits,
len,
next_bit_pos: 0,
_phantom: PhantomData,
}
}
}
impl<W: Word, B: ?Sized + AsRef<[W]>> Iterator for BitIter<'_, W, B> {
type Item = bool;
fn next(&mut self) -> Option<bool> {
if self.next_bit_pos == self.len {
return None;
}
let bits_per_word = W::BITS as usize;
let word_idx = self.next_bit_pos / bits_per_word;
let bit_idx = self.next_bit_pos % bits_per_word;
let word = unsafe { *self.bits.as_ref().get_unchecked(word_idx) };
let bit = (word >> bit_idx) & W::ONE;
self.next_bit_pos += 1;
Some(bit != W::ZERO)
}
}
impl<W: Word, B: ?Sized + AsRef<[W]>> ExactSizeIterator for BitIter<'_, W, B> {
fn len(&self) -> usize {
self.len - self.next_bit_pos
}
}
impl<W: Word, B: ?Sized + AsRef<[W]>> FusedIterator for BitIter<'_, W, B> {}
#[derive(Debug, Clone, MemSize, MemDbg)]
pub struct OnesIter<'a, W: Word, B: ?Sized> {
bits: &'a B,
len: usize,
word_idx: usize,
word: W,
}
impl<'a, W: Word, B: ?Sized + AsRef<[W]>> OnesIter<'a, W, B> {
pub fn new(bits: &'a B, len: usize) -> Self {
debug_assert!(len <= bits.as_ref().len() * W::BITS as usize);
let word = if bits.as_ref().is_empty() {
W::ZERO
} else {
unsafe { *bits.as_ref().get_unchecked(0) }
};
Self {
bits,
len,
word_idx: 0,
word,
}
}
}
impl<W: Word, B: ?Sized + AsRef<[W]>> Iterator for OnesIter<'_, W, B> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let bits_per_word = W::BITS as usize;
while self.word == W::ZERO {
self.word_idx += 1;
if self.word_idx >= self.bits.as_ref().len() {
return None;
}
self.word = unsafe { *self.bits.as_ref().get_unchecked(self.word_idx) };
}
let bit_idx = self.word.trailing_zeros() as usize;
let res = (self.word_idx * bits_per_word) + bit_idx;
if res >= self.len {
None
} else {
self.word &= self.word - W::ONE;
Some(res)
}
}
}
impl<W: Word, B: ?Sized + AsRef<[W]>> FusedIterator for OnesIter<'_, W, B> {}
#[derive(Debug, Clone, MemSize, MemDbg)]
pub struct ZerosIter<'a, W: Word, B: ?Sized> {
bits: &'a B,
len: usize,
word_idx: usize,
word: W,
}
impl<'a, W: Word, B: ?Sized + AsRef<[W]>> ZerosIter<'a, W, B> {
pub fn new(bits: &'a B, len: usize) -> Self {
debug_assert!(len <= bits.as_ref().len() * W::BITS as usize);
let word = if bits.as_ref().is_empty() {
W::ZERO
} else {
unsafe { !*bits.as_ref().get_unchecked(0) }
};
Self {
bits,
len,
word_idx: 0,
word,
}
}
}
impl<W: Word, B: ?Sized + AsRef<[W]>> Iterator for ZerosIter<'_, W, B> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let bits_per_word = W::BITS as usize;
while self.word == W::ZERO {
self.word_idx += 1;
if self.word_idx >= self.bits.as_ref().len() {
return None;
}
self.word = unsafe { !*self.bits.as_ref().get_unchecked(self.word_idx) };
}
let bit_idx = self.word.trailing_zeros() as usize;
let res = (self.word_idx * bits_per_word) + bit_idx;
if res >= self.len {
None
} else {
self.word &= self.word - W::ONE;
Some(res)
}
}
}
impl<W: Word, B: ?Sized + AsRef<[W]>> FusedIterator for ZerosIter<'_, W, B> {}
impl<A: PrimitiveAtomicUnsigned<Value: Word>, T: ?Sized + AsRef<[A]> + BitLength> AtomicBitVecOps<A>
for T
{
}
pub trait AtomicBitVecOps<A: PrimitiveAtomicUnsigned<Value: Word>>: AsRef<[A]> + BitLength {
fn get(&self, index: usize, ordering: Ordering) -> bool {
panic_if_out_of_bounds!(index, self.len());
unsafe { self.get_unchecked(index, ordering) }
}
fn set(&self, index: usize, value: bool, ordering: Ordering) {
panic_if_out_of_bounds!(index, self.len());
unsafe { self.set_unchecked(index, value, ordering) }
}
fn swap(&self, index: usize, value: bool, ordering: Ordering) -> bool {
panic_if_out_of_bounds!(index, self.len());
unsafe { self.swap_unchecked(index, value, ordering) }
}
#[inline]
unsafe fn get_unchecked(&self, index: usize, ordering: Ordering) -> bool {
let bits_per_word = A::Value::BITS as usize;
let word_index = index / bits_per_word;
let bits = self.as_ref();
let word = unsafe { bits.get_unchecked(word_index).load(ordering) };
(word >> (index % bits_per_word)) & A::Value::ONE != A::Value::ZERO
}
#[inline]
unsafe fn set_unchecked(&self, index: usize, value: bool, ordering: Ordering) {
let bits_per_word = A::Value::BITS as usize;
let word_index = index / bits_per_word;
let bit_index = index % bits_per_word;
let bits = self.as_ref();
unsafe {
if value {
bits.get_unchecked(word_index)
.fetch_or(A::Value::ONE << bit_index, ordering);
} else {
bits.get_unchecked(word_index)
.fetch_and(!(A::Value::ONE << bit_index), ordering);
}
}
}
#[inline]
unsafe fn swap_unchecked(&self, index: usize, value: bool, ordering: Ordering) -> bool {
let bits_per_word = A::Value::BITS as usize;
let word_index = index / bits_per_word;
let bit_index = index % bits_per_word;
let bits = self.as_ref();
let old_word = unsafe {
if value {
bits.get_unchecked(word_index)
.fetch_or(A::Value::ONE << bit_index, ordering)
} else {
bits.get_unchecked(word_index)
.fetch_and(!(A::Value::ONE << bit_index), ordering)
}
};
(old_word >> bit_index) & A::Value::ONE != A::Value::ZERO
}
fn fill(&mut self, value: bool, ordering: Ordering) {
let bits_per_word = A::Value::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_ref();
let word_value: A::Value = if value {
!A::Value::ZERO
} else {
A::Value::ZERO
};
core::sync::atomic::fence(Ordering::SeqCst);
bits[..full_words]
.iter()
.for_each(|x| x.store(word_value, ordering));
if residual != 0 {
let mask = (A::Value::ONE << residual) - A::Value::ONE;
bits[full_words].store(
(bits[full_words].load(ordering) & !mask) | (word_value & mask),
ordering,
);
}
}
#[cfg(feature = "rayon")]
fn par_fill(&mut self, value: bool, ordering: Ordering) {
let bits_per_word = A::Value::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_ref();
let word_value: A::Value = if value {
!A::Value::ZERO
} else {
A::Value::ZERO
};
core::sync::atomic::fence(Ordering::SeqCst);
bits[..full_words]
.par_iter()
.with_min_len(crate::RAYON_MIN_LEN)
.for_each(|x| x.store(word_value, ordering));
if residual != 0 {
let mask = (A::Value::ONE << residual) - A::Value::ONE;
bits[full_words].store(
(bits[full_words].load(ordering) & !mask) | (word_value & mask),
ordering,
);
}
}
fn reset(&mut self, ordering: Ordering) {
self.fill(false, ordering);
}
#[cfg(feature = "rayon")]
fn par_reset(&mut self, ordering: Ordering) {
self.par_fill(false, ordering);
}
fn flip(&mut self, ordering: Ordering) {
let bits_per_word = A::Value::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_ref();
core::sync::atomic::fence(Ordering::SeqCst);
bits[..full_words]
.iter()
.for_each(|x| _ = x.fetch_xor(!A::Value::ZERO, ordering));
if residual != 0 {
let mask = (A::Value::ONE << residual) - A::Value::ONE;
let last_word = bits[full_words].load(ordering);
bits[full_words].store((last_word & !mask) | (!last_word & mask), ordering);
}
}
#[cfg(feature = "rayon")]
fn par_flip(&mut self, ordering: Ordering) {
let bits_per_word = A::Value::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_ref();
core::sync::atomic::fence(Ordering::SeqCst);
bits[..full_words]
.par_iter()
.with_min_len(crate::RAYON_MIN_LEN)
.for_each(|x| _ = x.fetch_xor(!A::Value::ZERO, ordering));
if residual != 0 {
let mask = (A::Value::ONE << residual) - A::Value::ONE;
let last_word = bits[full_words].load(ordering);
bits[full_words].store((last_word & !mask) | (!last_word & mask), ordering);
}
}
#[cfg(feature = "rayon")]
fn par_count_ones(&self) -> usize {
let bits_per_word = A::Value::BITS as usize;
let full_words = self.len() / bits_per_word;
let residual = self.len() % bits_per_word;
let bits = self.as_ref();
let mut num_ones;
core::sync::atomic::fence(Ordering::SeqCst);
num_ones = bits[..full_words]
.par_iter()
.with_min_len(crate::RAYON_MIN_LEN)
.map(|x| x.load(Ordering::Relaxed).count_ones() as usize)
.sum();
if residual != 0 {
num_ones += (bits[full_words].load(Ordering::Relaxed) << (bits_per_word - residual))
.count_ones() as usize
}
num_ones
}
#[inline(always)]
fn iter(&self) -> AtomicBitIter<'_, A, [A]> {
AtomicBitIter::new(self.as_ref(), self.len())
}
}
#[derive(Debug, MemSize, MemDbg)]
pub struct AtomicBitIter<'a, A, B: ?Sized> {
bits: &'a B,
len: usize,
next_bit_pos: usize,
_phantom: PhantomData<A>,
}
impl<'a, A: PrimitiveAtomicUnsigned<Value: Word>, B: ?Sized + AsRef<[A]>> AtomicBitIter<'a, A, B> {
pub fn new(bits: &'a B, len: usize) -> Self {
debug_assert!(len <= bits.as_ref().len() * A::Value::BITS as usize);
AtomicBitIter {
bits,
len,
next_bit_pos: 0,
_phantom: PhantomData,
}
}
}
impl<A: PrimitiveAtomicUnsigned<Value: Word>, B: ?Sized + AsRef<[A]>> Iterator
for AtomicBitIter<'_, A, B>
{
type Item = bool;
fn next(&mut self) -> Option<bool> {
if self.next_bit_pos == self.len {
return None;
}
let bits_per_word = A::Value::BITS as usize;
let word_idx = self.next_bit_pos / bits_per_word;
let bit_idx = self.next_bit_pos % bits_per_word;
let word = unsafe {
self.bits
.as_ref()
.get_unchecked(word_idx)
.load(Ordering::Relaxed)
};
let bit = (word >> bit_idx) & A::Value::ONE;
self.next_bit_pos += 1;
Some(bit != A::Value::ZERO)
}
}
impl<A: PrimitiveAtomicUnsigned<Value: Word>, B: ?Sized + AsRef<[A]>> ExactSizeIterator
for AtomicBitIter<'_, A, B>
{
fn len(&self) -> usize {
self.len - self.next_bit_pos
}
}
impl<A: PrimitiveAtomicUnsigned<Value: Word>, B: ?Sized + AsRef<[A]>> FusedIterator
for AtomicBitIter<'_, A, B>
{
}