use std::{
io::{self, Cursor, Write},
ops::Range,
};
use azalea_buf::{AzBuf, BufReadError};
#[derive(AzBuf, Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct BitSet {
data: Box<[u64]>,
}
const LOG2_BITS_PER_WORD: usize = 6;
impl BitSet {
#[inline]
pub fn new(num_bits: usize) -> Self {
BitSet {
data: vec![0; num_bits.div_ceil(64)].into(),
}
}
#[inline]
pub fn index(&self, index: usize) -> bool {
self.get(index).unwrap_or_else(|| {
let len = self.len();
panic!("index out of bounds: the len is {len} but the index is {index}")
})
}
#[inline]
pub fn get(&self, index: usize) -> Option<bool> {
self.data
.get(index / 64)
.map(|word| (word & (1u64 << (index % 64))) != 0)
}
pub fn clear(&mut self, range: Range<usize>) {
assert!(
range.start <= range.end,
"Range ends before it starts; {} must be less than or equal to {}",
range.start,
range.end
);
let from_idx = range.start;
let mut to_idx = range.end;
if from_idx == to_idx {
return;
}
let start_word_idx = self.word_index(from_idx);
if start_word_idx >= self.data.len() {
return;
}
let mut end_word_idx = self.word_index(to_idx - 1);
if end_word_idx >= self.data.len() {
to_idx = self.len();
end_word_idx = self.data.len() - 1;
}
let first_word_mask = u64::MAX.wrapping_shl(
from_idx
.try_into()
.expect("from_index shouldn't be larger than u32"),
);
let last_word_mask = u64::MAX.wrapping_shr((64 - (to_idx % 64)) as u32);
if start_word_idx == end_word_idx {
self.data[start_word_idx] &= !(first_word_mask & last_word_mask);
} else {
self.data[start_word_idx] &= !first_word_mask;
for i in (start_word_idx + 1)..end_word_idx {
self.data[i] = 0;
}
self.data[end_word_idx] &= !last_word_mask;
}
}
pub fn next_clear_bit(&self, from_index: usize) -> usize {
let mut u = self.word_index(from_index);
if u >= self.data.len() {
return from_index;
}
let mut word = !self.data[u] & (u64::MAX.wrapping_shl(from_index.try_into().unwrap()));
loop {
if word != 0 {
return (u * 64) + word.trailing_zeros() as usize;
}
u += 1;
if u == self.data.len() {
return self.data.len() * 64;
}
word = !self.data[u];
}
}
#[inline]
fn word_index(&self, bit_index: usize) -> usize {
bit_index >> LOG2_BITS_PER_WORD
}
#[inline]
pub fn set(&mut self, bit_index: usize) {
self.data[bit_index / 64] |= 1u64 << (bit_index % 64);
}
pub fn iter_ones(&self) -> impl Iterator<Item = usize> {
(0..self.len()).filter(|i| self.index(*i))
}
#[inline]
pub fn len(&self) -> usize {
self.data.len() * 64
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl From<Vec<u64>> for BitSet {
fn from(data: Vec<u64>) -> Self {
BitSet { data: data.into() }
}
}
impl From<Vec<u8>> for BitSet {
fn from(data: Vec<u8>) -> Self {
let mut words = vec![0; data.len().div_ceil(8)];
for (i, byte) in data.iter().enumerate() {
words[i / 8] |= (*byte as u64) << ((i % 8) * 8);
}
BitSet { data: words.into() }
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct FixedBitSet<const N: usize>
where
[u8; bits_to_bytes(N)]: Sized,
{
data: [u8; bits_to_bytes(N)],
}
impl<const N: usize> FixedBitSet<N>
where
[u8; bits_to_bytes(N)]: Sized,
{
pub const fn new() -> Self {
FixedBitSet {
data: [0; bits_to_bytes(N)],
}
}
pub const fn new_with_data(data: [u8; bits_to_bytes(N)]) -> Self {
FixedBitSet { data }
}
#[inline]
pub fn index(&self, index: usize) -> bool {
(self.data[index / 8] & (1u8 << (index % 8))) != 0
}
#[inline]
pub fn set(&mut self, bit_index: usize) {
self.data[bit_index / 8] |= 1u8 << (bit_index % 8);
}
}
impl<const N: usize> AzBuf for FixedBitSet<N>
where
[u8; bits_to_bytes(N)]: Sized,
{
fn azalea_read(buf: &mut Cursor<&[u8]>) -> Result<Self, BufReadError> {
let mut data = [0; bits_to_bytes(N)];
for item in data.iter_mut().take(bits_to_bytes(N)) {
*item = u8::azalea_read(buf)?;
}
Ok(FixedBitSet { data })
}
fn azalea_write(&self, buf: &mut impl Write) -> io::Result<()> {
for i in 0..bits_to_bytes(N) {
self.data[i].azalea_write(buf)?;
}
Ok(())
}
}
impl<const N: usize> Default for FixedBitSet<N>
where
[u8; bits_to_bytes(N)]: Sized,
{
fn default() -> Self {
Self::new()
}
}
pub const fn bits_to_bytes(n: usize) -> usize {
n.div_ceil(8)
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct FastFixedBitSet<const N: usize>
where
[u64; bits_to_longs(N)]: Sized,
{
data: [u64; bits_to_longs(N)],
}
impl<const N: usize> FastFixedBitSet<N>
where
[u64; bits_to_longs(N)]: Sized,
{
pub const fn new() -> Self {
FastFixedBitSet {
data: [0; bits_to_longs(N)],
}
}
#[inline]
pub fn index(&self, index: usize) -> bool {
(self.data[index / 64] & (1u64 << (index % 64))) != 0
}
#[inline]
pub fn set(&mut self, bit_index: usize) {
assert!(bit_index < N);
self.data[bit_index / 64] |= 1u64 << (bit_index % 64);
}
}
impl<const N: usize> Default for FastFixedBitSet<N>
where
[u64; bits_to_longs(N)]: Sized,
{
fn default() -> Self {
Self::new()
}
}
pub const fn bits_to_longs(n: usize) -> usize {
n.div_ceil(64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitset() {
let mut bitset = BitSet::new(64);
assert!(!bitset.index(0));
assert!(!bitset.index(1));
assert!(!bitset.index(2));
bitset.set(1);
assert!(!bitset.index(0));
assert!(bitset.index(1));
assert!(!bitset.index(2));
}
#[test]
fn test_clear() {
let mut bitset = BitSet::new(128);
bitset.set(62);
bitset.set(63);
bitset.set(64);
bitset.set(65);
bitset.set(66);
bitset.clear(63..65);
assert!(bitset.index(62));
assert!(!bitset.index(63));
assert!(!bitset.index(64));
assert!(bitset.index(65));
assert!(bitset.index(66));
}
#[test]
fn test_clear_2() {
let mut bitset = BitSet::new(128);
bitset.set(64);
bitset.set(65);
bitset.set(66);
bitset.set(67);
bitset.set(68);
bitset.clear(65..67);
assert!(bitset.index(64));
assert!(!bitset.index(65));
assert!(!bitset.index(66));
assert!(bitset.index(67));
assert!(bitset.index(68));
}
}