use std::{
alloc::Layout,
fmt::{self, Debug, Display},
mem,
ptr::{self, NonNull},
slice,
};
use crate::{Allocator, Box, CloneIn};
const USIZE_BITS: usize = usize::BITS as usize;
pub struct BitSet<'alloc> {
entries: Box<'alloc, [usize]>,
max_bit_count: usize,
}
impl<'alloc> BitSet<'alloc> {
pub fn new_in(max_bit_count: usize, allocator: &'alloc Allocator) -> Self {
let capacity = max_bit_count.div_ceil(USIZE_BITS);
let layout = Layout::array::<usize>(capacity).unwrap();
let ptr = allocator.alloc_layout(layout).cast::<usize>();
unsafe { ptr::write_bytes(ptr.as_ptr(), 0, capacity) };
let slice = unsafe { slice::from_raw_parts_mut(ptr.as_ptr(), capacity) };
let entries = unsafe { Box::from_non_null(NonNull::from(slice)) };
Self { entries, max_bit_count }
}
#[inline]
pub fn has_bit(&self, bit: usize) -> bool {
(self.entries[bit / USIZE_BITS] & (1 << (bit % USIZE_BITS))) != 0
}
#[inline]
pub fn set_bit(&mut self, bit: usize) {
self.entries[bit / USIZE_BITS] |= 1 << (bit % USIZE_BITS);
}
#[inline]
pub fn unset_bit(&mut self, bit: usize) {
self.entries[bit / USIZE_BITS] &= !(1 << (bit % USIZE_BITS));
}
pub fn union(&mut self, other: &Self) {
for (self_word, other_word) in self.entries.iter_mut().zip(other.entries.iter()) {
*self_word |= *other_word;
}
}
#[inline]
pub fn clear(&mut self) {
self.entries.fill(0);
}
#[inline]
pub fn ones(&self) -> Ones<'_, '_> {
Ones { bitset: self, idx: 0 }
}
}
impl Display for BitSet<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.entries.is_empty() {
return Ok(());
}
let mut usizes = self.entries.iter().rev();
while usizes.clone().next() == Some(&0) {
usizes.next();
}
let Some(highest_usize) = usizes.next() else {
return f.write_str("00000000");
};
let bytes = highest_usize.to_ne_bytes();
#[cfg(target_endian = "little")]
let mut bytes = bytes.iter().rev();
#[cfg(target_endian = "big")]
let mut bytes = bytes.iter();
while bytes.clone().next() == Some(&0) {
bytes.next();
}
let highest_byte = bytes.next().unwrap();
f.write_str(&format!("{highest_byte:08b}"))?;
for byte in bytes {
f.write_str(&format!("_{byte:08b}"))?;
}
for lower_usize in usizes {
let bytes = lower_usize.to_ne_bytes();
#[cfg(target_endian = "little")]
let bytes = bytes.iter().rev();
#[cfg(target_endian = "big")]
let bytes = bytes.iter();
for byte in bytes {
f.write_str(&format!("_{byte:08b}"))?;
}
}
Ok(())
}
}
impl Debug for BitSet<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("BitSet").field(&self.to_string()).finish()
}
}
impl<'new_alloc> CloneIn<'new_alloc> for BitSet<'_> {
type Cloned = BitSet<'new_alloc>;
fn clone_in(&self, allocator: &'new_alloc Allocator) -> BitSet<'new_alloc> {
let slice = self.entries.as_ref();
let layout = unsafe {
Layout::from_size_align_unchecked(mem::size_of_val(slice), align_of::<usize>())
};
let dst_ptr = allocator.alloc_layout(layout).cast::<usize>();
unsafe { ptr::copy_nonoverlapping(slice.as_ptr(), dst_ptr.as_ptr(), slice.len()) };
let new_slice = unsafe { slice::from_raw_parts_mut(dst_ptr.as_ptr(), slice.len()) };
let entries = unsafe { Box::from_non_null(NonNull::from(new_slice)) };
BitSet { entries, max_bit_count: self.max_bit_count }
}
}
#[derive(Clone)]
pub struct Ones<'a, 'b> {
bitset: &'b BitSet<'a>,
idx: usize,
}
impl Iterator for Ones<'_, '_> {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<usize> {
while self.idx < self.bitset.max_bit_count {
if self.bitset.has_bit(self.idx) {
let ret = self.idx;
self.idx += 1;
return Some(ret);
}
self.idx += 1;
}
None
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.bitset.max_bit_count - self.idx;
(0, Some(remaining))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic() {
let allocator = Allocator::default();
let mut bs = BitSet::new_in(64, &allocator);
assert_eq!(bs.to_string(), "00000000");
bs.set_bit(0);
bs.set_bit(1);
bs.set_bit(7);
assert_eq!(bs.to_string(), "10000011");
bs.set_bit(9);
assert_eq!(bs.to_string(), "00000010_10000011");
bs.set_bit(63);
assert_eq!(
bs.to_string(),
"10000000_00000000_00000000_00000000_00000000_00000000_00000010_10000011"
);
let mut bs = BitSet::new_in(65, &allocator);
assert_eq!(bs.to_string(), "00000000");
bs.set_bit(0);
bs.set_bit(1);
bs.set_bit(7);
assert_eq!(bs.to_string(), "10000011");
bs.set_bit(8);
assert_eq!(bs.to_string(), "00000001_10000011");
bs.set_bit(15);
assert_eq!(bs.to_string(), "10000001_10000011");
bs.set_bit(63);
assert_eq!(
bs.to_string(),
"10000000_00000000_00000000_00000000_00000000_00000000_10000001_10000011"
);
bs.set_bit(64);
assert_eq!(
bs.to_string(),
"00000001_10000000_00000000_00000000_00000000_00000000_00000000_10000001_10000011"
);
}
#[test]
fn clone_in() {
let allocator = Allocator::default();
let mut bs = BitSet::new_in(16, &allocator);
assert_eq!(bs.to_string(), "00000000");
bs.set_bit(0);
bs.set_bit(1);
bs.set_bit(7);
assert_eq!(bs.to_string(), "10000011");
bs.set_bit(8);
assert_eq!(bs.to_string(), "00000001_10000011");
let mut bs2 = bs.clone_in(&allocator);
bs2.set_bit(15);
assert_eq!(bs2.to_string(), "10000001_10000011");
}
#[test]
fn unset_bit() {
let allocator = Allocator::default();
let mut bs = BitSet::new_in(16, &allocator);
bs.set_bit(0);
bs.set_bit(1);
bs.set_bit(7);
bs.set_bit(8);
assert_eq!(bs.to_string(), "00000001_10000011");
bs.unset_bit(1);
assert_eq!(bs.to_string(), "00000001_10000001");
bs.unset_bit(8);
assert_eq!(bs.to_string(), "10000001");
bs.unset_bit(0);
assert_eq!(bs.to_string(), "10000000");
bs.unset_bit(7);
assert_eq!(bs.to_string(), "00000000");
}
#[test]
fn union() {
let allocator = Allocator::default();
let mut bs1 = BitSet::new_in(16, &allocator);
bs1.set_bit(0);
bs1.set_bit(3);
bs1.set_bit(8);
assert_eq!(bs1.to_string(), "00000001_00001001");
let mut bs2 = BitSet::new_in(16, &allocator);
bs2.set_bit(1);
bs2.set_bit(3);
bs2.set_bit(9);
assert_eq!(bs2.to_string(), "00000010_00001010");
bs1.union(&bs2);
assert_eq!(bs1.to_string(), "00000011_00001011");
}
#[test]
fn clear() {
let allocator = Allocator::default();
let mut bs = BitSet::new_in(128, &allocator);
bs.set_bit(0);
bs.set_bit(7);
bs.set_bit(64);
bs.set_bit(127);
assert!(bs.has_bit(0));
assert!(bs.has_bit(7));
assert!(bs.has_bit(64));
assert!(bs.has_bit(127));
bs.clear();
assert_eq!(bs.to_string(), "00000000");
assert!(!bs.has_bit(0));
assert!(!bs.has_bit(7));
assert!(!bs.has_bit(64));
assert!(!bs.has_bit(127));
bs.set_bit(42);
assert!(bs.has_bit(42));
}
#[test]
fn ones() {
let allocator = Allocator::default();
let mut bs = BitSet::new_in(3, &allocator);
bs.set_bit(0);
bs.set_bit(2);
assert_eq!(bs.ones().collect::<Vec<_>>(), [0, 2]);
}
}