use std::fmt;
use std::iter;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Clone)]
pub struct SmallBitset {
near: u64,
#[allow(clippy::box_vec)]
far: Option<Box<Vec<u64>>>,
}
impl Default for SmallBitset {
fn default() -> Self {
SmallBitset { near: 0, far: None }
}
}
impl fmt::Debug for SmallBitset {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SmallBitset")?;
f.debug_list().entries(self.iter()).finish()
}
}
impl SmallBitset {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, val: usize) -> bool {
let (word, mask) = self.addr_mut(val);
let ret = 0 == (*word & mask);
*word |= mask;
ret
}
pub fn remove(&mut self, val: usize) -> bool {
let (word, mask) = self.addr_mut(val);
let ret = 0 != (*word & mask);
*word &= !mask;
ret
}
pub fn contains(&self, val: usize) -> bool {
let (word, mask) = self.addr(val);
0 != (word & mask)
}
pub fn iter<'a>(&'a self) -> impl Iterator<Item = usize> + 'a {
static EMPTY: Vec<u64> = Vec::new();
iter::once(self.near)
.chain(self.far.as_deref().unwrap_or(&EMPTY).iter().copied())
.enumerate()
.flat_map(move |(ix, word)| {
(0..64)
.filter(move |&bit| 0 != (word & (1 << bit)))
.map(move |bit| bit + ix * 64)
})
}
fn addr_mut(&mut self, val: usize) -> (&mut u64, u64) {
if val < 64 {
(&mut self.near, 1 << val)
} else {
let ix = val / 64 - 1;
let far = self.far.get_or_insert_with(|| Box::new(Vec::new()));
if far.len() <= ix {
far.resize(ix + 1, 0);
}
(&mut far[ix], 1 << (val % 64))
}
}
fn addr(&self, val: usize) -> (u64, u64) {
if val < 64 {
(self.near, 1 << val)
} else if let Some(far) = self.far.as_ref() {
let ix = val / 64 - 1;
(far.get(ix).copied().unwrap_or(0), 1 << (val % 64))
} else {
(0, 1 << (val % 64))
}
}
}
impl Serialize for SmallBitset {
fn serialize<S: Serializer>(
&self,
serializer: S,
) -> Result<S::Ok, S::Error> {
match self.far {
None => {
let near_array = [self.near];
if self.near == 0 {
&[] as &[u64]
} else {
&near_array as &[u64]
}
.serialize(serializer)
}
Some(ref far) => {
let mut elements = Vec::clone(far);
elements.push(self.near);
elements.serialize(serializer)
}
}
}
}
impl<'de> Deserialize<'de> for SmallBitset {
fn deserialize<D: Deserializer<'de>>(
deserializer: D,
) -> Result<Self, D::Error> {
let mut elements: Vec<u64> = Vec::deserialize(deserializer)?;
let near = elements.pop().unwrap_or(0);
Ok(SmallBitset {
near,
far: if elements.is_empty() {
None
} else {
Some(Box::new(elements))
},
})
}
}
#[cfg(test)]
mod test {
use serde_cbor;
use super::*;
#[test]
fn basic_operations() {
let mut bs = SmallBitset::new();
assert!(!bs.contains(0));
assert!(!bs.contains(100));
assert!(!bs.contains(usize::MAX));
assert!(bs.insert(0));
assert!(bs.insert(42));
assert!(!bs.insert(42));
assert!(bs.contains(0));
assert!(!bs.contains(1));
assert!(bs.contains(42));
assert_eq!(vec![0, 42], bs.iter().collect::<Vec<_>>());
assert!(bs.remove(0));
assert!(!bs.remove(0));
assert!(!bs.contains(0));
assert!(bs.contains(42));
assert_eq!(vec![42], bs.iter().collect::<Vec<_>>());
assert!(bs.insert(100));
assert!(bs.contains(100));
assert_eq!(vec![42, 100], bs.iter().collect::<Vec<_>>());
assert!(bs.insert(1000));
assert!(bs.contains(1000));
assert_eq!(vec![42, 100, 1000], bs.iter().collect::<Vec<_>>());
assert!(bs.remove(100));
assert!(!bs.contains(100));
assert_eq!(vec![42, 1000], bs.iter().collect::<Vec<_>>());
}
fn serde_flip(bs: &SmallBitset) {
let as_bytes = serde_cbor::to_vec(bs).unwrap();
let reread: SmallBitset =
serde_cbor::from_reader(&as_bytes[..]).unwrap();
assert_eq!(
bs.iter().collect::<Vec<_>>(),
reread.iter().collect::<Vec<_>>()
);
}
#[test]
fn test_serde() {
let mut bs = SmallBitset::new();
serde_flip(&bs);
bs.insert(0);
serde_flip(&bs);
bs.insert(42);
serde_flip(&bs);
bs.remove(0);
serde_flip(&bs);
bs.insert(100);
serde_flip(&bs);
bs.insert(1000);
serde_flip(&bs);
bs.remove(42);
serde_flip(&bs);
bs.remove(1000);
serde_flip(&bs);
}
}