pub trait BitSet {
const SIZE: usize;
fn new() -> Self;
fn get_bit(&self, i: usize) -> bool;
fn set_bit(&mut self, i: usize);
fn clear_bit(&mut self, i: usize);
type Iter: Iterator<Item = usize>;
fn iter(&self) -> Self::Iter;
}
pub struct BitSetIter<T: BitSet> {
bits: T,
curr: usize,
}
impl<T: BitSet> Iterator for BitSetIter<T> {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
while self.curr < T::SIZE && !self.bits.get_bit(self.curr) {
self.curr += 1;
}
if self.curr < T::SIZE {
let i = self.curr;
self.curr += 1;
Some(i)
} else {
None
}
}
}
macro_rules! bitset_impl {
($t:ty,$b:tt) => {
impl BitSet for $t {
const SIZE: usize = $b;
type Iter = BitSetIter<$t>;
#[inline]
fn new() -> Self {
0
}
#[inline]
fn get_bit(&self, i: usize) -> bool {
if i < Self::SIZE {
(self & (1 << i)) != 0
} else {
false
}
}
#[inline]
fn set_bit(&mut self, i: usize) {
if i < Self::SIZE {
*self |= 1 << i;
}
}
#[inline]
fn clear_bit(&mut self, i: usize) {
if i < Self::SIZE {
*self &= !(1 << i);
}
}
#[inline]
fn iter(&self) -> Self::Iter {
BitSetIter {
bits: *self,
curr: 0,
}
}
}
};
}
bitset_impl!(u8, 8);
bitset_impl!(u16, 16);
bitset_impl!(u32, 32);
bitset_impl!(u64, 64);
bitset_impl!(u128, 128);
pub trait BitVec {
type Rep: BitSet;
fn get_bit(&self, i: usize) -> bool;
fn set_bit(&mut self, i: usize);
fn clear_bit(&mut self, i: usize);
fn and(&self, other: &Self) -> Self;
}
impl<T> BitVec for Vec<T>
where
T: BitSet + Copy + std::ops::BitAnd<Output = T>,
{
type Rep = T;
#[inline]
fn get_bit(&self, i: usize) -> bool {
if i / 32 >= self.len() {
false
} else {
self[i / 32].get_bit(i % 32)
}
}
#[inline]
fn set_bit(&mut self, i: usize) {
while i / 32 >= self.len() {
self.push(T::new());
}
self[i / 32].set_bit(i % 32);
}
#[inline]
fn clear_bit(&mut self, i: usize) {
if i / 32 < self.len() {
self[i / 32].clear_bit(i % 32);
}
}
#[inline]
fn and(&self, other: &Self) -> Self {
self.iter()
.zip(other.iter())
.map(|(x, y)| *x & *y)
.collect()
}
}
#[cfg(test)]
mod tests {
use crate::bitset::*;
#[test]
fn bitset() {
let mut x: u32 = 0;
for i in 0..32 {
assert!(x.get_bit(i) == false);
}
x.set_bit(12);
assert!(x.get_bit(12));
for i in 0..32 {
if i != 12 {
assert!(x.get_bit(i) == false);
}
}
x.clear_bit(12);
for i in 0..32 {
assert!(x.get_bit(i) == false);
}
x = 0xffffffff;
for i in 0..32 {
assert!(x.get_bit(i) == true);
}
x.clear_bit(14);
assert!(x.get_bit(14) == false);
for i in 0..32 {
if i != 14 {
assert!(x.get_bit(i) == true);
}
}
}
#[test]
fn bitset_iter() {
let x: u16 = 0b1010011101101010;
let idxs = x.iter().collect::<Vec<_>>();
assert_eq!(idxs, vec![1, 3, 5, 6, 8, 9, 10, 13, 15]);
}
}