#![allow(clippy::precedence, clippy::verbose_bit_mask)]
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use super::util;
use crate::storage::*;
use crate::structure::bititer::BitIter;
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, Bytes, BytesMut};
use futures::io;
use futures::stream::{Stream, StreamExt, TryStreamExt};
use std::{convert::TryFrom, error, fmt};
use tokio_util::codec::{Decoder, FramedRead};
#[derive(Clone)]
pub struct BitArray {
len: u64,
buf: Bytes,
}
#[derive(Debug, PartialEq)]
pub enum BitArrayError {
InputBufferTooSmall(usize),
UnexpectedInputBufferSize(u64, u64, u64),
}
impl BitArrayError {
fn validate_input_buf_size(input_buf_size: usize) -> Result<(), Self> {
if input_buf_size < 8 {
return Err(BitArrayError::InputBufferTooSmall(input_buf_size));
}
Ok(())
}
fn validate_len(input_buf_size: usize, len: u64) -> Result<(), Self> {
let expected_buf_size = {
let after_shifting = len >> 6 << 3;
if len & 63 == 0 {
after_shifting + 8
} else {
after_shifting + 16
}
};
let input_buf_size = u64::try_from(input_buf_size).unwrap();
if input_buf_size != expected_buf_size {
return Err(BitArrayError::UnexpectedInputBufferSize(
input_buf_size,
expected_buf_size,
len,
));
}
Ok(())
}
}
impl fmt::Display for BitArrayError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use BitArrayError::*;
match self {
InputBufferTooSmall(input_buf_size) => {
write!(f, "expected input buffer size ({}) >= 8", input_buf_size)
}
UnexpectedInputBufferSize(input_buf_size, expected_buf_size, len) => write!(
f,
"expected input buffer size ({}) to be {} for {} bits",
input_buf_size, expected_buf_size, len
),
}
}
}
impl error::Error for BitArrayError {}
impl From<BitArrayError> for io::Error {
fn from(err: BitArrayError) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, err)
}
}
fn read_control_word(buf: &[u8], input_buf_size: usize) -> Result<u64, BitArrayError> {
let len = BigEndian::read_u64(buf);
BitArrayError::validate_len(input_buf_size, len)?;
Ok(len)
}
impl BitArray {
pub fn from_bits(mut buf: Bytes) -> Result<BitArray, BitArrayError> {
let input_buf_size = buf.len();
BitArrayError::validate_input_buf_size(input_buf_size)?;
let len = read_control_word(&buf.split_off(input_buf_size - 8), input_buf_size)?;
Ok(BitArray { buf, len })
}
pub fn bits(&self) -> &[u8] {
&self.buf
}
pub fn len(&self) -> usize {
usize::try_from(self.len).unwrap_or_else(|_| {
panic!(
"expected length ({}) to fit in {} bytes",
self.len,
std::mem::size_of::<usize>()
)
})
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn get(&self, index: usize) -> bool {
let len = self.len();
debug_assert!(index < len, "expected index ({}) < length ({})", index, len);
let byte = self.buf[index / 8];
let mask = 0b1000_0000 >> index % 8;
byte & mask != 0
}
pub fn iter(&self) -> impl Iterator<Item = bool> {
let bits = self.clone();
(0..bits.len()).map(move |index| bits.get(index))
}
}
pub struct BitArrayFileBuilder<W> {
dest: W,
current: u64,
count: u64,
}
impl<W: SyncableFile> BitArrayFileBuilder<W> {
pub fn new(dest: W) -> BitArrayFileBuilder<W> {
BitArrayFileBuilder {
dest,
current: 0,
count: 0,
}
}
pub async fn push(&mut self, bit: bool) -> io::Result<()> {
if bit {
let pos = self.count & 0b11_1111;
self.current |= 0x8000_0000_0000_0000 >> pos;
}
self.count += 1;
if self.count & 0b11_1111 == 0 {
util::write_u64(&mut self.dest, self.current).await?;
self.current = 0;
}
Ok(())
}
pub async fn push_all<S: Stream<Item = io::Result<bool>> + Unpin>(
&mut self,
mut stream: S,
) -> io::Result<()> {
while let Some(bit) = stream.next().await {
let bit = bit?;
self.push(bit).await?;
}
Ok(())
}
async fn finalize_data(&mut self) -> io::Result<()> {
if self.count & 0b11_1111 != 0 {
util::write_u64(&mut self.dest, self.current).await?;
}
Ok(())
}
pub async fn finalize(mut self) -> io::Result<()> {
let count = self.count;
self.finalize_data().await?;
util::write_u64(&mut self.dest, count).await?;
self.dest.flush().await?;
self.dest.sync_all().await?;
Ok(())
}
pub fn count(&self) -> u64 {
self.count
}
}
pub struct BitArrayBlockDecoder {
readahead: Option<u64>,
}
impl Decoder for BitArrayBlockDecoder {
type Item = u64;
type Error = io::Error;
fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<u64>, io::Error> {
Ok(decode_next_bitarray_block(bytes, &mut self.readahead))
}
}
fn decode_next_bitarray_block<B: Buf>(bytes: &mut B, readahead: &mut Option<u64>) -> Option<u64> {
if bytes.remaining() < 8 {
return None;
}
match readahead.replace(bytes.get_u64()) {
Some(word) => Some(word),
None => decode_next_bitarray_block(bytes, readahead),
}
}
pub fn bitarray_stream_blocks<R: AsyncRead + Unpin>(r: R) -> FramedRead<R, BitArrayBlockDecoder> {
FramedRead::new(r, BitArrayBlockDecoder { readahead: None })
}
pub fn bitarray_iter_blocks<B: Buf>(b: &mut B) -> BitArrayBlockIterator<B> {
BitArrayBlockIterator {
buf: b,
readahead: None,
}
}
pub struct BitArrayBlockIterator<'a, B: Buf> {
buf: &'a mut B,
readahead: Option<u64>,
}
impl<'a, B: Buf> Iterator for BitArrayBlockIterator<'a, B> {
type Item = u64;
fn next(&mut self) -> Option<u64> {
decode_next_bitarray_block(self.buf, &mut self.readahead)
}
}
pub(crate) async fn bitarray_len_from_file<F: FileLoad>(f: F) -> io::Result<u64> {
BitArrayError::validate_input_buf_size(f.size().await?)?;
let mut control_word = vec![0; 8];
f.open_read_from(f.size().await? - 8)
.await?
.read_exact(&mut control_word)
.await?;
Ok(read_control_word(&control_word, f.size().await?)?)
}
pub async fn bitarray_stream_bits<F: FileLoad>(
f: F,
) -> io::Result<impl Stream<Item = io::Result<bool>> + Unpin> {
let len = bitarray_len_from_file(f.clone()).await?;
Ok(bitarray_stream_blocks(f.open_read().await?)
.map_ok(|block| util::stream_iter_ok(BitIter::new(block)))
.try_flatten()
.into_stream()
.take(len as usize))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::memory::*;
use futures::executor::block_on;
use futures::future;
#[test]
fn bit_array_error() {
assert_eq!(
"expected input buffer size (7) >= 8",
BitArrayError::InputBufferTooSmall(7).to_string()
);
assert_eq!(
"expected input buffer size (9) to be 8 for 0 bits",
BitArrayError::UnexpectedInputBufferSize(9, 8, 0).to_string()
);
assert_eq!(
io::Error::new(
io::ErrorKind::InvalidData,
BitArrayError::InputBufferTooSmall(7)
)
.to_string(),
io::Error::from(BitArrayError::InputBufferTooSmall(7)).to_string()
);
}
#[test]
fn validate_input_buf_size() {
let val = |buf_size| BitArrayError::validate_input_buf_size(buf_size);
let err = |buf_size| Err(BitArrayError::InputBufferTooSmall(buf_size));
assert_eq!(err(7), val(7));
assert_eq!(Ok(()), val(8));
assert_eq!(Ok(()), val(9));
assert_eq!(Ok(()), val(usize::max_value()));
}
#[test]
fn validate_len() {
let val = |buf_size, len| BitArrayError::validate_len(buf_size, len);
let err = |buf_size, expected, len| {
Err(BitArrayError::UnexpectedInputBufferSize(
buf_size, expected, len,
))
};
assert_eq!(err(0, 8, 0), val(0, 0));
assert_eq!(Ok(()), val(16, 1));
assert_eq!(Ok(()), val(16, 2));
#[cfg(target_pointer_width = "64")]
assert_eq!(
Ok(()),
val(
usize::try_from(u128::from(u64::max_value()) + 65 >> 6 << 3).unwrap(),
u64::max_value()
)
);
}
#[test]
fn decode() {
let mut decoder = BitArrayBlockDecoder { readahead: None };
let mut bytes = BytesMut::from([0u8; 8].as_ref());
assert_eq!(None, Decoder::decode(&mut decoder, &mut bytes).unwrap());
}
#[test]
fn empty() {
assert!(BitArray::from_bits(Bytes::from([0u8; 8].as_ref()))
.unwrap()
.is_empty());
}
#[tokio::test]
async fn construct_and_parse_small_bitarray() {
let x = MemoryBackedStore::new();
let contents = vec![true, true, false, false, true];
let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
block_on(async {
builder.push_all(util::stream_iter_ok(contents)).await?;
builder.finalize().await?;
Ok::<_, io::Error>(())
})
.unwrap();
let loaded = block_on(x.map()).unwrap();
let bitarray = BitArray::from_bits(loaded).unwrap();
assert_eq!(true, bitarray.get(0));
assert_eq!(true, bitarray.get(1));
assert_eq!(false, bitarray.get(2));
assert_eq!(false, bitarray.get(3));
assert_eq!(true, bitarray.get(4));
}
#[tokio::test]
async fn construct_and_parse_large_bitarray() {
let x = MemoryBackedStore::new();
let contents = (0..).map(|n| n % 3 == 0).take(123456);
let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
block_on(async {
builder.push_all(util::stream_iter_ok(contents)).await?;
builder.finalize().await?;
Ok::<_, io::Error>(())
})
.unwrap();
let loaded = block_on(x.map()).unwrap();
let bitarray = BitArray::from_bits(loaded).unwrap();
for i in 0..bitarray.len() {
assert_eq!(i % 3 == 0, bitarray.get(i));
}
}
#[tokio::test]
async fn bitarray_len_from_file_errors() {
let store = MemoryBackedStore::new();
let mut writer = store.open_write().await.unwrap();
writer.write_all(&[0, 0, 0]).await.unwrap();
writer.sync_all().await.unwrap();
assert_eq!(
io::Error::from(BitArrayError::InputBufferTooSmall(3)).to_string(),
block_on(bitarray_len_from_file(store))
.err()
.unwrap()
.to_string()
);
let store = MemoryBackedStore::new();
let mut writer = store.open_write().await.unwrap();
writer.write_all(&[0, 0, 0, 0, 0, 0, 0, 2]).await.unwrap();
writer.sync_all().await.unwrap();
assert_eq!(
io::Error::from(BitArrayError::UnexpectedInputBufferSize(8, 16, 2)).to_string(),
block_on(bitarray_len_from_file(store))
.err()
.unwrap()
.to_string()
);
}
#[tokio::test]
async fn stream_blocks() {
let x = MemoryBackedStore::new();
let contents: Vec<bool> = (0..).map(|n| n % 4 == 1).take(256).collect();
let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
builder
.push_all(util::stream_iter_ok(contents))
.await
.unwrap();
builder.finalize().await.unwrap();
let stream = bitarray_stream_blocks(x.open_read().await.unwrap());
stream
.try_for_each(|block| future::ok(assert_eq!(0x4444444444444444, block)))
.await
.unwrap();
}
#[tokio::test]
async fn stream_bits() {
let x = MemoryBackedStore::new();
let contents: Vec<_> = (0..).map(|n| n % 4 == 1).take(123).collect();
let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
block_on(async {
builder
.push_all(util::stream_iter_ok(contents.clone()))
.await?;
builder.finalize().await?;
Ok::<_, io::Error>(())
})
.unwrap();
let result: Vec<_> =
block_on(bitarray_stream_bits(x).await.unwrap().try_collect()).unwrap();
assert_eq!(contents, result);
}
}