lance_encoding/v2/encodings/physical/
bitpack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow::datatypes::{
7    Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
8};
9use arrow_array::{Array, ArrowPrimitiveType, PrimitiveArray};
10use arrow_buffer::bit_util::ceil;
11use arrow_buffer::ArrowNativeType;
12use arrow_schema::DataType;
13use bytes::Bytes;
14use futures::future::{BoxFuture, FutureExt};
15use log::trace;
16use num_traits::{AsPrimitive, PrimInt};
17use snafu::location;
18
19use lance_arrow::DataTypeExt;
20use lance_core::{Error, Result};
21
22use crate::buffer::LanceBuffer;
23use crate::compression_algo::fastlanes::BitPacking;
24use crate::data::BlockInfo;
25use crate::data::{DataBlock, FixedWidthDataBlock, NullableDataBlock};
26use crate::decoder::{PageScheduler, PrimitivePageDecoder};
27use crate::format::ProtobufUtils;
28use crate::v2::encoder::{ArrayEncoder, EncodedArray};
29use arrow::array::{ArrayRef, AsArray};
30use bytemuck::cast_slice;
31
32const LOG_ELEMS_PER_CHUNK: u8 = 10;
33const ELEMS_PER_CHUNK: u64 = 1 << LOG_ELEMS_PER_CHUNK;
34
35// Compute the compressed_bit_width for a given array of integers
36// todo: compute all statistics before encoding
37// todo: see how to use rust macro to rewrite this function
38pub fn compute_compressed_bit_width_for_non_neg(arrays: &[ArrayRef]) -> u64 {
39    debug_assert!(!arrays.is_empty());
40
41    let res;
42
43    match arrays[0].data_type() {
44        DataType::UInt8 => {
45            let mut global_max: u8 = 0;
46            for array in arrays {
47                let primitive_array = array
48                    .as_any()
49                    .downcast_ref::<PrimitiveArray<UInt8Type>>()
50                    .unwrap();
51                let array_max = arrow::compute::bit_or(primitive_array);
52                global_max = global_max.max(array_max.unwrap_or(0));
53            }
54            let num_bits =
55                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
56            // we will have constant encoding later
57            if num_bits == 0 {
58                res = 1;
59            } else {
60                res = num_bits;
61            }
62        }
63
64        DataType::Int8 => {
65            let mut global_max_width: u64 = 0;
66            for array in arrays {
67                let primitive_array = array
68                    .as_any()
69                    .downcast_ref::<PrimitiveArray<Int8Type>>()
70                    .unwrap();
71                let array_max_width = arrow::compute::bit_or(primitive_array).unwrap_or(0);
72                global_max_width = global_max_width.max(8 - array_max_width.leading_zeros() as u64);
73            }
74            if global_max_width == 0 {
75                res = 1;
76            } else {
77                res = global_max_width;
78            }
79        }
80
81        DataType::UInt16 => {
82            let mut global_max: u16 = 0;
83            for array in arrays {
84                let primitive_array = array
85                    .as_any()
86                    .downcast_ref::<PrimitiveArray<UInt16Type>>()
87                    .unwrap();
88                let array_max = arrow::compute::bit_or(primitive_array).unwrap_or(0);
89                global_max = global_max.max(array_max);
90            }
91            let num_bits =
92                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
93            if num_bits == 0 {
94                res = 1;
95            } else {
96                res = num_bits;
97            }
98        }
99
100        DataType::Int16 => {
101            let mut global_max_width: u64 = 0;
102            for array in arrays {
103                let primitive_array = array
104                    .as_any()
105                    .downcast_ref::<PrimitiveArray<Int16Type>>()
106                    .unwrap();
107                let array_max_width = arrow::compute::bit_or(primitive_array).unwrap_or(0);
108                global_max_width =
109                    global_max_width.max(16 - array_max_width.leading_zeros() as u64);
110            }
111            if global_max_width == 0 {
112                res = 1;
113            } else {
114                res = global_max_width;
115            }
116        }
117
118        DataType::UInt32 => {
119            let mut global_max: u32 = 0;
120            for array in arrays {
121                let primitive_array = array
122                    .as_any()
123                    .downcast_ref::<PrimitiveArray<UInt32Type>>()
124                    .unwrap();
125                let array_max = arrow::compute::bit_or(primitive_array).unwrap_or(0);
126                global_max = global_max.max(array_max);
127            }
128            let num_bits =
129                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
130            if num_bits == 0 {
131                res = 1;
132            } else {
133                res = num_bits;
134            }
135        }
136
137        DataType::Int32 => {
138            let mut global_max_width: u64 = 0;
139            for array in arrays {
140                let primitive_array = array
141                    .as_any()
142                    .downcast_ref::<PrimitiveArray<Int32Type>>()
143                    .unwrap();
144                let array_max_width = arrow::compute::bit_or(primitive_array).unwrap_or(0);
145                global_max_width =
146                    global_max_width.max(32 - array_max_width.leading_zeros() as u64);
147            }
148            if global_max_width == 0 {
149                res = 1;
150            } else {
151                res = global_max_width;
152            }
153        }
154
155        DataType::UInt64 => {
156            let mut global_max: u64 = 0;
157            for array in arrays {
158                let primitive_array = array
159                    .as_any()
160                    .downcast_ref::<PrimitiveArray<UInt64Type>>()
161                    .unwrap();
162                let array_max = arrow::compute::bit_or(primitive_array).unwrap_or(0);
163                global_max = global_max.max(array_max);
164            }
165            let num_bits =
166                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
167            if num_bits == 0 {
168                res = 1;
169            } else {
170                res = num_bits;
171            }
172        }
173
174        DataType::Int64 => {
175            let mut global_max_width: u64 = 0;
176            for array in arrays {
177                let primitive_array = array
178                    .as_any()
179                    .downcast_ref::<PrimitiveArray<Int64Type>>()
180                    .unwrap();
181                let array_max_width = arrow::compute::bit_or(primitive_array).unwrap_or(0);
182                global_max_width =
183                    global_max_width.max(64 - array_max_width.leading_zeros() as u64);
184            }
185            if global_max_width == 0 {
186                res = 1;
187            } else {
188                res = global_max_width;
189            }
190        }
191        _ => {
192            panic!("BitpackedForNonNegArrayEncoder only supports data types of UInt8, Int8, UInt16, Int16, UInt32, Int32, UInt64, Int64");
193        }
194    };
195    res
196}
197
198// Bitpack integers using fastlanes algorithm, the input is sliced into chunks of 1024 integers, and bitpacked
199// chunk by chunk. when the input is not a multiple of 1024, the last chunk is padded with zeros, this is fine because
200// we also know the number of rows we have.
201// Here self is a borrow of BitpackedForNonNegArrayEncoder, unpacked is a mutable borrow of FixedWidthDataBlock,
202// data_type can be  one of u8, u16, u32, or u64.
203// buffer_index is a mutable borrow of u32, indicating the buffer index of the output EncodedArray.
204// It outputs an fastlanes bitpacked EncodedArray
205macro_rules! encode_fixed_width {
206    ($self:expr, $unpacked:expr, $data_type:ty, $buffer_index:expr) => {{
207        let num_chunks = $unpacked.num_values.div_ceil(ELEMS_PER_CHUNK);
208        let num_full_chunks = $unpacked.num_values / ELEMS_PER_CHUNK;
209        let uncompressed_bit_width = std::mem::size_of::<$data_type>() as u64 * 8;
210
211        // the output vector type is the same as the input type, for example, when input is u16, output is Vec<u16>
212        let packed_chunk_size = 1024 * $self.compressed_bit_width as usize / uncompressed_bit_width as usize;
213
214        let input_slice = $unpacked.data.borrow_to_typed_slice::<$data_type>();
215        let input = input_slice.as_ref();
216
217        let mut output = Vec::with_capacity(num_chunks as usize * packed_chunk_size);
218
219        // Loop over all but the last chunk.
220        (0..num_full_chunks).for_each(|i| {
221            let start_elem = (i * ELEMS_PER_CHUNK) as usize;
222
223            let output_len = output.len();
224            unsafe {
225                output.set_len(output_len + packed_chunk_size);
226                BitPacking::unchecked_pack(
227                    $self.compressed_bit_width,
228                    &input[start_elem..][..ELEMS_PER_CHUNK as usize],
229                    &mut output[output_len..][..packed_chunk_size],
230                );
231            }
232        });
233
234        if num_chunks != num_full_chunks {
235            let last_chunk_elem_num = $unpacked.num_values % ELEMS_PER_CHUNK;
236            let mut last_chunk = vec![0 as $data_type; ELEMS_PER_CHUNK as usize];
237            last_chunk[..last_chunk_elem_num as usize].clone_from_slice(
238                &input[$unpacked.num_values as usize - last_chunk_elem_num as usize..],
239            );
240
241            let output_len = output.len();
242            unsafe {
243                output.set_len(output_len + packed_chunk_size);
244                BitPacking::unchecked_pack(
245                    $self.compressed_bit_width,
246                    &last_chunk,
247                    &mut output[output_len..][..packed_chunk_size],
248                );
249            }
250        }
251
252        let bitpacked_for_non_neg_buffer_index = *$buffer_index;
253        *$buffer_index += 1;
254
255        let encoding = ProtobufUtils::bitpacked_for_non_neg_encoding(
256            $self.compressed_bit_width as u64,
257            uncompressed_bit_width,
258            bitpacked_for_non_neg_buffer_index,
259        );
260        let packed = DataBlock::FixedWidth(FixedWidthDataBlock {
261            bits_per_value: $self.compressed_bit_width as u64,
262            data: LanceBuffer::reinterpret_vec(output),
263            num_values: $unpacked.num_values,
264            block_info: BlockInfo::new(),
265        });
266
267        Result::Ok(EncodedArray {
268            data: packed,
269            encoding,
270        })
271    }};
272}
273
274#[derive(Debug)]
275pub struct BitpackedForNonNegArrayEncoder {
276    pub compressed_bit_width: usize,
277    pub original_data_type: DataType,
278}
279
280impl BitpackedForNonNegArrayEncoder {
281    pub fn new(compressed_bit_width: usize, data_type: DataType) -> Self {
282        Self {
283            compressed_bit_width,
284            original_data_type: data_type,
285        }
286    }
287}
288
289impl ArrayEncoder for BitpackedForNonNegArrayEncoder {
290    fn encode(
291        &self,
292        data: DataBlock,
293        data_type: &DataType,
294        buffer_index: &mut u32,
295    ) -> Result<EncodedArray> {
296        match data {
297            DataBlock::AllNull(_) => {
298                let encoding = ProtobufUtils::basic_all_null_encoding();
299                Ok(EncodedArray { data, encoding })
300            }
301            DataBlock::FixedWidth(mut unpacked) => {
302                match data_type {
303                    DataType::UInt8 | DataType::Int8 => encode_fixed_width!(self, unpacked, u8, buffer_index),
304                    DataType::UInt16 | DataType::Int16 => encode_fixed_width!(self, unpacked, u16, buffer_index),
305                    DataType::UInt32 | DataType::Int32 => encode_fixed_width!(self, unpacked, u32, buffer_index),
306                    DataType::UInt64 | DataType::Int64 => encode_fixed_width!(self, unpacked, u64, buffer_index),
307                    _ => unreachable!("BitpackedForNonNegArrayEncoder only supports data types of UInt8, Int8, UInt16, Int16, UInt32, Int32, UInt64, Int64"),
308                }
309            }
310            DataBlock::Nullable(nullable) => {
311                let validity_buffer_index = *buffer_index;
312                *buffer_index += 1;
313
314                let validity_desc = ProtobufUtils::flat_encoding(
315                    1,
316                    validity_buffer_index,
317                    /*compression=*/ None,
318                );
319                let encoded_values: EncodedArray;
320                match *nullable.data {
321                    DataBlock::FixedWidth(mut unpacked) => {
322                        match data_type {
323                            DataType::UInt8 | DataType::Int8 => encoded_values = encode_fixed_width!(self, unpacked, u8, buffer_index)?,
324                            DataType::UInt16 | DataType::Int16 => encoded_values = encode_fixed_width!(self, unpacked, u16, buffer_index)?,
325                            DataType::UInt32 | DataType::Int32 => encoded_values = encode_fixed_width!(self, unpacked, u32, buffer_index)?,
326                            DataType::UInt64 | DataType::Int64 => encoded_values = encode_fixed_width!(self, unpacked, u64, buffer_index)?,
327                            _ => unreachable!("BitpackedForNonNegArrayEncoder only supports data types of UInt8, Int8, UInt16, Int16, UInt32, Int32, UInt64, Int64"),
328                        }
329                    }
330                    _ => {
331                        return Err(Error::InvalidInput {
332                            source: "Bitpacking only supports fixed width data blocks or a nullable data block with fixed width data block inside or a all null data block".into(),
333                            location: location!(),
334                        });
335                    }
336                }
337                let encoding =
338                    ProtobufUtils::basic_some_null_encoding(validity_desc, encoded_values.encoding);
339                let encoded = DataBlock::Nullable(NullableDataBlock {
340                    data: Box::new(encoded_values.data),
341                    nulls: nullable.nulls,
342                    block_info: BlockInfo::new(),
343                });
344                Ok(EncodedArray {
345                    data: encoded,
346                    encoding,
347                })
348            }
349            _ => {
350                Err(Error::InvalidInput {
351                    source: "Bitpacking only supports fixed width data blocks or a nullable data block with fixed width data block inside or a all null data block".into(),
352                    location: location!(),
353                })
354            }
355        }
356    }
357}
358
359#[derive(Debug)]
360pub struct BitpackedForNonNegScheduler {
361    compressed_bit_width: u64,
362    uncompressed_bits_per_value: u64,
363    buffer_offset: u64,
364}
365
366impl BitpackedForNonNegScheduler {
367    pub fn new(
368        compressed_bit_width: u64,
369        uncompressed_bits_per_value: u64,
370        buffer_offset: u64,
371    ) -> Self {
372        Self {
373            compressed_bit_width,
374            uncompressed_bits_per_value,
375            buffer_offset,
376        }
377    }
378
379    fn locate_chunk_start(&self, relative_row_num: u64) -> u64 {
380        let chunk_size = ELEMS_PER_CHUNK * self.compressed_bit_width / 8;
381        self.buffer_offset + (relative_row_num / ELEMS_PER_CHUNK * chunk_size)
382    }
383
384    fn locate_chunk_end(&self, relative_row_num: u64) -> u64 {
385        let chunk_size = ELEMS_PER_CHUNK * self.compressed_bit_width / 8;
386        self.buffer_offset + (relative_row_num / ELEMS_PER_CHUNK * chunk_size) + chunk_size
387    }
388}
389
390impl PageScheduler for BitpackedForNonNegScheduler {
391    fn schedule_ranges(
392        &self,
393        ranges: &[std::ops::Range<u64>],
394        scheduler: &Arc<dyn crate::EncodingsIo>,
395        top_level_row: u64,
396    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
397        assert!(!ranges.is_empty());
398
399        let mut byte_ranges = vec![];
400
401        // map one bytes to multiple ranges, one bytes has at least one range corresponding to it
402        let mut bytes_idx_to_range_indices = vec![];
403        let first_byte_range = std::ops::Range {
404            start: self.locate_chunk_start(ranges[0].start),
405            end: self.locate_chunk_end(ranges[0].end - 1),
406        }; // the ranges are half-open
407        byte_ranges.push(first_byte_range);
408        bytes_idx_to_range_indices.push(vec![ranges[0].clone()]);
409
410        for (i, range) in ranges.iter().enumerate().skip(1) {
411            let this_start = self.locate_chunk_start(range.start);
412            let this_end = self.locate_chunk_end(range.end - 1);
413
414            // when the current range start is in the same chunk as the previous range's end, we colaesce this two bytes ranges
415            // when the current range start is not in the same chunk as the previous range's end, we create a new bytes range
416            if this_start == self.locate_chunk_start(ranges[i - 1].end - 1) {
417                byte_ranges.last_mut().unwrap().end = this_end;
418                bytes_idx_to_range_indices
419                    .last_mut()
420                    .unwrap()
421                    .push(range.clone());
422            } else {
423                byte_ranges.push(this_start..this_end);
424                bytes_idx_to_range_indices.push(vec![range.clone()]);
425            }
426        }
427
428        trace!(
429            "Scheduling I/O for {} ranges spread across byte range {}..{}",
430            byte_ranges.len(),
431            byte_ranges[0].start,
432            byte_ranges.last().unwrap().end
433        );
434
435        let bytes = scheduler.submit_request(byte_ranges.clone(), top_level_row);
436
437        // copy the necessary data from `self` to move into the async block
438        let compressed_bit_width = self.compressed_bit_width;
439        let uncompressed_bits_per_value = self.uncompressed_bits_per_value;
440        let num_rows = ranges.iter().map(|range| range.end - range.start).sum();
441
442        async move {
443            let bytes = bytes.await?;
444            let decompressed_output = bitpacked_for_non_neg_decode(
445                compressed_bit_width,
446                uncompressed_bits_per_value,
447                &bytes,
448                &bytes_idx_to_range_indices,
449                num_rows,
450            );
451            Ok(Box::new(BitpackedForNonNegPageDecoder {
452                uncompressed_bits_per_value,
453                decompressed_buf: decompressed_output,
454            }) as Box<dyn PrimitivePageDecoder>)
455        }
456        .boxed()
457    }
458}
459
460#[derive(Debug)]
461struct BitpackedForNonNegPageDecoder {
462    // number of bits in the uncompressed value. E.g. this will be 32 for DataType::UInt32
463    uncompressed_bits_per_value: u64,
464
465    decompressed_buf: LanceBuffer,
466}
467
468impl PrimitivePageDecoder for BitpackedForNonNegPageDecoder {
469    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
470        if ![8, 16, 32, 64].contains(&self.uncompressed_bits_per_value) {
471            return Err(Error::InvalidInput {
472                source: "BitpackedForNonNegPageDecoder should only has uncompressed_bits_per_value of 8, 16, 32, or 64".into(),
473                location: location!(),
474            });
475        }
476
477        let elem_size_in_bytes = self.uncompressed_bits_per_value / 8;
478
479        Ok(DataBlock::FixedWidth(FixedWidthDataBlock {
480            data: self.decompressed_buf.slice_with_length(
481                (rows_to_skip * elem_size_in_bytes) as usize,
482                (num_rows * elem_size_in_bytes) as usize,
483            ),
484            bits_per_value: self.uncompressed_bits_per_value,
485            num_values: num_rows,
486            block_info: BlockInfo::new(),
487        }))
488    }
489}
490
491macro_rules! bitpacked_decode {
492    ($uncompressed_type:ty, $compressed_bit_width:expr, $data:expr, $bytes_idx_to_range_indices:expr, $num_rows:expr) => {{
493        let mut decompressed: Vec<$uncompressed_type> = Vec::with_capacity($num_rows as usize);
494        let packed_chunk_size_in_byte: usize = (ELEMS_PER_CHUNK * $compressed_bit_width) as usize / 8;
495        let mut decompress_chunk_buf = vec![0 as $uncompressed_type; ELEMS_PER_CHUNK as usize];
496
497        for (i, bytes) in $data.iter().enumerate() {
498            let mut ranges_idx = 0;
499            let mut curr_range_start = $bytes_idx_to_range_indices[i][0].start;
500            let mut chunk_num = 0;
501
502            while chunk_num * packed_chunk_size_in_byte < bytes.len() {
503                // Copy for memory alignment
504                // TODO: This copy should not be needed
505                let chunk_in_u8: Vec<u8> = bytes[chunk_num * packed_chunk_size_in_byte..]
506                    [..packed_chunk_size_in_byte]
507                    .to_vec();
508                chunk_num += 1;
509                let chunk = cast_slice(&chunk_in_u8);
510                unsafe {
511                    BitPacking::unchecked_unpack(
512                        $compressed_bit_width as usize,
513                        chunk,
514                        &mut decompress_chunk_buf,
515                    );
516                }
517
518                loop {
519                    // Case 1: All the elements after (curr_range_start % ELEMS_PER_CHUNK) inside this chunk are needed.
520                    let elems_after_curr_range_start_in_this_chunk =
521                        ELEMS_PER_CHUNK - curr_range_start % ELEMS_PER_CHUNK;
522                    if curr_range_start + elems_after_curr_range_start_in_this_chunk
523                        <= $bytes_idx_to_range_indices[i][ranges_idx].end
524                    {
525                        decompressed.extend_from_slice(
526                            &decompress_chunk_buf[(curr_range_start % ELEMS_PER_CHUNK) as usize..],
527                        );
528                        curr_range_start += elems_after_curr_range_start_in_this_chunk;
529                        break;
530                    } else {
531                        // Case 2: Only part of the elements after (curr_range_start % ELEMS_PER_CHUNK) inside this chunk are needed.
532                        let elems_this_range_needed_in_this_chunk =
533                            ($bytes_idx_to_range_indices[i][ranges_idx].end - curr_range_start)
534                                .min(ELEMS_PER_CHUNK - curr_range_start % ELEMS_PER_CHUNK);
535                        decompressed.extend_from_slice(
536                            &decompress_chunk_buf[(curr_range_start % ELEMS_PER_CHUNK) as usize..]
537                                [..elems_this_range_needed_in_this_chunk as usize],
538                        );
539                        if curr_range_start + elems_this_range_needed_in_this_chunk
540                            == $bytes_idx_to_range_indices[i][ranges_idx].end
541                        {
542                            ranges_idx += 1;
543                            if ranges_idx == $bytes_idx_to_range_indices[i].len() {
544                                break;
545                            }
546                            curr_range_start = $bytes_idx_to_range_indices[i][ranges_idx].start;
547                        } else {
548                            curr_range_start += elems_this_range_needed_in_this_chunk;
549                        }
550                    }
551                }
552            }
553        }
554
555        LanceBuffer::reinterpret_vec(decompressed)
556    }};
557}
558
559fn bitpacked_for_non_neg_decode(
560    compressed_bit_width: u64,
561    uncompressed_bits_per_value: u64,
562    data: &[Bytes],
563    bytes_idx_to_range_indices: &[Vec<std::ops::Range<u64>>],
564    num_rows: u64,
565) -> LanceBuffer {
566    match uncompressed_bits_per_value {
567        8 => bitpacked_decode!(
568            u8,
569            compressed_bit_width,
570            data,
571            bytes_idx_to_range_indices,
572            num_rows
573        ),
574        16 => bitpacked_decode!(
575            u16,
576            compressed_bit_width,
577            data,
578            bytes_idx_to_range_indices,
579            num_rows
580        ),
581        32 => bitpacked_decode!(
582            u32,
583            compressed_bit_width,
584            data,
585            bytes_idx_to_range_indices,
586            num_rows
587        ),
588        64 => bitpacked_decode!(
589            u64,
590            compressed_bit_width,
591            data,
592            bytes_idx_to_range_indices,
593            num_rows
594        ),
595        _ => unreachable!(
596            "bitpacked_for_non_neg_decode only supports 8, 16, 32, 64 uncompressed_bits_per_value"
597        ),
598    }
599}
600
601#[derive(Debug)]
602pub struct BitpackParams {
603    pub num_bits: u64,
604
605    pub signed: bool,
606}
607
608// Compute the number of bits to use for each item, if this array can be encoded using
609// bitpacking encoding. Returns `None` if the type or array data is not supported.
610pub fn bitpack_params(arr: &dyn Array) -> Option<BitpackParams> {
611    match arr.data_type() {
612        DataType::UInt8 => bitpack_params_for_type::<UInt8Type>(arr.as_primitive()),
613        DataType::UInt16 => bitpack_params_for_type::<UInt16Type>(arr.as_primitive()),
614        DataType::UInt32 => bitpack_params_for_type::<UInt32Type>(arr.as_primitive()),
615        DataType::UInt64 => bitpack_params_for_type::<UInt64Type>(arr.as_primitive()),
616        DataType::Int8 => bitpack_params_for_signed_type::<Int8Type>(arr.as_primitive()),
617        DataType::Int16 => bitpack_params_for_signed_type::<Int16Type>(arr.as_primitive()),
618        DataType::Int32 => bitpack_params_for_signed_type::<Int32Type>(arr.as_primitive()),
619        DataType::Int64 => bitpack_params_for_signed_type::<Int64Type>(arr.as_primitive()),
620        // TODO -- eventually we could support temporal types as well
621        _ => None,
622    }
623}
624
625// Compute the number bits to to use for bitpacking generically.
626// returns None if the array is empty or all nulls
627fn bitpack_params_for_type<T>(arr: &PrimitiveArray<T>) -> Option<BitpackParams>
628where
629    T: ArrowPrimitiveType,
630    T::Native: PrimInt + AsPrimitive<u64>,
631{
632    let max = arrow::compute::bit_or(arr);
633    let num_bits =
634        max.map(|max| arr.data_type().byte_width() as u64 * 8 - max.leading_zeros() as u64);
635
636    // we can't bitpack into 0 bits, so the minimum is 1
637    num_bits
638        .map(|num_bits| num_bits.max(1))
639        .map(|bits| BitpackParams {
640            num_bits: bits,
641            signed: false,
642        })
643}
644
645/// determine the minimum number of bits that can be used to represent
646/// an array of signed values. It includes all the significant bits for
647/// the value + plus 1 bit to represent the sign. If there are no negative values
648/// then it will not add a signed bit
649fn bitpack_params_for_signed_type<T>(arr: &PrimitiveArray<T>) -> Option<BitpackParams>
650where
651    T: ArrowPrimitiveType,
652    T::Native: PrimInt + AsPrimitive<i64>,
653{
654    let mut add_signed_bit = false;
655    let mut min_leading_bits: Option<u64> = None;
656    for val in arr.iter() {
657        if val.is_none() {
658            continue;
659        }
660        let val = val.unwrap();
661        if min_leading_bits.is_none() {
662            min_leading_bits = Some(u64::MAX);
663        }
664
665        if val.to_i64().unwrap() < 0i64 {
666            min_leading_bits = min_leading_bits.map(|bits| bits.min(val.leading_ones() as u64));
667            add_signed_bit = true;
668        } else {
669            min_leading_bits = min_leading_bits.map(|bits| bits.min(val.leading_zeros() as u64));
670        }
671    }
672
673    let mut min_leading_bits = arr.data_type().byte_width() as u64 * 8 - min_leading_bits?;
674    if add_signed_bit {
675        // Need extra sign bit
676        min_leading_bits += 1;
677    }
678    // cannot bitpack into <1 bit
679    let num_bits = min_leading_bits.max(1);
680    Some(BitpackParams {
681        num_bits,
682        signed: add_signed_bit,
683    })
684}
685#[derive(Debug)]
686pub struct BitpackedArrayEncoder {
687    num_bits: u64,
688    signed_type: bool,
689}
690
691impl BitpackedArrayEncoder {
692    pub fn new(num_bits: u64, signed_type: bool) -> Self {
693        Self {
694            num_bits,
695            signed_type,
696        }
697    }
698}
699
700impl ArrayEncoder for BitpackedArrayEncoder {
701    fn encode(
702        &self,
703        data: DataBlock,
704        _data_type: &DataType,
705        buffer_index: &mut u32,
706    ) -> Result<EncodedArray> {
707        // calculate the total number of bytes we need to allocate for the destination.
708        // this will be the number of items in the source array times the number of bits.
709        let dst_bytes_total = ceil(data.num_values() as usize * self.num_bits as usize, 8);
710
711        let mut dst_buffer = vec![0u8; dst_bytes_total];
712        let mut dst_idx = 0;
713        let mut dst_offset = 0;
714
715        let DataBlock::FixedWidth(unpacked) = data else {
716            return Err(Error::InvalidInput {
717                source: "Bitpacking only supports fixed width data blocks".into(),
718                location: location!(),
719            });
720        };
721
722        pack_bits(
723            &unpacked.data,
724            self.num_bits,
725            &mut dst_buffer,
726            &mut dst_idx,
727            &mut dst_offset,
728        );
729
730        let packed = DataBlock::FixedWidth(FixedWidthDataBlock {
731            bits_per_value: self.num_bits,
732            data: LanceBuffer::Owned(dst_buffer),
733            num_values: unpacked.num_values,
734            block_info: BlockInfo::new(),
735        });
736
737        let bitpacked_buffer_index = *buffer_index;
738        *buffer_index += 1;
739
740        let encoding = ProtobufUtils::bitpacked_encoding(
741            self.num_bits,
742            unpacked.bits_per_value,
743            bitpacked_buffer_index,
744            self.signed_type,
745        );
746
747        Ok(EncodedArray {
748            data: packed,
749            encoding,
750        })
751    }
752}
753
754fn pack_bits(
755    src: &LanceBuffer,
756    num_bits: u64,
757    dst: &mut [u8],
758    dst_idx: &mut usize,
759    dst_offset: &mut u8,
760) {
761    let bit_len = src.len() as u64 * 8;
762
763    let mask = u64::MAX >> (64 - num_bits);
764
765    let mut src_idx = 0;
766    while src_idx < src.len() {
767        let mut curr_mask = mask;
768        let mut curr_src = src[src_idx] & curr_mask as u8;
769        let mut src_offset = 0;
770        let mut src_bits_written = 0;
771
772        while src_bits_written < num_bits {
773            dst[*dst_idx] += (curr_src >> src_offset) << *dst_offset as u64;
774            let bits_written = (num_bits - src_bits_written)
775                .min(8 - src_offset)
776                .min(8 - *dst_offset as u64);
777            src_bits_written += bits_written;
778            *dst_offset += bits_written as u8;
779            src_offset += bits_written;
780
781            if *dst_offset == 8 {
782                *dst_idx += 1;
783                *dst_offset = 0;
784            }
785
786            if src_offset == 8 {
787                src_idx += 1;
788                src_offset = 0;
789                curr_mask >>= 8;
790                if src_idx == src.len() {
791                    break;
792                }
793                curr_src = src[src_idx] & curr_mask as u8;
794            }
795        }
796
797        // advance source_offset to the next byte if we're not at the end..
798        // note that we don't need to do this if we wrote the full number of bits
799        // because source index would have been advanced by the inner loop above
800        if bit_len != num_bits {
801            let partial_bytes_written = ceil(num_bits as usize, 8);
802
803            // we also want to the next location in src, unless we wrote something
804            // byte-aligned in which case the logic above would have already advanced
805            let mut to_next_byte = 1;
806            if num_bits % 8 == 0 {
807                to_next_byte = 0;
808            }
809
810            src_idx += src.len() - partial_bytes_written + to_next_byte;
811        }
812    }
813}
814
815// A physical scheduler for bitpacked buffers
816#[derive(Debug, Clone, Copy)]
817pub struct BitpackedScheduler {
818    bits_per_value: u64,
819    uncompressed_bits_per_value: u64,
820    buffer_offset: u64,
821    signed: bool,
822}
823
824impl BitpackedScheduler {
825    pub fn new(
826        bits_per_value: u64,
827        uncompressed_bits_per_value: u64,
828        buffer_offset: u64,
829        signed: bool,
830    ) -> Self {
831        Self {
832            bits_per_value,
833            uncompressed_bits_per_value,
834            buffer_offset,
835            signed,
836        }
837    }
838}
839
840impl PageScheduler for BitpackedScheduler {
841    fn schedule_ranges(
842        &self,
843        ranges: &[std::ops::Range<u64>],
844        scheduler: &Arc<dyn crate::EncodingsIo>,
845        top_level_row: u64,
846    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
847        let mut min = u64::MAX;
848        let mut max = 0;
849
850        let mut buffer_bit_start_offsets: Vec<u8> = vec![];
851        let mut buffer_bit_end_offsets: Vec<Option<u8>> = vec![];
852        let byte_ranges = ranges
853            .iter()
854            .map(|range| {
855                let start_byte_offset = range.start * self.bits_per_value / 8;
856                let mut end_byte_offset = range.end * self.bits_per_value / 8;
857                if range.end * self.bits_per_value % 8 != 0 {
858                    // If the end of the range is not byte-aligned, we need to read one more byte
859                    end_byte_offset += 1;
860
861                    let end_bit_offset = range.end * self.bits_per_value % 8;
862                    buffer_bit_end_offsets.push(Some(end_bit_offset as u8));
863                } else {
864                    buffer_bit_end_offsets.push(None);
865                }
866
867                let start_bit_offset = range.start * self.bits_per_value % 8;
868                buffer_bit_start_offsets.push(start_bit_offset as u8);
869
870                let start = self.buffer_offset + start_byte_offset;
871                let end = self.buffer_offset + end_byte_offset;
872                min = min.min(start);
873                max = max.max(end);
874
875                start..end
876            })
877            .collect::<Vec<_>>();
878
879        trace!(
880            "Scheduling I/O for {} ranges spread across byte range {}..{}",
881            byte_ranges.len(),
882            min,
883            max
884        );
885
886        let bytes = scheduler.submit_request(byte_ranges, top_level_row);
887
888        let bits_per_value = self.bits_per_value;
889        let uncompressed_bits_per_value = self.uncompressed_bits_per_value;
890        let signed = self.signed;
891        async move {
892            let bytes = bytes.await?;
893            Ok(Box::new(BitpackedPageDecoder {
894                buffer_bit_start_offsets,
895                buffer_bit_end_offsets,
896                bits_per_value,
897                uncompressed_bits_per_value,
898                signed,
899                data: bytes,
900            }) as Box<dyn PrimitivePageDecoder>)
901        }
902        .boxed()
903    }
904}
905
906#[derive(Debug)]
907struct BitpackedPageDecoder {
908    // bit offsets of the first value within each buffer
909    buffer_bit_start_offsets: Vec<u8>,
910
911    // bit offsets of the last value within each buffer. e.g. if there was a buffer
912    // with 2 values, packed into 5 bits, this would be [Some(3)], indicating that
913    // the bits from the 3rd->8th bit in the last byte shouldn't be decoded.
914    buffer_bit_end_offsets: Vec<Option<u8>>,
915
916    // the number of bits used to represent a compressed value. E.g. if the max value
917    // in the page was 7 (0b111), then this will be 3
918    bits_per_value: u64,
919
920    // number of bits in the uncompressed value. E.g. this will be 32 for u32
921    uncompressed_bits_per_value: u64,
922
923    // whether or not to use the msb as a sign bit during decoding
924    signed: bool,
925
926    data: Vec<Bytes>,
927}
928
929impl PrimitivePageDecoder for BitpackedPageDecoder {
930    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
931        let num_bytes = self.uncompressed_bits_per_value / 8 * num_rows;
932        let mut dest = vec![0; num_bytes as usize];
933
934        // current maximum supported bits per value = 64
935        debug_assert!(self.bits_per_value <= 64);
936
937        let mut rows_to_skip = rows_to_skip;
938        let mut rows_taken = 0;
939        let byte_len = self.uncompressed_bits_per_value / 8;
940        let mut dst_idx = 0; // index for current byte being written to destination buffer
941
942        // create bit mask for source bits
943        let mask = u64::MAX >> (64 - self.bits_per_value);
944
945        for i in 0..self.data.len() {
946            let src = &self.data[i];
947            let (mut src_idx, mut src_offset) = match compute_start_offset(
948                rows_to_skip,
949                src.len(),
950                self.bits_per_value,
951                self.buffer_bit_start_offsets[i],
952                self.buffer_bit_end_offsets[i],
953            ) {
954                StartOffset::SkipFull(rows_to_skip_here) => {
955                    rows_to_skip -= rows_to_skip_here;
956                    continue;
957                }
958                StartOffset::SkipSome(buffer_start_offset) => (
959                    buffer_start_offset.index,
960                    buffer_start_offset.bit_offset as u64,
961                ),
962            };
963
964            while src_idx < src.len() && rows_taken < num_rows {
965                rows_taken += 1;
966                let mut curr_mask = mask; // copy mask
967
968                // current source byte being written to destination
969                let mut curr_src = src[src_idx] & (curr_mask << src_offset) as u8;
970
971                // how many bits from the current source value have been written to destination
972                let mut src_bits_written = 0;
973
974                // the offset within the current destination byte to write to
975                let mut dst_offset = 0;
976
977                let is_negative = is_encoded_item_negative(
978                    src,
979                    src_idx,
980                    src_offset,
981                    self.bits_per_value as usize,
982                );
983
984                while src_bits_written < self.bits_per_value {
985                    // write bits from current source byte into destination
986                    dest[dst_idx] += (curr_src >> src_offset) << dst_offset;
987                    let bits_written = (self.bits_per_value - src_bits_written)
988                        .min(8 - src_offset)
989                        .min(8 - dst_offset);
990                    src_bits_written += bits_written;
991                    dst_offset += bits_written;
992                    src_offset += bits_written;
993                    curr_mask >>= bits_written;
994
995                    if dst_offset == 8 {
996                        dst_idx += 1;
997                        dst_offset = 0;
998                    }
999
1000                    if src_offset == 8 {
1001                        src_idx += 1;
1002                        src_offset = 0;
1003                        if src_idx == src.len() {
1004                            break;
1005                        }
1006                        curr_src = src[src_idx] & curr_mask as u8;
1007                    }
1008                }
1009
1010                // if the type is signed, need to pad out the rest of the byte with 1s
1011                let mut negative_padded_current_byte = false;
1012                if self.signed && is_negative && dst_offset > 0 {
1013                    negative_padded_current_byte = true;
1014                    while dst_offset < 8 {
1015                        dest[dst_idx] |= 1 << dst_offset;
1016                        dst_offset += 1;
1017                    }
1018                }
1019
1020                // advance destination offset to the next location
1021                // note that we don't need to do this if we wrote the full number of bits
1022                // because source index would have been advanced by the inner loop above
1023                if self.uncompressed_bits_per_value != self.bits_per_value {
1024                    let partial_bytes_written = ceil(self.bits_per_value as usize, 8);
1025
1026                    // we also want to move one location to the next location in destination,
1027                    // unless we wrote something byte-aligned in which case the logic above
1028                    // would have already advanced dst_idx
1029                    let mut to_next_byte = 1;
1030                    if self.bits_per_value % 8 == 0 {
1031                        to_next_byte = 0;
1032                    }
1033                    let next_dst_idx =
1034                        dst_idx + byte_len as usize - partial_bytes_written + to_next_byte;
1035
1036                    // pad remaining bytes with 1 for negative signed numbers
1037                    if self.signed && is_negative {
1038                        if !negative_padded_current_byte {
1039                            dest[dst_idx] = 0xFF;
1040                        }
1041                        for i in dest.iter_mut().take(next_dst_idx).skip(dst_idx + 1) {
1042                            *i = 0xFF;
1043                        }
1044                    }
1045
1046                    dst_idx = next_dst_idx;
1047                }
1048
1049                // If we've reached the last byte, there may be some extra bits from the
1050                // next value outside the range. We don't want to be taking those.
1051                if let Some(buffer_bit_end_offset) = self.buffer_bit_end_offsets[i] {
1052                    if src_idx == src.len() - 1 && src_offset >= buffer_bit_end_offset as u64 {
1053                        break;
1054                    }
1055                }
1056            }
1057        }
1058
1059        Ok(DataBlock::FixedWidth(FixedWidthDataBlock {
1060            data: LanceBuffer::from(dest),
1061            bits_per_value: self.uncompressed_bits_per_value,
1062            num_values: num_rows,
1063            block_info: BlockInfo::new(),
1064        }))
1065    }
1066}
1067
1068fn is_encoded_item_negative(src: &Bytes, src_idx: usize, src_offset: u64, num_bits: usize) -> bool {
1069    let mut last_byte_idx = src_idx + ((src_offset as usize + num_bits) / 8);
1070    let shift_amount = (src_offset as usize + num_bits) % 8;
1071    let shift_amount = if shift_amount == 0 {
1072        last_byte_idx -= 1;
1073        7
1074    } else {
1075        shift_amount - 1
1076    };
1077    let last_byte = src[last_byte_idx];
1078    let sign_bit_mask = 1 << shift_amount;
1079    let sign_bit = last_byte & sign_bit_mask;
1080
1081    sign_bit > 0
1082}
1083
1084#[derive(Debug, PartialEq)]
1085struct BufferStartOffset {
1086    index: usize,
1087    bit_offset: u8,
1088}
1089
1090#[derive(Debug, PartialEq)]
1091enum StartOffset {
1092    // skip the full buffer. The value is how many rows are skipped
1093    // by skipping the full buffer (e.g., # rows in buffer)
1094    SkipFull(u64),
1095
1096    // skip to some start offset in the buffer
1097    SkipSome(BufferStartOffset),
1098}
1099
1100/// compute how far ahead in this buffer should we skip ahead and start reading
1101///
1102/// * `rows_to_skip` - how many rows to skip
1103/// * `buffer_len` - length buf buffer (in bytes)
1104/// * `bits_per_value` - number of bits used to represent a single bitpacked value
1105/// * `buffer_start_bit_offset` - offset of the start of the first value within the
1106///   buffer's  first byte
1107/// * `buffer_end_bit_offset` - end bit of the last value within the buffer. Can be
1108///   `None` if the end of the last value is byte aligned with end of buffer.
1109fn compute_start_offset(
1110    rows_to_skip: u64,
1111    buffer_len: usize,
1112    bits_per_value: u64,
1113    buffer_start_bit_offset: u8,
1114    buffer_end_bit_offset: Option<u8>,
1115) -> StartOffset {
1116    let rows_in_buffer = rows_in_buffer(
1117        buffer_len,
1118        bits_per_value,
1119        buffer_start_bit_offset,
1120        buffer_end_bit_offset,
1121    );
1122    if rows_to_skip >= rows_in_buffer {
1123        return StartOffset::SkipFull(rows_in_buffer);
1124    }
1125
1126    let start_bit = rows_to_skip * bits_per_value + buffer_start_bit_offset as u64;
1127    let start_byte = start_bit / 8;
1128
1129    StartOffset::SkipSome(BufferStartOffset {
1130        index: start_byte as usize,
1131        bit_offset: (start_bit % 8) as u8,
1132    })
1133}
1134
1135/// calculates the number of rows in a buffer
1136fn rows_in_buffer(
1137    buffer_len: usize,
1138    bits_per_value: u64,
1139    buffer_start_bit_offset: u8,
1140    buffer_end_bit_offset: Option<u8>,
1141) -> u64 {
1142    let mut bits_in_buffer = (buffer_len * 8) as u64 - buffer_start_bit_offset as u64;
1143
1144    // if the end of the last value of the buffer isn't byte aligned, subtract the
1145    // end offset from the total number of bits in buffer
1146    if let Some(buffer_end_bit_offset) = buffer_end_bit_offset {
1147        bits_in_buffer -= (8 - buffer_end_bit_offset) as u64;
1148    }
1149
1150    bits_in_buffer / bits_per_value
1151}
1152
1153#[cfg(test)]
1154pub mod test {
1155    use crate::{
1156        format::pb,
1157        testing::{check_round_trip_encoding_generated, ArrayGeneratorProvider},
1158        version::LanceFileVersion,
1159    };
1160
1161    use super::*;
1162    use std::{collections::HashMap, marker::PhantomData, sync::Arc};
1163
1164    use arrow_array::{
1165        types::{UInt16Type, UInt8Type},
1166        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
1167        UInt16Array, UInt32Array, UInt64Array, UInt8Array,
1168    };
1169
1170    use arrow_schema::Field;
1171    use lance_datagen::{
1172        array::{fill, rand_with_distribution},
1173        gen, ArrayGenerator, ArrayGeneratorExt, RowCount,
1174    };
1175    use rand::distributions::Uniform;
1176
1177    #[test]
1178    fn test_bitpack_params() {
1179        fn gen_array(generator: Box<dyn ArrayGenerator>) -> ArrayRef {
1180            let arr = gen()
1181                .anon_col(generator)
1182                .into_batch_rows(RowCount::from(10000))
1183                .unwrap()
1184                .column(0)
1185                .clone();
1186
1187            arr
1188        }
1189
1190        macro_rules! do_test {
1191            ($num_bits:expr, $data_type:ident, $null_probability:expr) => {
1192                let max = 1 << $num_bits - 1;
1193                let mut arr =
1194                    gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
1195
1196                // ensure we don't randomly generate all nulls, that won't work
1197                while arr.null_count() == arr.len() {
1198                    arr = gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
1199                }
1200                let result = bitpack_params(arr.as_ref());
1201                assert!(result.is_some());
1202                assert_eq!($num_bits, result.unwrap().num_bits);
1203            };
1204        }
1205
1206        let test_cases = vec![
1207            (5u64, 0.0f64),
1208            (5u64, 0.9f64),
1209            (1u64, 0.0f64),
1210            (1u64, 0.5f64),
1211            (8u64, 0.0f64),
1212            (8u64, 0.5f64),
1213        ];
1214
1215        for (num_bits, null_probability) in &test_cases {
1216            do_test!(*num_bits, UInt8Type, *null_probability);
1217            do_test!(*num_bits, UInt16Type, *null_probability);
1218            do_test!(*num_bits, UInt32Type, *null_probability);
1219            do_test!(*num_bits, UInt64Type, *null_probability);
1220        }
1221
1222        // do some test cases that that will only work on larger types
1223        let test_cases = vec![
1224            (13u64, 0.0f64),
1225            (13u64, 0.5f64),
1226            (16u64, 0.0f64),
1227            (16u64, 0.5f64),
1228        ];
1229        for (num_bits, null_probability) in &test_cases {
1230            do_test!(*num_bits, UInt16Type, *null_probability);
1231            do_test!(*num_bits, UInt32Type, *null_probability);
1232            do_test!(*num_bits, UInt64Type, *null_probability);
1233        }
1234        let test_cases = vec![
1235            (25u64, 0.0f64),
1236            (25u64, 0.5f64),
1237            (32u64, 0.0f64),
1238            (32u64, 0.5f64),
1239        ];
1240        for (num_bits, null_probability) in &test_cases {
1241            do_test!(*num_bits, UInt32Type, *null_probability);
1242            do_test!(*num_bits, UInt64Type, *null_probability);
1243        }
1244        let test_cases = vec![
1245            (48u64, 0.0f64),
1246            (48u64, 0.5f64),
1247            (64u64, 0.0f64),
1248            (64u64, 0.5f64),
1249        ];
1250        for (num_bits, null_probability) in &test_cases {
1251            do_test!(*num_bits, UInt64Type, *null_probability);
1252        }
1253
1254        // test that it returns None for datatypes that don't support bitpacking
1255        let arr = Float64Array::from_iter_values(vec![0.1, 0.2, 0.3]);
1256        let result = bitpack_params(&arr);
1257        assert!(result.is_none());
1258    }
1259
1260    #[test]
1261    fn test_num_compressed_bits_signed_types() {
1262        let values = Int32Array::from(vec![1, 2, -7]);
1263        let arr = values;
1264        let result = bitpack_params(&arr);
1265        assert!(result.is_some());
1266        let result = result.unwrap();
1267        assert_eq!(4, result.num_bits);
1268        assert!(result.signed);
1269
1270        // check that it doesn't add a sign bit if it doesn't need to
1271        let values = Int32Array::from(vec![1, 2, 7]);
1272        let arr = values;
1273        let result = bitpack_params(&arr);
1274        assert!(result.is_some());
1275        let result = result.unwrap();
1276        assert_eq!(3, result.num_bits);
1277        assert!(!result.signed);
1278    }
1279
1280    #[test]
1281    fn test_rows_in_buffer() {
1282        let test_cases = vec![
1283            (5usize, 5u64, 0u8, None, 8u64),
1284            (2, 3, 0, Some(5), 4),
1285            (2, 3, 7, Some(6), 2),
1286        ];
1287
1288        for (
1289            buffer_len,
1290            bits_per_value,
1291            buffer_start_bit_offset,
1292            buffer_end_bit_offset,
1293            expected,
1294        ) in test_cases
1295        {
1296            let result = rows_in_buffer(
1297                buffer_len,
1298                bits_per_value,
1299                buffer_start_bit_offset,
1300                buffer_end_bit_offset,
1301            );
1302            assert_eq!(expected, result);
1303        }
1304    }
1305
1306    #[test]
1307    fn test_compute_start_offset() {
1308        let result = compute_start_offset(0, 5, 5, 0, None);
1309        assert_eq!(
1310            StartOffset::SkipSome(BufferStartOffset {
1311                index: 0,
1312                bit_offset: 0
1313            }),
1314            result
1315        );
1316
1317        let result = compute_start_offset(10, 5, 5, 0, None);
1318        assert_eq!(StartOffset::SkipFull(8), result);
1319    }
1320
1321    #[test_log::test(test)]
1322    fn test_will_bitpack_allowed_types_when_possible() {
1323        let test_cases: Vec<(DataType, ArrayRef, u64)> = vec![
1324            (
1325                DataType::UInt8,
1326                Arc::new(UInt8Array::from_iter_values(vec![0, 1, 2, 3, 4, 5])),
1327                3, // bits per value
1328            ),
1329            (
1330                DataType::UInt16,
1331                Arc::new(UInt16Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 8])),
1332                11,
1333            ),
1334            (
1335                DataType::UInt32,
1336                Arc::new(UInt32Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 16])),
1337                19,
1338            ),
1339            (
1340                DataType::UInt64,
1341                Arc::new(UInt64Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 32])),
1342                35,
1343            ),
1344            (
1345                DataType::Int8,
1346                Arc::new(Int8Array::from_iter_values(vec![0, 2, 3, 4, -5])),
1347                4,
1348            ),
1349            (
1350                // check it will not pack with signed bit if all values of signed type are positive
1351                DataType::Int8,
1352                Arc::new(Int8Array::from_iter_values(vec![0, 2, 3, 4, 5])),
1353                3,
1354            ),
1355            (
1356                DataType::Int16,
1357                Arc::new(Int16Array::from_iter_values(vec![0, 1, 2, 3, -4, 5 << 8])),
1358                12,
1359            ),
1360            (
1361                DataType::Int32,
1362                Arc::new(Int32Array::from_iter_values(vec![0, 1, 2, 3, 4, -5 << 16])),
1363                20,
1364            ),
1365            (
1366                DataType::Int64,
1367                Arc::new(Int64Array::from_iter_values(vec![
1368                    0,
1369                    1,
1370                    2,
1371                    -3,
1372                    -4,
1373                    -5 << 32,
1374                ])),
1375                36,
1376            ),
1377        ];
1378
1379        for (data_type, arr, bits_per_value) in test_cases {
1380            let mut buffed_index = 1;
1381            let params = bitpack_params(arr.as_ref()).unwrap();
1382            let encoder = BitpackedArrayEncoder {
1383                num_bits: params.num_bits,
1384                signed_type: params.signed,
1385            };
1386            let data = DataBlock::from_array(arr);
1387            let result = encoder.encode(data, &data_type, &mut buffed_index).unwrap();
1388
1389            let data = result.data.as_fixed_width().unwrap();
1390            assert_eq!(bits_per_value, data.bits_per_value);
1391
1392            let array_encoding = result.encoding.array_encoding.unwrap();
1393
1394            match array_encoding {
1395                pb::array_encoding::ArrayEncoding::Bitpacked(bitpacked) => {
1396                    assert_eq!(bits_per_value, bitpacked.compressed_bits_per_value);
1397                    assert_eq!(
1398                        (data_type.byte_width() * 8) as u64,
1399                        bitpacked.uncompressed_bits_per_value
1400                    );
1401                }
1402                _ => {
1403                    panic!("Array did not use bitpacking encoding")
1404                }
1405            }
1406        }
1407
1408        // check it will otherwise use flat encoding
1409        let test_cases: Vec<(DataType, ArrayRef)> = vec![
1410            // it should use flat encoding for datatypes that don't support bitpacking
1411            (
1412                DataType::Float32,
1413                Arc::new(Float32Array::from_iter_values(vec![0.1, 0.2, 0.3])),
1414            ),
1415            // it should still use flat encoding if bitpacked encoding would be packed
1416            // into the full byte range
1417            (
1418                DataType::UInt8,
1419                Arc::new(UInt8Array::from_iter_values(vec![0, 1, 2, 3, 4, 250])),
1420            ),
1421            (
1422                DataType::UInt16,
1423                Arc::new(UInt16Array::from_iter_values(vec![0, 1, 2, 3, 4, 250 << 8])),
1424            ),
1425            (
1426                DataType::UInt32,
1427                Arc::new(UInt32Array::from_iter_values(vec![
1428                    0,
1429                    1,
1430                    2,
1431                    3,
1432                    4,
1433                    250 << 24,
1434                ])),
1435            ),
1436            (
1437                DataType::UInt64,
1438                Arc::new(UInt64Array::from_iter_values(vec![
1439                    0,
1440                    1,
1441                    2,
1442                    3,
1443                    4,
1444                    250 << 56,
1445                ])),
1446            ),
1447            (
1448                DataType::Int8,
1449                Arc::new(Int8Array::from_iter_values(vec![-100])),
1450            ),
1451            (
1452                DataType::Int16,
1453                Arc::new(Int16Array::from_iter_values(vec![-100 << 8])),
1454            ),
1455            (
1456                DataType::Int32,
1457                Arc::new(Int32Array::from_iter_values(vec![-100 << 24])),
1458            ),
1459            (
1460                DataType::Int64,
1461                Arc::new(Int64Array::from_iter_values(vec![-100 << 56])),
1462            ),
1463        ];
1464
1465        for (data_type, arr) in test_cases {
1466            if let Some(params) = bitpack_params(arr.as_ref()) {
1467                assert_eq!(params.num_bits, data_type.byte_width() as u64 * 8);
1468            }
1469        }
1470    }
1471
1472    struct DistributionArrayGeneratorProvider<
1473        DataType,
1474        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1475    >
1476    where
1477        DataType::Native: Copy + 'static,
1478        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1479        DataType: ArrowPrimitiveType,
1480    {
1481        phantom: PhantomData<DataType>,
1482        distribution: Dist,
1483    }
1484
1485    impl<DataType, Dist> DistributionArrayGeneratorProvider<DataType, Dist>
1486    where
1487        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1488        DataType::Native: Copy + 'static,
1489        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1490        DataType: ArrowPrimitiveType,
1491    {
1492        fn new(dist: Dist) -> Self {
1493            Self {
1494                distribution: dist,
1495                phantom: Default::default(),
1496            }
1497        }
1498    }
1499
1500    impl<DataType, Dist> ArrayGeneratorProvider for DistributionArrayGeneratorProvider<DataType, Dist>
1501    where
1502        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1503        DataType::Native: Copy + 'static,
1504        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1505        DataType: ArrowPrimitiveType,
1506    {
1507        fn provide(&self) -> Box<dyn ArrayGenerator> {
1508            rand_with_distribution::<DataType, Dist>(self.distribution.clone())
1509        }
1510
1511        fn copy(&self) -> Box<dyn ArrayGeneratorProvider> {
1512            Box::new(Self {
1513                phantom: self.phantom,
1514                distribution: self.distribution.clone(),
1515            })
1516        }
1517    }
1518
1519    #[test_log::test(tokio::test)]
1520    async fn test_bitpack_primitive() {
1521        let bitpacked_test_cases: &Vec<(DataType, Box<dyn ArrayGeneratorProvider>)> = &vec![
1522            // check less than one byte for multi-byte type
1523            (
1524                DataType::UInt32,
1525                Box::new(
1526                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1527                        Uniform::new(0, 19),
1528                    ),
1529                ),
1530            ),
1531            // // check that more than one byte for multi-byte type
1532            (
1533                DataType::UInt32,
1534                Box::new(
1535                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1536                        Uniform::new(5 << 7, 6 << 7),
1537                    ),
1538                ),
1539            ),
1540            (
1541                DataType::UInt64,
1542                Box::new(
1543                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1544                        Uniform::new(5 << 42, 6 << 42),
1545                    ),
1546                ),
1547            ),
1548            // check less than one byte for single-byte type
1549            (
1550                DataType::UInt8,
1551                Box::new(
1552                    DistributionArrayGeneratorProvider::<UInt8Type, Uniform<u8>>::new(
1553                        Uniform::new(0, 19),
1554                    ),
1555                ),
1556            ),
1557            // check less than one byte for single-byte type
1558            (
1559                DataType::UInt64,
1560                Box::new(
1561                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1562                        Uniform::new(129, 259),
1563                    ),
1564                ),
1565            ),
1566            // check byte aligned for single byte
1567            (
1568                DataType::UInt32,
1569                Box::new(
1570                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1571                        // this range should always give 8 bits
1572                        Uniform::new(200, 250),
1573                    ),
1574                ),
1575            ),
1576            // check where the num_bits divides evenly into the bit length of the type
1577            (
1578                DataType::UInt64,
1579                Box::new(
1580                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1581                        Uniform::new(1, 3), // 2 bits
1582                    ),
1583                ),
1584            ),
1585            // check byte aligned for multiple bytes
1586            (
1587                DataType::UInt32,
1588                Box::new(
1589                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1590                        // this range should always always give 16 bits
1591                        Uniform::new(200 << 8, 250 << 8),
1592                    ),
1593                ),
1594            ),
1595            // check byte aligned where the num bits doesn't divide evenly into the byte length
1596            (
1597                DataType::UInt64,
1598                Box::new(
1599                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1600                        // this range should always give 24 hits
1601                        Uniform::new(200 << 16, 250 << 16),
1602                    ),
1603                ),
1604            ),
1605            // check that we can still encode an all-0 array
1606            (
1607                DataType::UInt32,
1608                Box::new(
1609                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1610                        Uniform::new(0, 1),
1611                    ),
1612                ),
1613            ),
1614            // check for signed types
1615            (
1616                DataType::Int16,
1617                Box::new(
1618                    DistributionArrayGeneratorProvider::<Int16Type, Uniform<i16>>::new(
1619                        Uniform::new(-5, 5),
1620                    ),
1621                ),
1622            ),
1623            (
1624                DataType::Int64,
1625                Box::new(
1626                    DistributionArrayGeneratorProvider::<Int64Type, Uniform<i64>>::new(
1627                        Uniform::new(-(5 << 42), 6 << 42),
1628                    ),
1629                ),
1630            ),
1631            (
1632                DataType::Int32,
1633                Box::new(
1634                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1635                        Uniform::new(-(5 << 7), 6 << 7),
1636                    ),
1637                ),
1638            ),
1639            // check signed where packed to < 1 byte for multi-byte type
1640            (
1641                DataType::Int32,
1642                Box::new(
1643                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1644                        Uniform::new(-19, 19),
1645                    ),
1646                ),
1647            ),
1648            // check signed byte aligned to single byte
1649            (
1650                DataType::Int32,
1651                Box::new(
1652                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1653                        // this range should always give 8 bits
1654                        Uniform::new(-120, 120),
1655                    ),
1656                ),
1657            ),
1658            // check signed byte aligned to multiple bytes
1659            (
1660                DataType::Int32,
1661                Box::new(
1662                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1663                        // this range should always give 16 bits
1664                        Uniform::new(-120 << 8, 120 << 8),
1665                    ),
1666                ),
1667            ),
1668            // check that it works for all positive integers even if type is signed
1669            (
1670                DataType::Int32,
1671                Box::new(
1672                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1673                        Uniform::new(10, 20),
1674                    ),
1675                ),
1676            ),
1677            // check that all 0 works for signed type
1678            (
1679                DataType::Int32,
1680                Box::new(
1681                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1682                        Uniform::new(0, 1),
1683                    ),
1684                ),
1685            ),
1686        ];
1687
1688        for (data_type, array_gen_provider) in bitpacked_test_cases {
1689            let mut metadata = HashMap::new();
1690            metadata.insert(
1691                "lance-encoding:compression".to_string(),
1692                "bitpacking".to_string(),
1693            );
1694            let field = Field::new("", data_type.clone(), false).with_metadata(metadata);
1695            check_round_trip_encoding_generated(
1696                field,
1697                array_gen_provider.copy(),
1698                LanceFileVersion::V2_1,
1699            )
1700            .await;
1701        }
1702    }
1703}