pub mod iter;
mod rleplus;
mod unvalidated;
pub use unvalidated::{UnvalidatedBitField, Validate};
use ahash::AHashSet;
use iter::{ranges_from_bits, RangeIterator};
use std::{
iter::FromIterator,
ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Range, Sub, SubAssign},
};
type Result<T> = std::result::Result<T, &'static str>;
#[derive(Debug, Default, Clone)]
pub struct BitField {
ranges: Vec<Range<usize>>,
set: AHashSet<usize>,
unset: AHashSet<usize>,
}
impl PartialEq for BitField {
fn eq(&self, other: &Self) -> bool {
Iterator::eq(self.ranges(), other.ranges())
}
}
impl FromIterator<usize> for BitField {
fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
let mut vec: Vec<_> = iter.into_iter().collect();
vec.sort_unstable();
Self::from_ranges(ranges_from_bits(vec))
}
}
impl FromIterator<bool> for BitField {
fn from_iter<I: IntoIterator<Item = bool>>(iter: I) -> Self {
let bits = iter
.into_iter()
.enumerate()
.filter(|&(_, b)| b)
.map(|(i, _)| i);
Self::from_ranges(ranges_from_bits(bits))
}
}
impl BitField {
pub fn new() -> Self {
Self::default()
}
pub fn from_ranges(iter: impl RangeIterator) -> Self {
Self {
ranges: iter.collect(),
..Default::default()
}
}
pub fn set(&mut self, bit: usize) {
self.unset.remove(&bit);
self.set.insert(bit);
}
pub fn unset(&mut self, bit: usize) {
self.set.remove(&bit);
self.unset.insert(bit);
}
pub fn get(&self, index: usize) -> bool {
if self.set.contains(&index) {
true
} else if self.unset.contains(&index) {
false
} else {
use std::cmp::Ordering;
self.ranges
.binary_search_by(|range| {
if index < range.start {
Ordering::Greater
} else if index >= range.end {
Ordering::Less
} else {
Ordering::Equal
}
})
.is_ok()
}
}
pub fn first(&self) -> Option<usize> {
let min_set_bit = self.set.iter().min();
let min_range = min_set_bit.map(|&bit| bit..bit + 1);
let min_range_iterator = iter::Ranges::new(min_range);
self.inner_ranges()
.union(min_range_iterator)
.flatten()
.find(|i| !self.unset.contains(i))
}
pub fn iter(&self) -> impl Iterator<Item = usize> + '_ {
let mut set_bits: Vec<_> = self.set.iter().copied().collect();
set_bits.sort_unstable();
self.inner_ranges()
.union(ranges_from_bits(set_bits))
.flatten()
.filter(move |i| !self.unset.contains(i))
}
pub fn bounded_iter(&self, max: usize) -> Result<impl Iterator<Item = usize> + '_> {
if self.len() <= max {
Ok(self.iter())
} else {
Err("Bits set exceeds max in retrieval")
}
}
fn inner_ranges(&self) -> impl RangeIterator + '_ {
iter::Ranges::new(self.ranges.iter().cloned())
}
pub fn ranges(&self) -> impl RangeIterator + '_ {
let ranges = |set: &AHashSet<usize>| {
let mut vec: Vec<_> = set.iter().copied().collect();
vec.sort_unstable();
ranges_from_bits(vec)
};
self.inner_ranges()
.union(ranges(&self.set))
.difference(ranges(&self.unset))
}
pub fn is_empty(&self) -> bool {
self.set.is_empty()
&& self
.inner_ranges()
.flatten()
.all(|bit| self.unset.contains(&bit))
}
pub fn slice(&self, start: usize, len: usize) -> Result<Self> {
let slice = BitField::from_ranges(self.ranges().skip_bits(start).take_bits(len));
if slice.len() == len {
Ok(slice)
} else {
Err("Not enough bits")
}
}
pub fn len(&self) -> usize {
self.ranges().map(|range| range.len()).sum()
}
pub fn cut(&self, other: &Self) -> Self {
Self::from_ranges(self.ranges().cut(other.ranges()))
}
pub fn union<'a>(bitfields: impl IntoIterator<Item = &'a Self>) -> Self {
bitfields.into_iter().fold(Self::new(), |a, b| &a | b)
}
pub fn contains_any(&self, other: &BitField) -> bool {
self.ranges().intersection(other.ranges()).next().is_some()
}
pub fn contains_all(&self, other: &BitField) -> bool {
other.ranges().difference(self.ranges()).next().is_none()
}
}
impl BitOr<&BitField> for &BitField {
type Output = BitField;
#[inline]
fn bitor(self, rhs: &BitField) -> Self::Output {
BitField::from_ranges(self.ranges().union(rhs.ranges()))
}
}
impl BitOrAssign<&BitField> for BitField {
#[inline]
fn bitor_assign(&mut self, rhs: &BitField) {
*self = &*self | rhs;
}
}
impl BitAnd<&BitField> for &BitField {
type Output = BitField;
#[inline]
fn bitand(self, rhs: &BitField) -> Self::Output {
BitField::from_ranges(self.ranges().intersection(rhs.ranges()))
}
}
impl BitAndAssign<&BitField> for BitField {
#[inline]
fn bitand_assign(&mut self, rhs: &BitField) {
*self = &*self & rhs;
}
}
impl Sub<&BitField> for &BitField {
type Output = BitField;
#[inline]
fn sub(self, rhs: &BitField) -> Self::Output {
BitField::from_ranges(self.ranges().difference(rhs.ranges()))
}
}
impl SubAssign<&BitField> for BitField {
#[inline]
fn sub_assign(&mut self, rhs: &BitField) {
*self = &*self - rhs;
}
}
impl BitXor<&BitField> for &BitField {
type Output = BitField;
fn bitxor(self, rhs: &BitField) -> Self::Output {
BitField::from_ranges(self.ranges().symmetric_difference(rhs.ranges()))
}
}
impl BitXorAssign<&BitField> for BitField {
fn bitxor_assign(&mut self, rhs: &BitField) {
*self = &*self ^ rhs;
}
}
#[macro_export]
macro_rules! bitfield {
(@iter) => {
std::iter::empty::<bool>()
};
(@iter $head:literal $(, $tail:literal)*) => {
std::iter::once($head != 0_u32).chain(bitfield!(@iter $($tail),*))
};
($($val:literal),* $(,)?) => {
bitfield!(@iter $($val),*).collect::<$crate::BitField>()
};
}
#[cfg(feature = "json")]
pub mod json {
use super::*;
use crate::iter::Ranges;
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Deserialize, Serialize, Debug, PartialEq)]
#[serde(transparent)]
pub struct BitFieldJson(#[serde(with = "self")] pub BitField);
#[derive(Serialize)]
#[serde(transparent)]
pub struct BitFieldJsonRef<'a>(#[serde(with = "self")] pub &'a BitField);
impl From<BitFieldJson> for BitField {
fn from(wrapper: BitFieldJson) -> Self {
wrapper.0
}
}
impl From<BitField> for BitFieldJson {
fn from(wrapper: BitField) -> Self {
BitFieldJson(wrapper)
}
}
fn serialize<S>(m: &BitField, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let total: usize = m.len();
if !m.is_empty() {
let mut seq = serializer.serialize_seq(Some(total))?;
m.ranges().fold(Ok(0), |last_index, range| {
let last_index = last_index?;
let zero_index = (range.start - last_index) as u8;
let nonzero_index = (range.end - range.start) as u8;
seq.serialize_element(&zero_index)?;
seq.serialize_element(&nonzero_index)?;
Ok(range.end)
})?;
seq.end()
} else {
let mut seq = serializer.serialize_seq(Some(1))?;
seq.serialize_element(&0)?;
seq.end()
}
}
fn deserialize<'de, D>(deserializer: D) -> std::result::Result<BitField, D::Error>
where
D: Deserializer<'de>,
{
let bitfield_bytes: Vec<usize> = Deserialize::deserialize(deserializer)?;
let mut ranges: Vec<Range<usize>> = Vec::new();
bitfield_bytes.iter().fold((false, 0), |last, index| {
let (should_set, last_index) = last;
let ending_index = index + last_index;
if should_set {
ranges.push(Range {
start: last_index,
end: ending_index,
})
}
(!should_set, ending_index)
});
let ranges = Ranges::new(ranges.iter().cloned());
Ok(BitField::from_ranges(ranges))
}
#[test]
fn serialization_starts_with_zeros() {
let bf = BitFieldJson(bitfield![0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1]);
let j = serde_json::to_string(&bf).unwrap();
assert_eq!(j, "[2,4,3,2]");
let bitfield: BitFieldJson = serde_json::from_str(&j).unwrap();
assert_eq!(bf, bitfield);
}
#[test]
fn serialization_starts_with_ones() {
let bf = BitFieldJson(bitfield![1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1]);
let j = serde_json::to_string(&bf).unwrap();
assert_eq!(j, "[0,6,3,2]");
let bitfield: BitFieldJson = serde_json::from_str(&j).unwrap();
assert_eq!(bf, bitfield);
}
#[test]
fn serialization_with_single_unut() {
let bf = BitFieldJson(bitfield![]);
let j = serde_json::to_string(&bf).unwrap();
assert_eq!(j, "[0]");
let bitfield: BitFieldJson = serde_json::from_str(&j).unwrap();
assert_eq!(bf, bitfield);
}
}