use core::array::from_fn;
use core::fmt::{Debug, Formatter};
use core::iter::{FusedIterator, Iterator};
use core::ops::{
BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Range, Shl, ShlAssign,
Shr, ShrAssign,
};
pub const fn bucket_count(bit_count: usize) -> usize {
bit_count.div_ceil(8)
}
#[allow(clippy::no_effect)]
#[allow(clippy::unnecessary_operation)]
pub(crate) const fn compile_assert_const_params(bit_count: usize, buckets: usize) {
["BIT_COUNT must be greater than zero."][(bit_count == 0) as usize];
["BUCKET_COUNT must match bucket_count(BIT_COUNT)."]
[(bucket_count(bit_count) != buckets) as usize];
}
pub(crate) fn runtime_assert_const_params(bit_count: usize, buckets: usize) {
assert_ne!(bit_count, 0, "BIT_COUNT must be greater than zero.");
assert_eq!(
bucket_count(bit_count),
buckets,
"BUCKET_COUNT must match bucket_count(BIT_COUNT)."
);
}
pub(crate) const fn ones_mask(start_bit: usize, width: usize) -> u8 {
if width >= 8 {
!0u8
} else {
(1u8 << width).wrapping_sub(1) << start_bit
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub struct BitMap<const BIT_COUNT: usize, const BUCKET_COUNT: usize>(pub(crate) [u8; BUCKET_COUNT]);
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> BitMap<BIT_COUNT, BUCKET_COUNT> {
pub fn new() -> Self {
runtime_assert_const_params(BIT_COUNT, BUCKET_COUNT);
Self([0u8; BUCKET_COUNT])
}
pub const fn const_empty() -> Self {
compile_assert_const_params(BIT_COUNT, BUCKET_COUNT);
Self([0u8; BUCKET_COUNT])
}
#[inline]
pub fn with_all_set() -> Self {
runtime_assert_const_params(BIT_COUNT, BUCKET_COUNT);
let mut bm = Self([!0u8; BUCKET_COUNT]);
bm.clean_unused_bits();
bm
}
pub const fn const_full() -> Self {
compile_assert_const_params(BIT_COUNT, BUCKET_COUNT);
let mut bm = Self([!0u8; BUCKET_COUNT]);
bm.clean_unused_bits();
bm
}
#[inline]
pub fn from_slice(bits: &[bool]) -> Self {
runtime_assert_const_params(BIT_COUNT, BUCKET_COUNT);
assert_eq!(bits.len(), BIT_COUNT);
let mut bm = Self([0u8; BUCKET_COUNT]);
for (idx, bit) in bits.iter().enumerate() {
if *bit {
bm.set(idx)
}
}
bm
}
pub fn from_ones_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
runtime_assert_const_params(BIT_COUNT, BUCKET_COUNT);
let mut bitmap = Self::new();
for idx in iter {
assert!(idx < BIT_COUNT, "Bit index {idx} out of bounds");
bitmap.set(idx);
}
bitmap
}
#[inline]
pub fn set(&mut self, idx: usize) {
assert!(idx < BIT_COUNT, "Bit index {idx} out of bounds");
let (group_idx, item_idx) = Self::idxs(idx);
self.0[group_idx] |= 1 << item_idx;
}
pub fn set_range(&mut self, range: Range<usize>) {
assert!(
range.start < BIT_COUNT,
"Range start {} out of bounds",
range.start
);
assert!(
range.end <= BIT_COUNT,
"Range end {} out of bounds",
range.end
);
if range.start >= range.end {
return;
}
let (start_byte, start_bit) = Self::idxs(range.start);
let (end_byte, end_bit) = Self::idxs(range.end - 1);
if start_byte == end_byte {
let width = end_bit - start_bit + 1;
let mask = ones_mask(start_bit, width);
self.0[start_byte] |= mask;
return;
}
let first_mask = !0u8 << start_bit;
self.0[start_byte] |= first_mask;
for byte in &mut self.0[start_byte + 1..end_byte] {
*byte = !0;
}
let width = end_bit + 1;
let last_mask = ones_mask(0, width);
self.0[end_byte] |= last_mask;
}
#[inline]
pub fn unset(&mut self, idx: usize) {
assert!(idx < BIT_COUNT, "Bit index {idx} out of bounds");
let (group_idx, item_idx) = Self::idxs(idx);
self.0[group_idx] &= !(1 << item_idx);
}
pub fn unset_range(&mut self, range: Range<usize>) {
assert!(
range.start < BIT_COUNT,
"Range start {} out of bounds",
range.start
);
assert!(
range.end <= BIT_COUNT,
"Range end {} out of bounds",
range.end
);
if range.start >= range.end {
return;
}
let (start_byte, start_bit) = Self::idxs(range.start);
let (end_byte, end_bit) = Self::idxs(range.end - 1);
if start_byte == end_byte {
let width = end_bit - start_bit + 1;
let mask = !ones_mask(start_bit, width);
self.0[start_byte] &= mask;
return;
}
let first_mask = (1u8 << start_bit) - 1;
self.0[start_byte] &= first_mask;
for byte in &mut self.0[start_byte + 1..end_byte] {
*byte = 0;
}
let width = end_bit + 1;
let last_mask = !ones_mask(0, width);
self.0[end_byte] &= last_mask;
}
#[inline]
pub fn toggle(&mut self, idx: usize) -> bool {
assert!(idx < BIT_COUNT, "Bit index {idx} out of bounds");
let (group_idx, item_idx) = Self::idxs(idx);
let bit = self.0[group_idx] & 1 << item_idx != 0;
self.0[group_idx] ^= 1 << item_idx;
bit
}
#[inline]
pub fn is_set(&self, idx: usize) -> bool {
assert!(idx < BIT_COUNT, "Bit index {idx} out of bounds");
let (group_idx, item_idx) = Self::idxs(idx);
self.0[group_idx] & 1 << item_idx != 0
}
#[inline]
fn idxs(idx: usize) -> (usize, usize) {
(idx / 8, idx % 8)
}
#[inline]
pub fn iter(&self) -> BitMapIter<BIT_COUNT, BUCKET_COUNT> {
BitMapIter {
bytes: &self.0,
group_idx: 0,
item_idx: 0,
}
}
#[inline]
pub fn iter_ones(&self) -> IterOnes<BIT_COUNT, BUCKET_COUNT> {
IterOnes {
bytes: &self.0,
byte_idx: 0,
current: self.0[0],
base_bit_idx: 0,
}
}
#[inline]
pub fn iter_zeros(&self) -> IterZeros<BIT_COUNT, BUCKET_COUNT> {
IterZeros {
bytes: &self.0,
byte_idx: 0,
current: !self.0[0],
base_bit_idx: 0,
}
}
#[inline]
pub fn bit_or(&self, other: &Self) -> Self {
Self(from_fn(|i| self.0[i] | other.0[i]))
}
#[inline]
pub fn in_place_bit_or(&mut self, other: &Self) {
for (self_byte, other_byte) in self.0.iter_mut().zip(other.0.iter()) {
*self_byte |= other_byte
}
}
#[inline]
pub fn bit_and(&self, other: &Self) -> Self {
Self(from_fn(|i| self.0[i] & other.0[i]))
}
#[inline]
pub fn in_place_bit_and(&mut self, other: &Self) {
for (self_byte, other_byte) in self.0.iter_mut().zip(other.0.iter()) {
*self_byte &= other_byte
}
}
#[inline]
pub fn bit_xor(&self, other: &Self) -> Self {
Self(from_fn(|i| self.0[i] ^ other.0[i]))
}
#[inline]
pub fn in_place_bit_xor(&mut self, other: &Self) {
for (self_byte, other_byte) in self.0.iter_mut().zip(other.0.iter()) {
*self_byte ^= other_byte
}
}
#[inline]
pub fn bit_not(&self) -> Self {
let mut result = Self(from_fn(|i| !self.0[i]));
result.clean_unused_bits();
result
}
#[inline]
pub fn in_place_bit_not(&mut self) {
for byte in &mut self.0 {
*byte = !*byte;
}
self.clean_unused_bits();
}
#[inline]
pub fn popcount(&self) -> usize {
self.0.iter().map(|b| b.count_ones() as usize).sum()
}
pub fn first_set_bit(&self) -> Option<usize> {
for (i, byte) in self.0.iter().enumerate() {
if *byte != 0 {
let bit = byte.trailing_zeros() as usize;
return Some(i * 8 + bit);
}
}
None
}
#[inline]
const fn clean_unused_bits(&mut self) {
let bits_in_last = BIT_COUNT % 8;
if bits_in_last != 0 {
let mask = (1 << bits_in_last) - 1;
self.0[BUCKET_COUNT - 1] &= mask;
}
}
pub fn shift_left(&mut self, n: usize) {
if n >= BIT_COUNT {
self.0.fill(0);
return;
}
let (byte_shift, bit_shift) = Self::idxs(n);
if byte_shift > 0 {
for i in (byte_shift..BUCKET_COUNT).rev() {
self.0[i] = self.0[i - byte_shift];
}
for i in 0..byte_shift {
self.0[i] = 0;
}
}
if bit_shift > 0 {
for i in (0..BUCKET_COUNT).rev() {
let high = *self.0.get(i.wrapping_sub(1)).unwrap_or(&0);
self.0[i] <<= bit_shift;
self.0[i] |= high >> (8 - bit_shift);
}
}
self.clean_unused_bits();
}
pub fn shift_right(&mut self, n: usize) {
if n >= BIT_COUNT {
self.0.fill(0);
return;
}
self.clean_unused_bits();
let byte_shift = n / 8;
let bit_shift = n % 8;
if byte_shift > 0 {
for i in 0..BUCKET_COUNT - byte_shift {
self.0[i] = self.0[i + byte_shift];
}
for i in byte_shift..BUCKET_COUNT {
self.0[i] = 0;
}
}
if bit_shift > 0 {
for i in 0..BUCKET_COUNT {
let low = *self.0.get(i.wrapping_add(1)).unwrap_or(&0);
self.0[i] >>= bit_shift;
self.0[i] |= low << (8 - bit_shift);
}
}
}
pub fn rotate_left(&mut self, n: usize) {
runtime_assert_const_params(BIT_COUNT, BUCKET_COUNT);
if n % BIT_COUNT == 0 {
return;
}
let n = n % BIT_COUNT;
let mut prev = self.is_set((BIT_COUNT - n) % BIT_COUNT);
let mut bit_idx = 0;
let mut start_idx = 0;
for _ in 0..BIT_COUNT {
let temp = self.is_set(bit_idx);
if prev {
self.set(bit_idx)
} else {
self.unset(bit_idx);
}
prev = temp;
bit_idx = (bit_idx + n) % BIT_COUNT;
if bit_idx == start_idx {
start_idx += 1;
bit_idx += 1;
prev = self.is_set((bit_idx + BIT_COUNT - n) % BIT_COUNT)
}
}
}
pub fn rotate_right(&mut self, n: usize) {
self.rotate_left(BIT_COUNT - n % BIT_COUNT);
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Default
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
fn default() -> Self {
Self::new()
}
}
impl<'bitmap, const BIT_COUNT: usize, const BUCKET_COUNT: usize> IntoIterator
for &'bitmap BitMap<BIT_COUNT, BUCKET_COUNT>
{
type Item = bool;
type IntoIter = BitMapIter<'bitmap, BIT_COUNT, BUCKET_COUNT>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Debug for BitMap<BIT_COUNT, BUCKET_COUNT> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "LSB -> ")?;
for (i, bit) in self.iter().enumerate() {
if i % 8 == 0 {
write!(f, "{i}: ")?;
}
write!(f, "{}", if bit { '1' } else { '0' })?;
if i % 8 == 7 && i < BUCKET_COUNT * 8 - 1 {
write!(f, " ")?;
}
}
write!(f, " <- MSB")?;
Ok(())
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> FromIterator<bool>
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
runtime_assert_const_params(BIT_COUNT, BUCKET_COUNT);
let mut bm = Self::new();
let mut idx = 0;
for bit in iter {
if idx >= BIT_COUNT {
panic!("Iterator yielded more than {BIT_COUNT} elements");
}
if bit {
bm.set(idx);
}
idx += 1;
}
if idx != BIT_COUNT {
panic!("Iterator yielded fewer than {BIT_COUNT} elements");
}
bm
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> BitAnd for BitMap<BIT_COUNT, BUCKET_COUNT> {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
self.bit_and(&rhs)
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> BitAndAssign
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
fn bitand_assign(&mut self, rhs: Self) {
self.in_place_bit_and(&rhs)
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> BitOr for BitMap<BIT_COUNT, BUCKET_COUNT> {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
self.bit_or(&rhs)
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> BitOrAssign
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
fn bitor_assign(&mut self, rhs: Self) {
self.in_place_bit_or(&rhs)
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> BitXor for BitMap<BIT_COUNT, BUCKET_COUNT> {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
self.bit_xor(&rhs)
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> BitXorAssign
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
fn bitxor_assign(&mut self, rhs: Self) {
self.in_place_bit_xor(&rhs)
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Not for BitMap<BIT_COUNT, BUCKET_COUNT> {
type Output = Self;
fn not(self) -> Self::Output {
self.bit_not()
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Shl<usize>
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
type Output = Self;
fn shl(mut self, rhs: usize) -> Self::Output {
self.shift_left(rhs);
self
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> ShlAssign<usize>
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
fn shl_assign(&mut self, rhs: usize) {
self.shift_left(rhs);
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Shr<usize>
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
type Output = Self;
fn shr(mut self, rhs: usize) -> Self::Output {
self.shift_right(rhs);
self
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> ShrAssign<usize>
for BitMap<BIT_COUNT, BUCKET_COUNT>
{
fn shr_assign(&mut self, rhs: usize) {
self.shift_right(rhs);
}
}
#[derive(Clone, Copy)]
pub struct BitMapIter<'bitmap, const BIT_COUNT: usize, const BUCKET_COUNT: usize> {
bytes: &'bitmap [u8; BUCKET_COUNT],
group_idx: usize,
item_idx: usize,
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Iterator
for BitMapIter<'_, BIT_COUNT, BUCKET_COUNT>
{
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
let absolute_idx = self.group_idx * 8 + self.item_idx;
if absolute_idx >= BIT_COUNT {
return None;
}
let bit = self.bytes[self.group_idx] & 1 << self.item_idx;
self.item_idx += 1;
if self.item_idx == 8 {
self.item_idx = 0;
self.group_idx += 1;
}
Some(bit != 0)
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> FusedIterator
for BitMapIter<'_, BIT_COUNT, BUCKET_COUNT>
{
}
#[derive(Clone, Copy)]
pub struct IterOnes<'bitmap, const BIT_COUNT: usize, const BUCKET_COUNT: usize> {
bytes: &'bitmap [u8; BUCKET_COUNT],
byte_idx: usize,
current: u8,
base_bit_idx: usize,
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Iterator
for IterOnes<'_, BIT_COUNT, BUCKET_COUNT>
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.byte_idx < BUCKET_COUNT {
if self.current != 0 {
let tz = self.current.trailing_zeros() as usize;
let idx = self.base_bit_idx + tz;
if idx >= BIT_COUNT {
return None;
}
self.current &= self.current - 1; return Some(idx);
}
self.byte_idx += 1;
self.base_bit_idx += 8;
self.current = *self.bytes.get(self.byte_idx).unwrap_or(&0);
}
None
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> FusedIterator
for IterOnes<'_, BIT_COUNT, BUCKET_COUNT>
{
}
#[derive(Clone, Copy)]
pub struct IterZeros<'bitmap, const BIT_COUNT: usize, const BUCKET_COUNT: usize> {
bytes: &'bitmap [u8; BUCKET_COUNT],
byte_idx: usize,
current: u8,
base_bit_idx: usize,
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> Iterator
for IterZeros<'_, BIT_COUNT, BUCKET_COUNT>
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.byte_idx < BUCKET_COUNT {
if self.current != 0 {
let tz = self.current.trailing_zeros() as usize;
let idx = self.base_bit_idx + tz;
if idx >= BIT_COUNT {
self.current = 0; return None;
}
self.current &= self.current - 1; return Some(idx);
}
self.byte_idx += 1;
self.base_bit_idx += 8;
self.current = !*self.bytes.get(self.byte_idx).unwrap_or(&0);
}
None
}
}
impl<const BIT_COUNT: usize, const BUCKET_COUNT: usize> FusedIterator
for IterZeros<'_, BIT_COUNT, BUCKET_COUNT>
{
}