use std::sync::Arc;
use arrow::array::ArrayData;
use arrow::datatypes::{ArrowPrimitiveType, UInt16Type, UInt32Type, UInt64Type, UInt8Type};
use arrow::util::bit_util::ceil;
use arrow_array::{cast::AsArray, Array, ArrayRef, PrimitiveArray};
use arrow_schema::DataType;
use bytes::{Bytes, BytesMut};
use futures::future::{BoxFuture, FutureExt};
use log::trace;
use num_traits::{AsPrimitive, PrimInt};
use snafu::{location, Location};
use lance_arrow::DataTypeExt;
use lance_core::{Error, Result};
use crate::encoder::EncodedBufferMeta;
use crate::{
decoder::{PageScheduler, PrimitivePageDecoder},
encoder::{BufferEncoder, EncodedBuffer},
};
pub fn num_compressed_bits(arr: ArrayRef) -> Option<u64> {
match arr.data_type() {
DataType::UInt8 => num_bits_for_type::<UInt8Type>(arr.as_primitive()),
DataType::UInt16 => num_bits_for_type::<UInt16Type>(arr.as_primitive()),
DataType::UInt32 => num_bits_for_type::<UInt32Type>(arr.as_primitive()),
DataType::UInt64 => num_bits_for_type::<UInt64Type>(arr.as_primitive()),
_ => None,
}
}
fn num_bits_for_type<T>(arr: &PrimitiveArray<T>) -> Option<u64>
where
T: ArrowPrimitiveType,
T::Native: PrimInt + AsPrimitive<u64>,
{
let max = arrow::compute::bit_or(arr);
let num_bits =
max.map(|max| arr.data_type().byte_width() as u64 * 8 - max.leading_zeros() as u64);
num_bits.map(|num_bits| num_bits.max(1))
}
#[derive(Debug)]
pub struct BitpackingBufferEncoder {
num_bits: u64,
}
impl BitpackingBufferEncoder {
pub fn new(num_bits: u64) -> Self {
Self { num_bits }
}
}
impl BufferEncoder for BitpackingBufferEncoder {
fn encode(&self, arrays: &[ArrayRef]) -> Result<(EncodedBuffer, EncodedBufferMeta)> {
let count_items = arrays.iter().map(|arr| arr.len()).sum::<usize>();
let dst_bytes_total = ceil(count_items * self.num_bits as usize, 8);
let mut dst_buffer = vec![0u8; dst_bytes_total];
let mut dst_idx = 0;
let mut dst_offset = 0;
for arr in arrays {
pack_array(
arr.clone(),
self.num_bits,
&mut dst_buffer,
&mut dst_idx,
&mut dst_offset,
)?;
}
let data_type = arrays[0].data_type();
Ok((
EncodedBuffer {
parts: vec![dst_buffer.into()],
},
EncodedBufferMeta {
bits_per_value: (data_type.byte_width() * 8) as u64,
bitpacked_bits_per_value: Some(self.num_bits),
compression_scheme: None,
},
))
}
}
fn pack_array(
arr: ArrayRef,
num_bits: u64,
dst: &mut [u8],
dst_idx: &mut usize,
dst_offset: &mut u8,
) -> Result<()> {
match arr.data_type() {
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
pack_buffers(
arr.to_data(),
num_bits,
arr.data_type().byte_width(),
dst,
dst_idx,
dst_offset,
);
Ok(())
}
_ => Err(Error::InvalidInput {
source: format!("Invalid data type for bitpacking: {}", arr.data_type()).into(),
location: location!(),
}),
}
}
fn pack_buffers(
data: ArrayData,
num_bits: u64,
byte_len: usize,
dst: &mut [u8],
dst_idx: &mut usize,
dst_offset: &mut u8,
) {
let buffers = data.buffers();
debug_assert_eq!(buffers.len(), 1);
for buffer in buffers {
pack_bits(buffer, num_bits, byte_len, dst, dst_idx, dst_offset);
}
}
fn pack_bits(
src: &[u8],
num_bits: u64,
byte_len: usize,
dst: &mut [u8],
dst_idx: &mut usize,
dst_offset: &mut u8,
) {
let bit_len = byte_len as u64 * 8;
let mask = u64::MAX >> (64 - num_bits);
let mut src_idx = 0;
while src_idx < src.len() {
let mut curr_mask = mask;
let mut curr_src = src[src_idx] & curr_mask as u8;
let mut src_offset = 0;
let mut src_bits_written = 0;
while src_bits_written < num_bits {
dst[*dst_idx] += (curr_src >> src_offset) << *dst_offset as u64;
let bits_written = (num_bits - src_bits_written)
.min(8 - src_offset)
.min(8 - *dst_offset as u64);
src_bits_written += bits_written;
*dst_offset += bits_written as u8;
src_offset += bits_written;
if *dst_offset == 8 {
*dst_idx += 1;
*dst_offset = 0;
}
if src_offset == 8 {
src_idx += 1;
src_offset = 0;
curr_mask >>= 8;
if src_idx == src.len() {
break;
}
curr_src = src[src_idx] & curr_mask as u8;
}
}
if bit_len != num_bits {
let mut partial_bytes_written = num_bits / 8;
if bit_len % num_bits != 0 || partial_bytes_written == 0 {
partial_bytes_written += 1;
}
let mut to_next_byte = 1;
if num_bits % 8 == 0 {
to_next_byte = 0;
}
src_idx += (byte_len as u64 - partial_bytes_written + to_next_byte) as usize;
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct BitpackedScheduler {
bits_per_value: u64,
uncompressed_bits_per_value: u64,
buffer_offset: u64,
}
impl BitpackedScheduler {
pub fn new(bits_per_value: u64, uncompressed_bits_per_value: u64, buffer_offset: u64) -> Self {
Self {
bits_per_value,
uncompressed_bits_per_value,
buffer_offset,
}
}
}
impl PageScheduler for BitpackedScheduler {
fn schedule_ranges(
&self,
ranges: &[std::ops::Range<u64>],
scheduler: &Arc<dyn crate::EncodingsIo>,
top_level_row: u64,
) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
let mut min = u64::MAX;
let mut max = 0;
let mut buffer_bit_start_offsets: Vec<u8> = vec![];
let mut buffer_bit_end_offsets: Vec<Option<u8>> = vec![];
let byte_ranges = ranges
.iter()
.map(|range| {
let start_byte_offset = range.start * self.bits_per_value / 8;
let mut end_byte_offset = range.end * self.bits_per_value / 8;
if range.end * self.bits_per_value % 8 != 0 {
end_byte_offset += 1;
let end_bit_offset = range.end * self.bits_per_value % 8;
buffer_bit_end_offsets.push(Some(end_bit_offset as u8));
} else {
buffer_bit_end_offsets.push(None);
}
let start_bit_offset = range.start * self.bits_per_value % 8;
buffer_bit_start_offsets.push(start_bit_offset as u8);
let start = self.buffer_offset + start_byte_offset;
let end = self.buffer_offset + end_byte_offset;
min = min.min(start);
max = max.max(end);
start..end
})
.collect::<Vec<_>>();
trace!(
"Scheduling I/O for {} ranges spread across byte range {}..{}",
byte_ranges.len(),
min,
max
);
let bytes = scheduler.submit_request(byte_ranges, top_level_row);
let bits_per_value = self.bits_per_value;
let uncompressed_bits_per_value = self.uncompressed_bits_per_value;
async move {
let bytes = bytes.await?;
Ok(Box::new(BitpackedPageDecoder {
buffer_bit_start_offsets,
buffer_bit_end_offsets,
bits_per_value,
uncompressed_bits_per_value,
data: bytes,
}) as Box<dyn PrimitivePageDecoder>)
}
.boxed()
}
}
#[derive(Debug)]
struct BitpackedPageDecoder {
buffer_bit_start_offsets: Vec<u8>,
buffer_bit_end_offsets: Vec<Option<u8>>,
bits_per_value: u64,
uncompressed_bits_per_value: u64,
data: Vec<Bytes>,
}
impl PrimitivePageDecoder for BitpackedPageDecoder {
fn decode(
&self,
rows_to_skip: u64,
num_rows: u64,
_all_null: &mut bool,
) -> Result<Vec<BytesMut>> {
let num_bytes = self.uncompressed_bits_per_value / 8 * num_rows;
let mut dest_buffers = vec![BytesMut::with_capacity(num_bytes as usize)];
debug_assert!(self.bits_per_value <= 64);
let mut rows_to_skip = rows_to_skip;
let mut rows_taken = 0;
let byte_len = self.uncompressed_bits_per_value / 8;
let dst = &mut dest_buffers[0];
let mut dst_idx = dst.len(); let mask = u64::MAX >> (64 - self.bits_per_value);
for i in 0..self.data.len() {
let src = &self.data[i];
let (mut src_idx, mut src_offset) = match compute_start_offset(
rows_to_skip,
src.len(),
self.bits_per_value,
self.buffer_bit_start_offsets[i],
self.buffer_bit_end_offsets[i],
) {
StartOffset::SkipFull(rows_to_skip_here) => {
rows_to_skip -= rows_to_skip_here;
continue;
}
StartOffset::SkipSome(buffer_start_offset) => (
buffer_start_offset.index,
buffer_start_offset.bit_offset as u64,
),
};
while src_idx < src.len() && rows_taken < num_rows {
rows_taken += 1;
let mut curr_mask = mask; let mut curr_src = src[src_idx] & (curr_mask << src_offset) as u8;
let mut src_bits_written = 0;
let mut dst_offset = 0;
while src_bits_written < self.bits_per_value {
dst.extend([0].repeat(dst_idx + 1 - dst.len()));
dst[dst_idx] += (curr_src >> src_offset) << dst_offset;
let bits_written = (self.bits_per_value - src_bits_written)
.min(8 - src_offset)
.min(8 - dst_offset);
src_bits_written += bits_written;
dst_offset += bits_written;
src_offset += bits_written;
curr_mask >>= bits_written;
if dst_offset == 8 {
dst_idx += 1;
dst_offset = 0;
}
if src_offset == 8 {
src_idx += 1;
src_offset = 0;
if src_idx == src.len() {
break;
}
curr_src = src[src_idx] & curr_mask as u8;
}
}
if self.uncompressed_bits_per_value != self.bits_per_value {
let mut partial_bytes_written = self.bits_per_value / 8;
if self.uncompressed_bits_per_value % self.bits_per_value != 0
|| partial_bytes_written == 0
{
partial_bytes_written += 1;
}
let mut to_next_byte = 1;
if self.bits_per_value % 8 == 0 {
to_next_byte = 0;
}
dst_idx += (byte_len - partial_bytes_written + to_next_byte) as usize;
}
if let Some(buffer_bit_end_offset) = self.buffer_bit_end_offsets[i] {
if src_idx == src.len() - 1 && src_offset >= buffer_bit_end_offset as u64 {
break;
}
}
}
}
dst.extend([0].repeat(dst_idx + 1 - dst.len()));
Ok(dest_buffers)
}
fn num_buffers(&self) -> u32 {
1
}
}
#[derive(Debug, PartialEq)]
struct BufferStartOffset {
index: usize,
bit_offset: u8,
}
#[derive(Debug, PartialEq)]
enum StartOffset {
SkipFull(u64),
SkipSome(BufferStartOffset),
}
fn compute_start_offset(
rows_to_skip: u64,
buffer_len: usize,
bits_per_value: u64,
buffer_start_bit_offset: u8,
buffer_end_bit_offset: Option<u8>,
) -> StartOffset {
let rows_in_buffer = rows_in_buffer(
buffer_len,
bits_per_value,
buffer_start_bit_offset,
buffer_end_bit_offset,
);
if rows_to_skip >= rows_in_buffer {
return StartOffset::SkipFull(rows_in_buffer);
}
let start_bit = rows_to_skip * bits_per_value + buffer_start_bit_offset as u64;
let start_byte = start_bit / 8;
StartOffset::SkipSome(BufferStartOffset {
index: start_byte as usize,
bit_offset: (start_bit % 8) as u8,
})
}
fn rows_in_buffer(
buffer_len: usize,
bits_per_value: u64,
buffer_start_bit_offset: u8,
buffer_end_bit_offset: Option<u8>,
) -> u64 {
let mut bits_in_buffer = (buffer_len * 8) as u64 - buffer_start_bit_offset as u64;
if let Some(buffer_end_bit_offset) = buffer_end_bit_offset {
bits_in_buffer -= (8 - buffer_end_bit_offset) as u64;
}
bits_in_buffer / bits_per_value
}
#[cfg(test)]
pub mod test {
use super::*;
use std::sync::Arc;
use arrow_array::{
types::{UInt16Type, UInt8Type},
Float64Array,
};
use lance_datagen::{array::fill, gen, ArrayGenerator, ArrayGeneratorExt, RowCount};
#[test]
fn test_num_compressed_bits() {
fn gen_array(generator: Box<dyn ArrayGenerator>) -> ArrayRef {
let arr = gen()
.anon_col(generator)
.into_batch_rows(RowCount::from(10000))
.unwrap()
.column(0)
.clone();
arr
}
macro_rules! do_test {
($num_bits:expr, $data_type:ident, $null_probability:expr) => {
let max = 1 << $num_bits - 1;
let mut arr =
gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
while arr.null_count() == arr.len() {
arr = gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
}
let result = num_compressed_bits(arr);
assert_eq!(Some($num_bits), result);
};
}
let test_cases = vec![
(5u64, 0.0f64),
(5u64, 0.9f64),
(1u64, 0.0f64),
(1u64, 0.5f64),
(8u64, 0.0f64),
(8u64, 0.5f64),
];
for (num_bits, null_probability) in &test_cases {
do_test!(*num_bits, UInt8Type, *null_probability);
do_test!(*num_bits, UInt16Type, *null_probability);
do_test!(*num_bits, UInt32Type, *null_probability);
do_test!(*num_bits, UInt64Type, *null_probability);
}
let test_cases = vec![
(13u64, 0.0f64),
(13u64, 0.5f64),
(16u64, 0.0f64),
(16u64, 0.5f64),
];
for (num_bits, null_probability) in &test_cases {
do_test!(*num_bits, UInt16Type, *null_probability);
do_test!(*num_bits, UInt32Type, *null_probability);
do_test!(*num_bits, UInt64Type, *null_probability);
}
let test_cases = vec![
(25u64, 0.0f64),
(25u64, 0.5f64),
(32u64, 0.0f64),
(32u64, 0.5f64),
];
for (num_bits, null_probability) in &test_cases {
do_test!(*num_bits, UInt32Type, *null_probability);
do_test!(*num_bits, UInt64Type, *null_probability);
}
let test_cases = vec![
(48u64, 0.0f64),
(48u64, 0.5f64),
(64u64, 0.0f64),
(64u64, 0.5f64),
];
for (num_bits, null_probability) in &test_cases {
do_test!(*num_bits, UInt64Type, *null_probability);
}
let arr = Float64Array::from_iter_values(vec![0.1, 0.2, 0.3]);
let result = num_compressed_bits(Arc::new(arr));
assert_eq!(None, result);
}
#[test]
fn test_rows_in_buffer() {
let test_cases = vec![
(5usize, 5u64, 0u8, None, 8u64),
(2, 3, 0, Some(5), 4),
(2, 3, 7, Some(6), 2),
];
for (
buffer_len,
bits_per_value,
buffer_start_bit_offset,
buffer_end_bit_offset,
expected,
) in test_cases
{
let result = rows_in_buffer(
buffer_len,
bits_per_value,
buffer_start_bit_offset,
buffer_end_bit_offset,
);
assert_eq!(expected, result);
}
}
#[test]
fn test_compute_start_offset() {
let result = compute_start_offset(0, 5, 5, 0, None);
assert_eq!(
StartOffset::SkipSome(BufferStartOffset {
index: 0,
bit_offset: 0
}),
result
);
let result = compute_start_offset(10, 5, 5, 0, None);
assert_eq!(StartOffset::SkipFull(8), result);
}
}