use std::io::Read;
use arrow::{
array::BooleanBufferBuilder,
buffer::{BooleanBuffer, NullBuffer},
};
use bytes::Bytes;
use crate::{error::Result, memory::EstimateMemory};
use super::{
byte::{ByteRleDecoder, ByteRleEncoder},
PrimitiveValueDecoder, PrimitiveValueEncoder,
};
pub struct BooleanDecoder<R: Read> {
decoder: ByteRleDecoder<R>,
data: u8,
bits_in_data: usize,
}
impl<R: Read> BooleanDecoder<R> {
pub fn new(reader: R) -> Self {
Self {
decoder: ByteRleDecoder::new(reader),
bits_in_data: 0,
data: 0,
}
}
pub fn value(&mut self) -> bool {
let value = (self.data & 0x80) != 0;
self.data <<= 1;
self.bits_in_data -= 1;
value
}
}
impl<R: Read> PrimitiveValueDecoder<bool> for BooleanDecoder<R> {
fn skip(&mut self, n: usize) -> Result<()> {
let mut remaining_bits = n;
if self.bits_in_data > 0 {
let take = remaining_bits.min(self.bits_in_data);
self.data <<= take;
self.bits_in_data -= take;
remaining_bits -= take;
}
if remaining_bits == 0 {
return Ok(());
}
let whole_bytes = remaining_bits / 8;
if whole_bytes > 0 {
self.decoder.skip(whole_bytes)?;
remaining_bits -= whole_bytes * 8;
}
if remaining_bits > 0 {
let mut byte = [0i8; 1];
match self.decoder.decode(&mut byte) {
Ok(_) => {
self.data = (byte[0] as u8) << remaining_bits;
self.bits_in_data = 8 - remaining_bits;
}
Err(e) => {
return Err(e);
}
}
}
Ok(())
}
fn decode(&mut self, out: &mut [bool]) -> Result<()> {
for x in out.iter_mut() {
if self.bits_in_data == 0 {
let mut data = [0];
self.decoder.decode(&mut data)?;
self.data = data[0] as u8;
self.bits_in_data = 8;
}
*x = self.value();
}
Ok(())
}
}
pub struct BooleanEncoder {
byte_encoder: ByteRleEncoder,
builder: BooleanBufferBuilder,
}
impl EstimateMemory for BooleanEncoder {
fn estimate_memory_size(&self) -> usize {
self.builder.len() / 8
}
}
impl BooleanEncoder {
pub fn new() -> Self {
Self {
byte_encoder: ByteRleEncoder::new(),
builder: BooleanBufferBuilder::new(8),
}
}
pub fn extend(&mut self, null_buffer: &NullBuffer) {
let bb = null_buffer.inner();
self.extend_bb(bb);
}
pub fn extend_bb(&mut self, bb: &BooleanBuffer) {
self.builder.append_buffer(bb);
}
pub fn extend_present(&mut self, n: usize) {
self.builder.append_n(n, true);
}
pub fn extend_boolean(&mut self, b: bool) {
self.builder.append(b);
}
pub fn finish(&mut self) -> Bytes {
let bb = self.builder.finish();
let bytes = bb.values();
let bytes = bytes.iter().map(|b| b.reverse_bits()).collect::<Vec<_>>();
for &b in bytes.as_slice() {
self.byte_encoder.write_one(b as i8);
}
self.byte_encoder.take_inner()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic() {
let expected = vec![false; 800];
let data = [0x61u8, 0x00];
let data = &mut data.as_ref();
let mut decoder = BooleanDecoder::new(data);
let mut actual = vec![true; expected.len()];
decoder.decode(&mut actual).unwrap();
assert_eq!(actual, expected)
}
#[test]
fn literals() {
let expected = vec![
false, true, false, false, false, true, false, false, false, true, false, false, false, true, false, true, ];
let data = [0xfeu8, 0b01000100, 0b01000101];
let data = &mut data.as_ref();
let mut decoder = BooleanDecoder::new(data);
let mut actual = vec![true; expected.len()];
decoder.decode(&mut actual).unwrap();
assert_eq!(actual, expected)
}
#[test]
fn another() {
let expected = vec![true, false, false, false, false, false, false, false];
let data = [0xff, 0x80];
let data = &mut data.as_ref();
let mut decoder = BooleanDecoder::new(data);
let mut actual = vec![true; expected.len()];
decoder.decode(&mut actual).unwrap();
assert_eq!(actual, expected)
}
#[test]
fn test_skip_run() {
let data = [0x61u8, 0x00];
let mut decoder = BooleanDecoder::new(data.as_ref());
let mut batch = vec![true; 10];
decoder.decode(&mut batch).unwrap();
assert_eq!(batch, vec![false; 10]);
decoder.skip(80).unwrap();
let mut batch = vec![true; 10];
decoder.decode(&mut batch).unwrap();
assert_eq!(batch, vec![false; 10]);
}
#[test]
fn test_skip_all() {
let data = [0xffu8, 0x00u8];
let mut decoder = BooleanDecoder::new(data.as_ref());
decoder.skip(8).unwrap();
let mut batch = vec![true; 1];
let result = decoder.decode(&mut batch);
assert!(result.is_err());
}
#[test]
fn test_skip_partial_bits() {
let data = [0xfeu8, 0b01000100, 0b01000101]; let mut decoder = BooleanDecoder::new(data.as_ref());
decoder.skip(3).unwrap();
let mut batch = vec![true; 5];
decoder.decode(&mut batch).unwrap();
assert_eq!(batch, vec![false, false, true, false, false]);
}
#[test]
fn test_skip_cross_byte_boundary() {
let data = [0xfeu8, 0b01000100, 0b01000101]; let mut decoder = BooleanDecoder::new(data.as_ref());
decoder.skip(6).unwrap();
let mut batch = vec![true; 4];
decoder.decode(&mut batch).unwrap();
assert_eq!(batch, vec![false, false, false, true]);
}
#[test]
fn test_skip_zero() {
let data = [0x61u8, 0x00]; let mut decoder = BooleanDecoder::new(data.as_ref());
decoder.skip(0).unwrap();
let mut batch = vec![true; 10];
decoder.decode(&mut batch).unwrap();
assert_eq!(batch, vec![false; 10]);
}
#[test]
fn test_skip_exact_byte() {
let data = [0x61u8, 0x00]; let mut decoder = BooleanDecoder::new(data.as_ref());
decoder.skip(8).unwrap();
let mut batch = vec![true; 10];
decoder.decode(&mut batch).unwrap();
assert_eq!(batch, vec![false; 10]);
}
#[test]
fn test_skip_more_than_available() {
let data = [0xffu8, 0x00u8];
let mut decoder = BooleanDecoder::new(data.as_ref());
let result = decoder.skip(9);
assert!(result.is_err());
}
}