#![allow(clippy::comparison_chain)]
pub mod iter;
mod ops;
mod range;
mod rleplus;
mod unvalidated;
use std::collections::BTreeSet;
use std::ops::Range;
use iter::{ranges_from_bits, RangeIterator};
pub(crate) use range::RangeSize;
pub use rleplus::Error;
use thiserror::Error;
pub use unvalidated::{UnvalidatedBitField, Validate};
pub(crate) const MAX_ENCODED_SIZE: usize = 32 << 10;
#[derive(Clone, Error, Debug)]
#[error("bitfields may not include u64::MAX")]
pub struct OutOfRangeError;
impl From<OutOfRangeError> for Error {
fn from(_: OutOfRangeError) -> Self {
Error::RLEOverflow
}
}
#[derive(Debug, Default, Clone)]
pub struct BitField {
ranges: Vec<Range<u64>>,
set: BTreeSet<u64>,
unset: BTreeSet<u64>,
}
impl PartialEq for BitField {
fn eq(&self, other: &Self) -> bool {
Iterator::eq(self.ranges(), other.ranges())
}
}
#[doc(hidden)]
pub enum MaybeBitField {
Ok(BitField),
OutOfBounds,
}
impl MaybeBitField {
pub fn unwrap(self) -> BitField {
use MaybeBitField::*;
match self {
Ok(bf) => bf,
OutOfBounds => panic!("bitfield bit out of bounds"),
}
}
pub fn expect(self, message: &str) -> BitField {
use MaybeBitField::*;
match self {
Ok(bf) => bf,
OutOfBounds => panic!("{}", message),
}
}
}
impl TryFrom<MaybeBitField> for BitField {
type Error = OutOfRangeError;
fn try_from(value: MaybeBitField) -> Result<Self, Self::Error> {
match value {
MaybeBitField::Ok(bf) => Ok(bf),
MaybeBitField::OutOfBounds => Err(OutOfRangeError),
}
}
}
impl FromIterator<bool> for MaybeBitField {
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> MaybeBitField {
let mut iter = iter.into_iter().fuse();
let bits = (0u64..u64::MAX)
.zip(&mut iter)
.filter(|&(_, b)| b)
.map(|(i, _)| i);
let bf = BitField::from_ranges(ranges_from_bits(bits));
if iter.next().is_some() {
MaybeBitField::OutOfBounds
} else {
MaybeBitField::Ok(bf)
}
}
}
impl FromIterator<u64> for MaybeBitField {
fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> MaybeBitField {
let mut vec: Vec<_> = iter.into_iter().collect();
if vec.is_empty() {
MaybeBitField::Ok(BitField::new())
} else {
vec.sort_unstable();
vec.dedup();
if vec.last() == Some(&u64::MAX) {
MaybeBitField::OutOfBounds
} else {
MaybeBitField::Ok(BitField::from_ranges(ranges_from_bits(vec)))
}
}
}
}
impl BitField {
pub fn new() -> Self {
Self::default()
}
pub fn from_ranges(iter: impl RangeIterator) -> Self {
Self {
ranges: iter.collect(),
..Default::default()
}
}
pub fn try_from_bits<I>(iter: I) -> Result<Self, OutOfRangeError>
where
I: IntoIterator,
MaybeBitField: FromIterator<I::Item>,
{
iter.into_iter().collect::<MaybeBitField>().try_into()
}
pub fn set(&mut self, bit: u64) {
self.try_set(bit).unwrap()
}
pub fn try_set(&mut self, bit: u64) -> Result<(), OutOfRangeError> {
if bit == u64::MAX {
return Err(OutOfRangeError);
}
self.unset.remove(&bit);
self.set.insert(bit);
Ok(())
}
pub fn unset(&mut self, bit: u64) {
if bit == u64::MAX {
return;
}
self.set.remove(&bit);
self.unset.insert(bit);
}
pub fn get(&self, index: u64) -> bool {
if self.set.contains(&index) {
true
} else if self.unset.contains(&index) {
false
} else {
use std::cmp::Ordering;
self.ranges
.binary_search_by(|range| {
if index < range.start {
Ordering::Greater
} else if index >= range.end {
Ordering::Less
} else {
Ordering::Equal
}
})
.is_ok()
}
}
pub fn first(&self) -> Option<u64> {
match (
self.set.iter().min().copied(),
self.ranges
.iter()
.find_map(|r| r.clone().find(|i| !self.unset.contains(i))),
) {
(None, None) => None,
(Some(v), None) | (None, Some(v)) => Some(v),
(Some(a), Some(b)) => Some(std::cmp::min(a, b)),
}
}
pub fn last(&self) -> Option<u64> {
match (
self.set.iter().max().copied(),
self.ranges
.iter()
.rev()
.flat_map(|range| range.clone().rev())
.find(|i| !self.unset.contains(i)),
) {
(None, None) => None,
(Some(v), None) | (None, Some(v)) => Some(v),
(Some(a), Some(b)) => Some(std::cmp::max(a, b)),
}
}
pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
self.inner_ranges()
.union(ranges_from_bits(self.set.iter().copied()))
.flatten()
.filter(move |i| !self.unset.contains(i))
}
pub fn bounded_iter(&self, max: u64) -> Option<impl Iterator<Item = u64> + '_> {
if self.len() <= max {
Some(self.iter())
} else {
None
}
}
fn inner_ranges(&self) -> impl RangeIterator + '_ {
iter::Ranges::new(self.ranges.iter().cloned())
}
pub fn ranges(&self) -> impl RangeIterator + '_ {
self.inner_ranges()
.union(ranges_from_bits(self.set.iter().copied()))
.difference(ranges_from_bits(self.unset.iter().copied()))
}
pub fn is_empty(&self) -> bool {
self.set.is_empty()
&& self
.inner_ranges()
.flatten()
.all(|bit| self.unset.contains(&bit))
}
fn is_trivially_empty(&self) -> bool {
self.set.is_empty() && self.ranges.is_empty()
}
pub fn slice(&self, start: u64, len: u64) -> Option<Self> {
let slice = BitField::from_ranges(self.ranges().skip_bits(start).take_bits(len));
if slice.len() == len {
Some(slice)
} else {
None
}
}
pub fn len(&self) -> u64 {
self.ranges().map(|range| range.size()).sum()
}
pub fn cut(&self, other: &Self) -> Self {
Self::from_ranges(self.ranges().cut(other.ranges()))
}
pub fn union<'a>(bitfields: impl IntoIterator<Item = &'a Self>) -> Self {
bitfields.into_iter().fold(Self::new(), |a, b| a | b)
}
pub fn contains_any(&self, other: &BitField) -> bool {
!(self.is_trivially_empty() || other.is_trivially_empty())
&& self.ranges().intersection(other.ranges()).next().is_some()
}
pub fn contains_all(&self, other: &BitField) -> bool {
other.is_trivially_empty() || other.ranges().difference(self.ranges()).next().is_none()
}
}
#[macro_export]
macro_rules! bitfield {
(@iter) => {
std::iter::empty::<bool>()
};
(@iter $head:literal $(, $tail:literal)*) => {
std::iter::once($head != 0_u32).chain(bitfield!(@iter $($tail),*))
};
($($val:literal),* $(,)?) => {
bitfield!(@iter $($val),*).collect::<$crate::MaybeBitField>().unwrap()
};
}
#[cfg(feature = "json")]
pub mod json {
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::*;
use crate::iter::Ranges;
#[derive(Deserialize, Serialize, Debug, PartialEq)]
#[serde(transparent)]
pub struct BitFieldJson(#[serde(with = "self")] pub BitField);
#[derive(Serialize)]
#[serde(transparent)]
pub struct BitFieldJsonRef<'a>(#[serde(with = "self")] pub &'a BitField);
impl From<BitFieldJson> for BitField {
fn from(wrapper: BitFieldJson) -> Self {
wrapper.0
}
}
impl From<BitField> for BitFieldJson {
fn from(wrapper: BitField) -> Self {
BitFieldJson(wrapper)
}
}
fn serialize<S>(m: &BitField, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
if !m.is_empty() {
let mut seq = serializer.serialize_seq(None)?;
m.ranges().try_fold(0, |last_index, range| {
let zero_index = range.start - last_index;
let nonzero_index = range.end - range.start;
seq.serialize_element(&zero_index)?;
seq.serialize_element(&nonzero_index)?;
Ok(range.end)
})?;
seq.end()
} else {
let mut seq = serializer.serialize_seq(Some(1))?;
seq.serialize_element(&0)?;
seq.end()
}
}
fn deserialize<'de, D>(deserializer: D) -> std::result::Result<BitField, D::Error>
where
D: Deserializer<'de>,
{
let bitfield_bytes: Vec<u64> = Deserialize::deserialize(deserializer)?;
let mut ranges: Vec<Range<u64>> = Vec::new();
bitfield_bytes.iter().fold((false, 0), |last, index| {
let (should_set, last_index) = last;
let ending_index = index + last_index;
if should_set {
ranges.push(Range {
start: last_index,
end: ending_index,
})
}
(!should_set, ending_index)
});
let ranges = Ranges::new(ranges.iter().cloned());
Ok(BitField::from_ranges(ranges))
}
#[test]
fn serialization_starts_with_zeros() {
let bf = BitFieldJson(bitfield![0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1]);
let j = serde_json::to_string(&bf).unwrap();
assert_eq!(j, "[2,4,3,2]");
let bitfield: BitFieldJson = serde_json::from_str(&j).unwrap();
assert_eq!(bf, bitfield);
}
#[test]
fn serialization_starts_with_ones() {
let bf = BitFieldJson(bitfield![1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1]);
let j = serde_json::to_string(&bf).unwrap();
assert_eq!(j, "[0,6,3,2]");
let bitfield: BitFieldJson = serde_json::from_str(&j).unwrap();
assert_eq!(bf, bitfield);
}
#[test]
fn serialization_with_single_unut() {
let bf = BitFieldJson(bitfield![]);
let j = serde_json::to_string(&bf).unwrap();
assert_eq!(j, "[0]");
let bitfield: BitFieldJson = serde_json::from_str(&j).unwrap();
assert_eq!(bf, bitfield);
}
}