pub mod unary;
use std::io::{Read, Write};
use anyhow::{anyhow, Result};
use crate::bit_vectors::prelude::*;
use crate::broadword;
use crate::utils::MatrixView;
use crate::Serializable;
use unary::UnaryIter;
pub const WORD_LEN: usize = std::mem::size_of::<usize>() * 8;
#[derive(Default, Clone, PartialEq, Eq)]
pub struct BitVector {
words: Vec<usize>,
len: usize,
}
impl BitVector {
pub fn new() -> Self {
Self::default()
}
pub fn with_capacity(capa: usize) -> Self {
Self {
words: Vec::with_capacity(Self::words_for(capa)),
len: 0,
}
}
pub fn from_bit(bit: bool, len: usize) -> Self {
let word = if bit { usize::MAX } else { 0 };
let mut words = vec![word; Self::words_for(len)];
let shift = len % WORD_LEN;
if shift != 0 {
let mask = (1 << shift) - 1;
*words.last_mut().unwrap() &= mask;
}
Self { words, len }
}
pub fn from_bits<I>(bits: I) -> Self
where
I: IntoIterator<Item = bool>,
{
let mut this = Self::new();
bits.into_iter().for_each(|b| this.push_bit(b));
this
}
pub fn get_bit(&self, pos: usize) -> Option<bool> {
if pos < self.len {
let (block, shift) = (pos / WORD_LEN, pos % WORD_LEN);
Some((self.words[block] >> shift) & 1 == 1)
} else {
None
}
}
#[inline(always)]
pub fn set_bit(&mut self, pos: usize, bit: bool) -> Result<()> {
if self.len() <= pos {
return Err(anyhow!(
"pos must be no greater than self.len()={}, but got {pos}.",
self.len()
));
}
let word = pos / WORD_LEN;
let pos_in_word = pos % WORD_LEN;
self.words[word] &= !(1 << pos_in_word);
self.words[word] |= (bit as usize) << pos_in_word;
Ok(())
}
#[inline(always)]
pub fn push_bit(&mut self, bit: bool) {
let pos_in_word = self.len % WORD_LEN;
if pos_in_word == 0 {
self.words.push(bit as usize);
} else {
let cur_word = self.words.last_mut().unwrap();
*cur_word |= (bit as usize) << pos_in_word;
}
self.len += 1;
}
#[inline(always)]
pub fn get_bits(&self, pos: usize, len: usize) -> Option<usize> {
if WORD_LEN < len || self.len() < pos + len {
return None;
}
if len == 0 {
return Some(0);
}
let (block, shift) = (pos / WORD_LEN, pos % WORD_LEN);
let mask = {
if len < WORD_LEN {
(1 << len) - 1
} else {
usize::MAX
}
};
let bits = if shift + len <= WORD_LEN {
(self.words[block] >> shift) & mask
} else {
(self.words[block] >> shift) | ((self.words[block + 1] << (WORD_LEN - shift)) & mask)
};
Some(bits)
}
#[inline(always)]
pub fn set_bits(&mut self, pos: usize, bits: usize, len: usize) -> Result<()> {
if WORD_LEN < len {
return Err(anyhow!(
"len must be no greater than {WORD_LEN}, but got {len}."
));
}
if self.len() < pos + len {
return Err(anyhow!(
"pos+len must be no greater than self.len()={}, but got {}.",
self.len(),
pos + len
));
}
if len == 0 {
return Ok(());
}
let mask = {
if len < WORD_LEN {
(1 << len) - 1
} else {
usize::MAX
}
};
let bits = bits & mask;
let word = pos / WORD_LEN;
let pos_in_word = pos % WORD_LEN;
self.words[word] &= !(mask << pos_in_word);
self.words[word] |= bits << pos_in_word;
let stored = WORD_LEN - pos_in_word;
if stored < len {
self.words[word + 1] &= !(mask >> stored);
self.words[word + 1] |= bits >> stored;
}
Ok(())
}
#[inline(always)]
pub fn push_bits(&mut self, bits: usize, len: usize) -> Result<()> {
if WORD_LEN < len {
return Err(anyhow!(
"len must be no greater than {WORD_LEN}, but got {len}."
));
}
if len == 0 {
return Ok(());
}
let mask = {
if len < WORD_LEN {
(1 << len) - 1
} else {
usize::MAX
}
};
let bits = bits & mask;
let pos_in_word = self.len % WORD_LEN;
if pos_in_word == 0 {
self.words.push(bits);
} else {
let cur_word = self.words.last_mut().unwrap();
*cur_word |= bits << pos_in_word;
if len > WORD_LEN - pos_in_word {
self.words.push(bits >> (WORD_LEN - pos_in_word));
}
}
self.len += len;
Ok(())
}
pub fn predecessor1(&self, pos: usize) -> Option<usize> {
if self.len() <= pos {
return None;
}
let mut block = pos / WORD_LEN;
let shift = WORD_LEN - pos % WORD_LEN - 1;
let mut word = (self.words[block] << shift) >> shift;
loop {
if let Some(ret) = broadword::msb(word) {
return Some(block * WORD_LEN + ret);
} else if block == 0 {
return None;
}
block -= 1;
word = self.words[block];
}
}
pub fn predecessor0(&self, pos: usize) -> Option<usize> {
if self.len() <= pos {
return None;
}
let mut block = pos / WORD_LEN;
let shift = WORD_LEN - pos % WORD_LEN - 1;
let mut word = (!self.words[block] << shift) >> shift;
loop {
if let Some(ret) = broadword::msb(word) {
return Some(block * WORD_LEN + ret);
} else if block == 0 {
return None;
}
block -= 1;
word = !self.words[block];
}
}
pub fn successor1(&self, pos: usize) -> Option<usize> {
if self.len() <= pos {
return None;
}
let mut block = pos / WORD_LEN;
let shift = pos % WORD_LEN;
let mut word = (self.words[block] >> shift) << shift;
loop {
if let Some(ret) = broadword::lsb(word) {
return Some(block * WORD_LEN + ret).filter(|&i| i < self.len());
}
block += 1;
if block == self.words.len() {
return None;
}
word = self.words[block];
}
}
pub fn successor0(&self, pos: usize) -> Option<usize> {
if self.len() <= pos {
return None;
}
let mut block = pos / WORD_LEN;
let shift = pos % WORD_LEN;
let mut word = (!self.words[block] >> shift) << shift;
loop {
if let Some(ret) = broadword::lsb(word) {
return Some(block * WORD_LEN + ret).filter(|&i| i < self.len());
}
block += 1;
if block == self.words.len() {
return None;
}
word = !self.words[block];
}
}
pub const fn iter(&self) -> Iter {
Iter::new(self)
}
pub fn unary_iter(&self, pos: usize) -> UnaryIter {
UnaryIter::new(self, pos)
}
#[inline(always)]
pub fn get_word64(&self, pos: usize) -> Option<usize> {
if self.len <= pos {
return None;
}
let (block, shift) = (pos / WORD_LEN, pos % WORD_LEN);
let mut word = self.words[block] >> shift;
if shift != 0 && block + 1 < self.words.len() {
word |= self.words[block + 1] << (64 - shift);
}
Some(word)
}
pub const fn len(&self) -> usize {
self.len
}
pub const fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn words(&self) -> &[usize] {
&self.words
}
pub fn into_words(self) -> Vec<usize> {
self.words
}
pub fn capacity(&self) -> usize {
self.words.capacity() * WORD_LEN
}
#[inline(always)]
pub fn num_words(&self) -> usize {
self.words.len()
}
pub fn shrink_to_fit(&mut self) {
self.words.shrink_to_fit();
}
#[inline(always)]
const fn words_for(n: usize) -> usize {
(n + WORD_LEN - 1) / WORD_LEN
}
}
impl Build for BitVector {
fn build_from_bits<I>(
bits: I,
_with_rank: bool,
_with_select1: bool,
_with_select0: bool,
) -> Result<Self>
where
I: IntoIterator<Item = bool>,
Self: Sized,
{
Ok(Self::from_bits(bits))
}
}
impl NumBits for BitVector {
fn num_bits(&self) -> usize {
self.len()
}
fn num_ones(&self) -> usize {
self.rank1(self.len).unwrap()
}
}
impl Access for BitVector {
fn access(&self, pos: usize) -> Option<bool> {
if pos < self.len {
let (block, shift) = (pos / WORD_LEN, pos % WORD_LEN);
Some((self.words[block] >> shift) & 1 == 1)
} else {
None
}
}
}
impl Rank for BitVector {
fn rank1(&self, pos: usize) -> Option<usize> {
if self.len() < pos {
return None;
}
let mut r = 0;
let (wpos, left) = (pos / WORD_LEN, pos % WORD_LEN);
for &w in &self.words[..wpos] {
r += broadword::popcount(w);
}
if left != 0 {
r += broadword::popcount(self.words[wpos] << (WORD_LEN - left));
}
Some(r)
}
fn rank0(&self, pos: usize) -> Option<usize> {
Some(pos - self.rank1(pos)?)
}
}
impl Select for BitVector {
fn select1(&self, k: usize) -> Option<usize> {
let mut wpos = 0;
let mut cur_rank = 0;
while wpos < self.words.len() {
let cnt = broadword::popcount(self.words[wpos]);
if k < cur_rank + cnt {
break;
}
wpos += 1;
cur_rank += cnt;
}
if wpos == self.words.len() {
return None;
}
let sel =
wpos * WORD_LEN + broadword::select_in_word(self.words[wpos], k - cur_rank).unwrap();
Some(sel)
}
fn select0(&self, k: usize) -> Option<usize> {
let mut wpos = 0;
let mut cur_rank = 0;
while wpos < self.words.len() {
let cnt = broadword::popcount(!self.words[wpos]);
if k < cur_rank + cnt {
break;
}
wpos += 1;
cur_rank += cnt;
}
if wpos == self.words.len() {
return None;
}
let sel =
wpos * WORD_LEN + broadword::select_in_word(!self.words[wpos], k - cur_rank).unwrap();
(sel < self.len()).then(|| sel)
}
}
pub struct Iter<'a> {
bv: &'a BitVector,
pos: usize,
}
impl<'a> Iter<'a> {
pub const fn new(bv: &'a BitVector) -> Self {
Self { bv, pos: 0 }
}
}
impl Iterator for Iter<'_> {
type Item = bool;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.bv.len() {
let x = self.bv.access(self.pos).unwrap();
self.pos += 1;
Some(x)
} else {
None
}
}
#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.bv.len(), Some(self.bv.len()))
}
}
impl std::iter::Extend<bool> for BitVector {
fn extend<I>(&mut self, bits: I)
where
I: IntoIterator<Item = bool>,
{
bits.into_iter().for_each(|b| self.push_bit(b));
}
}
impl std::fmt::Debug for BitVector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut bits = vec![0u8; self.len()];
for (i, b) in bits.iter_mut().enumerate() {
*b = self.access(i).unwrap() as u8;
}
f.debug_struct("BitVector")
.field("bits", &MatrixView::new(&bits, 16))
.field("len", &self.len)
.finish()
}
}
impl Serializable for BitVector {
fn serialize_into<W: Write>(&self, mut writer: W) -> Result<usize> {
let mut mem = self.words.serialize_into(&mut writer)?;
mem += self.len.serialize_into(&mut writer)?;
Ok(mem)
}
fn deserialize_from<R: Read>(mut reader: R) -> Result<Self> {
let words = Vec::<usize>::deserialize_from(&mut reader)?;
let len = usize::deserialize_from(&mut reader)?;
Ok(Self { words, len })
}
fn size_in_bytes(&self) -> usize {
self.words.size_in_bytes() + usize::size_of().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_bit_oob() {
let mut bv = BitVector::from_bit(false, 3);
let e = bv.set_bit(3, true);
assert_eq!(
e.err().map(|x| x.to_string()),
Some("pos must be no greater than self.len()=3, but got 3.".to_string())
);
}
#[test]
fn test_set_bits_over_word() {
let mut bv = BitVector::from_bit(false, 100);
let e = bv.set_bits(0, 0b0, 65);
assert_eq!(
e.err().map(|x| x.to_string()),
Some("len must be no greater than 64, but got 65.".to_string())
);
}
#[test]
fn test_set_bits_oob() {
let mut bv = BitVector::from_bit(false, 3);
let e = bv.set_bits(2, 0b11, 2);
assert_eq!(
e.err().map(|x| x.to_string()),
Some("pos+len must be no greater than self.len()=3, but got 4.".to_string())
);
}
#[test]
fn test_set_bits_truncation() {
let mut bv = BitVector::from_bit(false, 3);
bv.set_bits(0, 0b111, 2).unwrap();
assert_eq!(bv, BitVector::from_bits([true, true, false]));
}
#[test]
fn test_set_bits_accross_word() {
let mut bv = BitVector::from_bit(false, 100);
bv.set_bits(62, 0b11111, 5).unwrap();
assert_eq!(bv.get_bits(61, 7).unwrap(), 0b0111110);
}
#[test]
fn test_push_bits_over_word() {
let mut bv = BitVector::new();
let e = bv.push_bits(0b0, 65);
assert_eq!(
e.err().map(|x| x.to_string()),
Some("len must be no greater than 64, but got 65.".to_string())
);
}
#[test]
fn test_push_bits_truncation() {
let mut bv = BitVector::new();
bv.push_bits(0b111, 2).unwrap();
assert_eq!(bv, BitVector::from_bits([true, true]));
}
#[test]
fn test_push_bits_accross_word() {
let mut bv = BitVector::from_bit(false, 62);
bv.push_bits(0b011111, 6).unwrap();
assert_eq!(bv.get_bits(61, 7).unwrap(), 0b0111110);
}
#[test]
fn test_get_word64_oob() {
let bv = BitVector::from_bit(false, 3);
assert_eq!(bv.get_word64(3), None);
}
#[test]
fn test_get_word64_overflow() {
let bv = BitVector::from_bit(true, 64);
assert_eq!(bv.get_word64(60), Some(0b1111));
}
#[test]
fn test_serialize() {
let mut bytes = vec![];
let bv = BitVector::from_bits([false, true, false, false, true]);
let size = bv.serialize_into(&mut bytes).unwrap();
let other = BitVector::deserialize_from(&bytes[..]).unwrap();
assert_eq!(bv, other);
assert_eq!(size, bytes.len());
assert_eq!(size, bv.size_in_bytes());
}
}