lance_encoding/previous/encodings/physical/
bitpack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow_array::types::{
7    Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
8};
9use arrow_array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
10use arrow_buffer::bit_util::ceil;
11use arrow_buffer::ArrowNativeType;
12use arrow_schema::DataType;
13use bytes::Bytes;
14use futures::future::{BoxFuture, FutureExt};
15use log::trace;
16use num_traits::{AsPrimitive, PrimInt};
17use snafu::location;
18
19use lance_arrow::DataTypeExt;
20use lance_bitpacking::BitPacking;
21use lance_core::{Error, Result};
22
23use crate::buffer::LanceBuffer;
24use crate::data::BlockInfo;
25use crate::data::{DataBlock, FixedWidthDataBlock, NullableDataBlock};
26use crate::decoder::{PageScheduler, PrimitivePageDecoder};
27use crate::format::ProtobufUtils;
28use crate::previous::encoder::{ArrayEncoder, EncodedArray};
29use bytemuck::cast_slice;
30
31const LOG_ELEMS_PER_CHUNK: u8 = 10;
32const ELEMS_PER_CHUNK: u64 = 1 << LOG_ELEMS_PER_CHUNK;
33
34// Compute the compressed_bit_width for a given array of integers
35// todo: compute all statistics before encoding
36// todo: see how to use rust macro to rewrite this function
37pub fn compute_compressed_bit_width_for_non_neg(arrays: &[ArrayRef]) -> u64 {
38    debug_assert!(!arrays.is_empty());
39
40    let res;
41
42    match arrays[0].data_type() {
43        DataType::UInt8 => {
44            let mut global_max: u8 = 0;
45            for array in arrays {
46                let primitive_array = array
47                    .as_any()
48                    .downcast_ref::<PrimitiveArray<UInt8Type>>()
49                    .unwrap();
50                let array_max = arrow_arith::aggregate::bit_or(primitive_array);
51                global_max = global_max.max(array_max.unwrap_or(0));
52            }
53            let num_bits =
54                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
55            // we will have constant encoding later
56            if num_bits == 0 {
57                res = 1;
58            } else {
59                res = num_bits;
60            }
61        }
62
63        DataType::Int8 => {
64            let mut global_max_width: u64 = 0;
65            for array in arrays {
66                let primitive_array = array
67                    .as_any()
68                    .downcast_ref::<PrimitiveArray<Int8Type>>()
69                    .unwrap();
70                let array_max_width = arrow_arith::aggregate::bit_or(primitive_array).unwrap_or(0);
71                global_max_width = global_max_width.max(8 - array_max_width.leading_zeros() as u64);
72            }
73            if global_max_width == 0 {
74                res = 1;
75            } else {
76                res = global_max_width;
77            }
78        }
79
80        DataType::UInt16 => {
81            let mut global_max: u16 = 0;
82            for array in arrays {
83                let primitive_array = array
84                    .as_any()
85                    .downcast_ref::<PrimitiveArray<UInt16Type>>()
86                    .unwrap();
87                let array_max = arrow_arith::aggregate::bit_or(primitive_array).unwrap_or(0);
88                global_max = global_max.max(array_max);
89            }
90            let num_bits =
91                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
92            if num_bits == 0 {
93                res = 1;
94            } else {
95                res = num_bits;
96            }
97        }
98
99        DataType::Int16 => {
100            let mut global_max_width: u64 = 0;
101            for array in arrays {
102                let primitive_array = array
103                    .as_any()
104                    .downcast_ref::<PrimitiveArray<Int16Type>>()
105                    .unwrap();
106                let array_max_width = arrow_arith::aggregate::bit_or(primitive_array).unwrap_or(0);
107                global_max_width =
108                    global_max_width.max(16 - array_max_width.leading_zeros() as u64);
109            }
110            if global_max_width == 0 {
111                res = 1;
112            } else {
113                res = global_max_width;
114            }
115        }
116
117        DataType::UInt32 => {
118            let mut global_max: u32 = 0;
119            for array in arrays {
120                let primitive_array = array
121                    .as_any()
122                    .downcast_ref::<PrimitiveArray<UInt32Type>>()
123                    .unwrap();
124                let array_max = arrow_arith::aggregate::bit_or(primitive_array).unwrap_or(0);
125                global_max = global_max.max(array_max);
126            }
127            let num_bits =
128                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
129            if num_bits == 0 {
130                res = 1;
131            } else {
132                res = num_bits;
133            }
134        }
135
136        DataType::Int32 => {
137            let mut global_max_width: u64 = 0;
138            for array in arrays {
139                let primitive_array = array
140                    .as_any()
141                    .downcast_ref::<PrimitiveArray<Int32Type>>()
142                    .unwrap();
143                let array_max_width = arrow_arith::aggregate::bit_or(primitive_array).unwrap_or(0);
144                global_max_width =
145                    global_max_width.max(32 - array_max_width.leading_zeros() as u64);
146            }
147            if global_max_width == 0 {
148                res = 1;
149            } else {
150                res = global_max_width;
151            }
152        }
153
154        DataType::UInt64 => {
155            let mut global_max: u64 = 0;
156            for array in arrays {
157                let primitive_array = array
158                    .as_any()
159                    .downcast_ref::<PrimitiveArray<UInt64Type>>()
160                    .unwrap();
161                let array_max = arrow_arith::aggregate::bit_or(primitive_array).unwrap_or(0);
162                global_max = global_max.max(array_max);
163            }
164            let num_bits =
165                arrays[0].data_type().byte_width() as u64 * 8 - global_max.leading_zeros() as u64;
166            if num_bits == 0 {
167                res = 1;
168            } else {
169                res = num_bits;
170            }
171        }
172
173        DataType::Int64 => {
174            let mut global_max_width: u64 = 0;
175            for array in arrays {
176                let primitive_array = array
177                    .as_any()
178                    .downcast_ref::<PrimitiveArray<Int64Type>>()
179                    .unwrap();
180                let array_max_width = arrow_arith::aggregate::bit_or(primitive_array).unwrap_or(0);
181                global_max_width =
182                    global_max_width.max(64 - array_max_width.leading_zeros() as u64);
183            }
184            if global_max_width == 0 {
185                res = 1;
186            } else {
187                res = global_max_width;
188            }
189        }
190        _ => {
191            panic!("BitpackedForNonNegArrayEncoder only supports data types of UInt8, Int8, UInt16, Int16, UInt32, Int32, UInt64, Int64");
192        }
193    };
194    res
195}
196
197// Bitpack integers using fastlanes algorithm, the input is sliced into chunks of 1024 integers, and bitpacked
198// chunk by chunk. when the input is not a multiple of 1024, the last chunk is padded with zeros, this is fine because
199// we also know the number of rows we have.
200// Here self is a borrow of BitpackedForNonNegArrayEncoder, unpacked is a mutable borrow of FixedWidthDataBlock,
201// data_type can be  one of u8, u16, u32, or u64.
202// buffer_index is a mutable borrow of u32, indicating the buffer index of the output EncodedArray.
203// It outputs an fastlanes bitpacked EncodedArray
204macro_rules! encode_fixed_width {
205    ($self:expr, $unpacked:expr, $data_type:ty, $buffer_index:expr) => {{
206        let num_chunks = $unpacked.num_values.div_ceil(ELEMS_PER_CHUNK);
207        let num_full_chunks = $unpacked.num_values / ELEMS_PER_CHUNK;
208        let uncompressed_bit_width = std::mem::size_of::<$data_type>() as u64 * 8;
209
210        // the output vector type is the same as the input type, for example, when input is u16, output is Vec<u16>
211        let packed_chunk_size = 1024 * $self.compressed_bit_width as usize / uncompressed_bit_width as usize;
212
213        let input_slice = $unpacked.data.borrow_to_typed_slice::<$data_type>();
214        let input = input_slice.as_ref();
215
216        let mut output = Vec::with_capacity(num_chunks as usize * packed_chunk_size);
217
218        // Loop over all but the last chunk.
219        (0..num_full_chunks).for_each(|i| {
220            let start_elem = (i * ELEMS_PER_CHUNK) as usize;
221
222            let output_len = output.len();
223            unsafe {
224                output.set_len(output_len + packed_chunk_size);
225                BitPacking::unchecked_pack(
226                    $self.compressed_bit_width,
227                    &input[start_elem..][..ELEMS_PER_CHUNK as usize],
228                    &mut output[output_len..][..packed_chunk_size],
229                );
230            }
231        });
232
233        if num_chunks != num_full_chunks {
234            let last_chunk_elem_num = $unpacked.num_values % ELEMS_PER_CHUNK;
235            let mut last_chunk = vec![0 as $data_type; ELEMS_PER_CHUNK as usize];
236            last_chunk[..last_chunk_elem_num as usize].clone_from_slice(
237                &input[$unpacked.num_values as usize - last_chunk_elem_num as usize..],
238            );
239
240            let output_len = output.len();
241            unsafe {
242                output.set_len(output_len + packed_chunk_size);
243                BitPacking::unchecked_pack(
244                    $self.compressed_bit_width,
245                    &last_chunk,
246                    &mut output[output_len..][..packed_chunk_size],
247                );
248            }
249        }
250
251        let bitpacked_for_non_neg_buffer_index = *$buffer_index;
252        *$buffer_index += 1;
253
254        let encoding = ProtobufUtils::bitpacked_for_non_neg_encoding(
255            $self.compressed_bit_width as u64,
256            uncompressed_bit_width,
257            bitpacked_for_non_neg_buffer_index,
258        );
259        let packed = DataBlock::FixedWidth(FixedWidthDataBlock {
260            bits_per_value: $self.compressed_bit_width as u64,
261            data: LanceBuffer::reinterpret_vec(output),
262            num_values: $unpacked.num_values,
263            block_info: BlockInfo::new(),
264        });
265
266        Result::Ok(EncodedArray {
267            data: packed,
268            encoding,
269        })
270    }};
271}
272
273#[derive(Debug)]
274pub struct BitpackedForNonNegArrayEncoder {
275    pub compressed_bit_width: usize,
276    pub original_data_type: DataType,
277}
278
279impl BitpackedForNonNegArrayEncoder {
280    pub fn new(compressed_bit_width: usize, data_type: DataType) -> Self {
281        Self {
282            compressed_bit_width,
283            original_data_type: data_type,
284        }
285    }
286}
287
288impl ArrayEncoder for BitpackedForNonNegArrayEncoder {
289    fn encode(
290        &self,
291        data: DataBlock,
292        data_type: &DataType,
293        buffer_index: &mut u32,
294    ) -> Result<EncodedArray> {
295        match data {
296            DataBlock::AllNull(_) => {
297                let encoding = ProtobufUtils::basic_all_null_encoding();
298                Ok(EncodedArray { data, encoding })
299            }
300            DataBlock::FixedWidth(unpacked) => {
301                match data_type {
302                    DataType::UInt8 | DataType::Int8 => encode_fixed_width!(self, unpacked, u8, buffer_index),
303                    DataType::UInt16 | DataType::Int16 => encode_fixed_width!(self, unpacked, u16, buffer_index),
304                    DataType::UInt32 | DataType::Int32 => encode_fixed_width!(self, unpacked, u32, buffer_index),
305                    DataType::UInt64 | DataType::Int64 => encode_fixed_width!(self, unpacked, u64, buffer_index),
306                    _ => unreachable!("BitpackedForNonNegArrayEncoder only supports data types of UInt8, Int8, UInt16, Int16, UInt32, Int32, UInt64, Int64"),
307                }
308            }
309            DataBlock::Nullable(nullable) => {
310                let validity_buffer_index = *buffer_index;
311                *buffer_index += 1;
312
313                let validity_desc = ProtobufUtils::flat_encoding(
314                    1,
315                    validity_buffer_index,
316                    /*compression=*/ None,
317                );
318                let encoded_values: EncodedArray;
319                match *nullable.data {
320                    DataBlock::FixedWidth(unpacked) => {
321                        match data_type {
322                            DataType::UInt8 | DataType::Int8 => encoded_values = encode_fixed_width!(self, unpacked, u8, buffer_index)?,
323                            DataType::UInt16 | DataType::Int16 => encoded_values = encode_fixed_width!(self, unpacked, u16, buffer_index)?,
324                            DataType::UInt32 | DataType::Int32 => encoded_values = encode_fixed_width!(self, unpacked, u32, buffer_index)?,
325                            DataType::UInt64 | DataType::Int64 => encoded_values = encode_fixed_width!(self, unpacked, u64, buffer_index)?,
326                            _ => unreachable!("BitpackedForNonNegArrayEncoder only supports data types of UInt8, Int8, UInt16, Int16, UInt32, Int32, UInt64, Int64"),
327                        }
328                    }
329                    _ => {
330                        return Err(Error::InvalidInput {
331                            source: "Bitpacking only supports fixed width data blocks or a nullable data block with fixed width data block inside or a all null data block".into(),
332                            location: location!(),
333                        });
334                    }
335                }
336                let encoding =
337                    ProtobufUtils::basic_some_null_encoding(validity_desc, encoded_values.encoding);
338                let encoded = DataBlock::Nullable(NullableDataBlock {
339                    data: Box::new(encoded_values.data),
340                    nulls: nullable.nulls,
341                    block_info: BlockInfo::new(),
342                });
343                Ok(EncodedArray {
344                    data: encoded,
345                    encoding,
346                })
347            }
348            _ => {
349                Err(Error::InvalidInput {
350                    source: "Bitpacking only supports fixed width data blocks or a nullable data block with fixed width data block inside or a all null data block".into(),
351                    location: location!(),
352                })
353            }
354        }
355    }
356}
357
358#[derive(Debug)]
359pub struct BitpackedForNonNegScheduler {
360    compressed_bit_width: u64,
361    uncompressed_bits_per_value: u64,
362    buffer_offset: u64,
363}
364
365impl BitpackedForNonNegScheduler {
366    pub fn new(
367        compressed_bit_width: u64,
368        uncompressed_bits_per_value: u64,
369        buffer_offset: u64,
370    ) -> Self {
371        Self {
372            compressed_bit_width,
373            uncompressed_bits_per_value,
374            buffer_offset,
375        }
376    }
377
378    fn locate_chunk_start(&self, relative_row_num: u64) -> u64 {
379        let chunk_size = ELEMS_PER_CHUNK * self.compressed_bit_width / 8;
380        self.buffer_offset + (relative_row_num / ELEMS_PER_CHUNK * chunk_size)
381    }
382
383    fn locate_chunk_end(&self, relative_row_num: u64) -> u64 {
384        let chunk_size = ELEMS_PER_CHUNK * self.compressed_bit_width / 8;
385        self.buffer_offset + (relative_row_num / ELEMS_PER_CHUNK * chunk_size) + chunk_size
386    }
387}
388
389impl PageScheduler for BitpackedForNonNegScheduler {
390    fn schedule_ranges(
391        &self,
392        ranges: &[std::ops::Range<u64>],
393        scheduler: &Arc<dyn crate::EncodingsIo>,
394        top_level_row: u64,
395    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
396        assert!(!ranges.is_empty());
397
398        let mut byte_ranges = vec![];
399
400        // map one bytes to multiple ranges, one bytes has at least one range corresponding to it
401        let mut bytes_idx_to_range_indices = vec![];
402        let first_byte_range = std::ops::Range {
403            start: self.locate_chunk_start(ranges[0].start),
404            end: self.locate_chunk_end(ranges[0].end - 1),
405        }; // the ranges are half-open
406        byte_ranges.push(first_byte_range);
407        bytes_idx_to_range_indices.push(vec![ranges[0].clone()]);
408
409        for (i, range) in ranges.iter().enumerate().skip(1) {
410            let this_start = self.locate_chunk_start(range.start);
411            let this_end = self.locate_chunk_end(range.end - 1);
412
413            // when the current range start is in the same chunk as the previous range's end, we colaesce this two bytes ranges
414            // when the current range start is not in the same chunk as the previous range's end, we create a new bytes range
415            if this_start == self.locate_chunk_start(ranges[i - 1].end - 1) {
416                byte_ranges.last_mut().unwrap().end = this_end;
417                bytes_idx_to_range_indices
418                    .last_mut()
419                    .unwrap()
420                    .push(range.clone());
421            } else {
422                byte_ranges.push(this_start..this_end);
423                bytes_idx_to_range_indices.push(vec![range.clone()]);
424            }
425        }
426
427        trace!(
428            "Scheduling I/O for {} ranges spread across byte range {}..{}",
429            byte_ranges.len(),
430            byte_ranges[0].start,
431            byte_ranges.last().unwrap().end
432        );
433
434        let bytes = scheduler.submit_request(byte_ranges.clone(), top_level_row);
435
436        // copy the necessary data from `self` to move into the async block
437        let compressed_bit_width = self.compressed_bit_width;
438        let uncompressed_bits_per_value = self.uncompressed_bits_per_value;
439        let num_rows = ranges.iter().map(|range| range.end - range.start).sum();
440
441        async move {
442            let bytes = bytes.await?;
443            let decompressed_output = bitpacked_for_non_neg_decode(
444                compressed_bit_width,
445                uncompressed_bits_per_value,
446                &bytes,
447                &bytes_idx_to_range_indices,
448                num_rows,
449            );
450            Ok(Box::new(BitpackedForNonNegPageDecoder {
451                uncompressed_bits_per_value,
452                decompressed_buf: decompressed_output,
453            }) as Box<dyn PrimitivePageDecoder>)
454        }
455        .boxed()
456    }
457}
458
459#[derive(Debug)]
460struct BitpackedForNonNegPageDecoder {
461    // number of bits in the uncompressed value. E.g. this will be 32 for DataType::UInt32
462    uncompressed_bits_per_value: u64,
463
464    decompressed_buf: LanceBuffer,
465}
466
467impl PrimitivePageDecoder for BitpackedForNonNegPageDecoder {
468    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
469        if ![8, 16, 32, 64].contains(&self.uncompressed_bits_per_value) {
470            return Err(Error::InvalidInput {
471                source: "BitpackedForNonNegPageDecoder should only has uncompressed_bits_per_value of 8, 16, 32, or 64".into(),
472                location: location!(),
473            });
474        }
475
476        let elem_size_in_bytes = self.uncompressed_bits_per_value / 8;
477
478        Ok(DataBlock::FixedWidth(FixedWidthDataBlock {
479            data: self.decompressed_buf.slice_with_length(
480                (rows_to_skip * elem_size_in_bytes) as usize,
481                (num_rows * elem_size_in_bytes) as usize,
482            ),
483            bits_per_value: self.uncompressed_bits_per_value,
484            num_values: num_rows,
485            block_info: BlockInfo::new(),
486        }))
487    }
488}
489
490macro_rules! bitpacked_decode {
491    ($uncompressed_type:ty, $compressed_bit_width:expr, $data:expr, $bytes_idx_to_range_indices:expr, $num_rows:expr) => {{
492        let mut decompressed: Vec<$uncompressed_type> = Vec::with_capacity($num_rows as usize);
493        let packed_chunk_size_in_byte: usize = (ELEMS_PER_CHUNK * $compressed_bit_width) as usize / 8;
494        let mut decompress_chunk_buf = vec![0 as $uncompressed_type; ELEMS_PER_CHUNK as usize];
495
496        for (i, bytes) in $data.iter().enumerate() {
497            let mut ranges_idx = 0;
498            let mut curr_range_start = $bytes_idx_to_range_indices[i][0].start;
499            let mut chunk_num = 0;
500
501            while chunk_num * packed_chunk_size_in_byte < bytes.len() {
502                // Copy for memory alignment
503                // TODO: This copy should not be needed
504                let chunk_in_u8: Vec<u8> = bytes[chunk_num * packed_chunk_size_in_byte..]
505                    [..packed_chunk_size_in_byte]
506                    .to_vec();
507                chunk_num += 1;
508                let chunk = cast_slice(&chunk_in_u8);
509                unsafe {
510                    BitPacking::unchecked_unpack(
511                        $compressed_bit_width as usize,
512                        chunk,
513                        &mut decompress_chunk_buf,
514                    );
515                }
516
517                loop {
518                    // Case 1: All the elements after (curr_range_start % ELEMS_PER_CHUNK) inside this chunk are needed.
519                    let elems_after_curr_range_start_in_this_chunk =
520                        ELEMS_PER_CHUNK - curr_range_start % ELEMS_PER_CHUNK;
521                    if curr_range_start + elems_after_curr_range_start_in_this_chunk
522                        <= $bytes_idx_to_range_indices[i][ranges_idx].end
523                    {
524                        decompressed.extend_from_slice(
525                            &decompress_chunk_buf[(curr_range_start % ELEMS_PER_CHUNK) as usize..],
526                        );
527                        curr_range_start += elems_after_curr_range_start_in_this_chunk;
528                        break;
529                    } else {
530                        // Case 2: Only part of the elements after (curr_range_start % ELEMS_PER_CHUNK) inside this chunk are needed.
531                        let elems_this_range_needed_in_this_chunk =
532                            ($bytes_idx_to_range_indices[i][ranges_idx].end - curr_range_start)
533                                .min(ELEMS_PER_CHUNK - curr_range_start % ELEMS_PER_CHUNK);
534                        decompressed.extend_from_slice(
535                            &decompress_chunk_buf[(curr_range_start % ELEMS_PER_CHUNK) as usize..]
536                                [..elems_this_range_needed_in_this_chunk as usize],
537                        );
538                        if curr_range_start + elems_this_range_needed_in_this_chunk
539                            == $bytes_idx_to_range_indices[i][ranges_idx].end
540                        {
541                            ranges_idx += 1;
542                            if ranges_idx == $bytes_idx_to_range_indices[i].len() {
543                                break;
544                            }
545                            curr_range_start = $bytes_idx_to_range_indices[i][ranges_idx].start;
546                        } else {
547                            curr_range_start += elems_this_range_needed_in_this_chunk;
548                        }
549                    }
550                }
551            }
552        }
553
554        LanceBuffer::reinterpret_vec(decompressed)
555    }};
556}
557
558fn bitpacked_for_non_neg_decode(
559    compressed_bit_width: u64,
560    uncompressed_bits_per_value: u64,
561    data: &[Bytes],
562    bytes_idx_to_range_indices: &[Vec<std::ops::Range<u64>>],
563    num_rows: u64,
564) -> LanceBuffer {
565    match uncompressed_bits_per_value {
566        8 => bitpacked_decode!(
567            u8,
568            compressed_bit_width,
569            data,
570            bytes_idx_to_range_indices,
571            num_rows
572        ),
573        16 => bitpacked_decode!(
574            u16,
575            compressed_bit_width,
576            data,
577            bytes_idx_to_range_indices,
578            num_rows
579        ),
580        32 => bitpacked_decode!(
581            u32,
582            compressed_bit_width,
583            data,
584            bytes_idx_to_range_indices,
585            num_rows
586        ),
587        64 => bitpacked_decode!(
588            u64,
589            compressed_bit_width,
590            data,
591            bytes_idx_to_range_indices,
592            num_rows
593        ),
594        _ => unreachable!(
595            "bitpacked_for_non_neg_decode only supports 8, 16, 32, 64 uncompressed_bits_per_value"
596        ),
597    }
598}
599
600#[derive(Debug)]
601pub struct BitpackParams {
602    pub num_bits: u64,
603
604    pub signed: bool,
605}
606
607// Compute the number of bits to use for each item, if this array can be encoded using
608// bitpacking encoding. Returns `None` if the type or array data is not supported.
609pub fn bitpack_params(arr: &dyn Array) -> Option<BitpackParams> {
610    match arr.data_type() {
611        DataType::UInt8 => bitpack_params_for_type::<UInt8Type>(arr.as_primitive()),
612        DataType::UInt16 => bitpack_params_for_type::<UInt16Type>(arr.as_primitive()),
613        DataType::UInt32 => bitpack_params_for_type::<UInt32Type>(arr.as_primitive()),
614        DataType::UInt64 => bitpack_params_for_type::<UInt64Type>(arr.as_primitive()),
615        DataType::Int8 => bitpack_params_for_signed_type::<Int8Type>(arr.as_primitive()),
616        DataType::Int16 => bitpack_params_for_signed_type::<Int16Type>(arr.as_primitive()),
617        DataType::Int32 => bitpack_params_for_signed_type::<Int32Type>(arr.as_primitive()),
618        DataType::Int64 => bitpack_params_for_signed_type::<Int64Type>(arr.as_primitive()),
619        // TODO -- eventually we could support temporal types as well
620        _ => None,
621    }
622}
623
624// Compute the number bits to to use for bitpacking generically.
625// returns None if the array is empty or all nulls
626fn bitpack_params_for_type<T>(arr: &PrimitiveArray<T>) -> Option<BitpackParams>
627where
628    T: ArrowPrimitiveType,
629    T::Native: PrimInt + AsPrimitive<u64>,
630{
631    let max = arrow_arith::aggregate::bit_or(arr);
632    let num_bits =
633        max.map(|max| arr.data_type().byte_width() as u64 * 8 - max.leading_zeros() as u64);
634
635    // we can't bitpack into 0 bits, so the minimum is 1
636    num_bits
637        .map(|num_bits| num_bits.max(1))
638        .map(|bits| BitpackParams {
639            num_bits: bits,
640            signed: false,
641        })
642}
643
644/// determine the minimum number of bits that can be used to represent
645/// an array of signed values. It includes all the significant bits for
646/// the value + plus 1 bit to represent the sign. If there are no negative values
647/// then it will not add a signed bit
648fn bitpack_params_for_signed_type<T>(arr: &PrimitiveArray<T>) -> Option<BitpackParams>
649where
650    T: ArrowPrimitiveType,
651    T::Native: PrimInt + AsPrimitive<i64>,
652{
653    let mut add_signed_bit = false;
654    let mut min_leading_bits: Option<u64> = None;
655    for val in arr.iter() {
656        if val.is_none() {
657            continue;
658        }
659        let val = val.unwrap();
660        if min_leading_bits.is_none() {
661            min_leading_bits = Some(u64::MAX);
662        }
663
664        if val.to_i64().unwrap() < 0i64 {
665            min_leading_bits = min_leading_bits.map(|bits| bits.min(val.leading_ones() as u64));
666            add_signed_bit = true;
667        } else {
668            min_leading_bits = min_leading_bits.map(|bits| bits.min(val.leading_zeros() as u64));
669        }
670    }
671
672    let mut min_leading_bits = arr.data_type().byte_width() as u64 * 8 - min_leading_bits?;
673    if add_signed_bit {
674        // Need extra sign bit
675        min_leading_bits += 1;
676    }
677    // cannot bitpack into <1 bit
678    let num_bits = min_leading_bits.max(1);
679    Some(BitpackParams {
680        num_bits,
681        signed: add_signed_bit,
682    })
683}
684#[derive(Debug)]
685pub struct BitpackedArrayEncoder {
686    num_bits: u64,
687    signed_type: bool,
688}
689
690impl BitpackedArrayEncoder {
691    pub fn new(num_bits: u64, signed_type: bool) -> Self {
692        Self {
693            num_bits,
694            signed_type,
695        }
696    }
697}
698
699impl ArrayEncoder for BitpackedArrayEncoder {
700    fn encode(
701        &self,
702        data: DataBlock,
703        _data_type: &DataType,
704        buffer_index: &mut u32,
705    ) -> Result<EncodedArray> {
706        // calculate the total number of bytes we need to allocate for the destination.
707        // this will be the number of items in the source array times the number of bits.
708        let dst_bytes_total = ceil(data.num_values() as usize * self.num_bits as usize, 8);
709
710        let mut dst_buffer = vec![0u8; dst_bytes_total];
711        let mut dst_idx = 0;
712        let mut dst_offset = 0;
713
714        let DataBlock::FixedWidth(unpacked) = data else {
715            return Err(Error::InvalidInput {
716                source: "Bitpacking only supports fixed width data blocks".into(),
717                location: location!(),
718            });
719        };
720
721        pack_bits(
722            &unpacked.data,
723            self.num_bits,
724            &mut dst_buffer,
725            &mut dst_idx,
726            &mut dst_offset,
727        );
728
729        let packed = DataBlock::FixedWidth(FixedWidthDataBlock {
730            bits_per_value: self.num_bits,
731            data: LanceBuffer::from(dst_buffer),
732            num_values: unpacked.num_values,
733            block_info: BlockInfo::new(),
734        });
735
736        let bitpacked_buffer_index = *buffer_index;
737        *buffer_index += 1;
738
739        let encoding = ProtobufUtils::bitpacked_encoding(
740            self.num_bits,
741            unpacked.bits_per_value,
742            bitpacked_buffer_index,
743            self.signed_type,
744        );
745
746        Ok(EncodedArray {
747            data: packed,
748            encoding,
749        })
750    }
751}
752
753fn pack_bits(
754    src: &LanceBuffer,
755    num_bits: u64,
756    dst: &mut [u8],
757    dst_idx: &mut usize,
758    dst_offset: &mut u8,
759) {
760    let bit_len = src.len() as u64 * 8;
761
762    let mask = u64::MAX >> (64 - num_bits);
763
764    let mut src_idx = 0;
765    while src_idx < src.len() {
766        let mut curr_mask = mask;
767        let mut curr_src = src[src_idx] & curr_mask as u8;
768        let mut src_offset = 0;
769        let mut src_bits_written = 0;
770
771        while src_bits_written < num_bits {
772            dst[*dst_idx] += (curr_src >> src_offset) << *dst_offset as u64;
773            let bits_written = (num_bits - src_bits_written)
774                .min(8 - src_offset)
775                .min(8 - *dst_offset as u64);
776            src_bits_written += bits_written;
777            *dst_offset += bits_written as u8;
778            src_offset += bits_written;
779
780            if *dst_offset == 8 {
781                *dst_idx += 1;
782                *dst_offset = 0;
783            }
784
785            if src_offset == 8 {
786                src_idx += 1;
787                src_offset = 0;
788                curr_mask >>= 8;
789                if src_idx == src.len() {
790                    break;
791                }
792                curr_src = src[src_idx] & curr_mask as u8;
793            }
794        }
795
796        // advance source_offset to the next byte if we're not at the end..
797        // note that we don't need to do this if we wrote the full number of bits
798        // because source index would have been advanced by the inner loop above
799        if bit_len != num_bits {
800            let partial_bytes_written = ceil(num_bits as usize, 8);
801
802            // we also want to the next location in src, unless we wrote something
803            // byte-aligned in which case the logic above would have already advanced
804            let mut to_next_byte = 1;
805            if num_bits % 8 == 0 {
806                to_next_byte = 0;
807            }
808
809            src_idx += src.len() - partial_bytes_written + to_next_byte;
810        }
811    }
812}
813
814// A physical scheduler for bitpacked buffers
815#[derive(Debug, Clone, Copy)]
816pub struct BitpackedScheduler {
817    bits_per_value: u64,
818    uncompressed_bits_per_value: u64,
819    buffer_offset: u64,
820    signed: bool,
821}
822
823impl BitpackedScheduler {
824    pub fn new(
825        bits_per_value: u64,
826        uncompressed_bits_per_value: u64,
827        buffer_offset: u64,
828        signed: bool,
829    ) -> Self {
830        Self {
831            bits_per_value,
832            uncompressed_bits_per_value,
833            buffer_offset,
834            signed,
835        }
836    }
837}
838
839impl PageScheduler for BitpackedScheduler {
840    fn schedule_ranges(
841        &self,
842        ranges: &[std::ops::Range<u64>],
843        scheduler: &Arc<dyn crate::EncodingsIo>,
844        top_level_row: u64,
845    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
846        let mut min = u64::MAX;
847        let mut max = 0;
848
849        let mut buffer_bit_start_offsets: Vec<u8> = vec![];
850        let mut buffer_bit_end_offsets: Vec<Option<u8>> = vec![];
851        let byte_ranges = ranges
852            .iter()
853            .map(|range| {
854                let start_byte_offset = range.start * self.bits_per_value / 8;
855                let mut end_byte_offset = range.end * self.bits_per_value / 8;
856                if range.end * self.bits_per_value % 8 != 0 {
857                    // If the end of the range is not byte-aligned, we need to read one more byte
858                    end_byte_offset += 1;
859
860                    let end_bit_offset = range.end * self.bits_per_value % 8;
861                    buffer_bit_end_offsets.push(Some(end_bit_offset as u8));
862                } else {
863                    buffer_bit_end_offsets.push(None);
864                }
865
866                let start_bit_offset = range.start * self.bits_per_value % 8;
867                buffer_bit_start_offsets.push(start_bit_offset as u8);
868
869                let start = self.buffer_offset + start_byte_offset;
870                let end = self.buffer_offset + end_byte_offset;
871                min = min.min(start);
872                max = max.max(end);
873
874                start..end
875            })
876            .collect::<Vec<_>>();
877
878        trace!(
879            "Scheduling I/O for {} ranges spread across byte range {}..{}",
880            byte_ranges.len(),
881            min,
882            max
883        );
884
885        let bytes = scheduler.submit_request(byte_ranges, top_level_row);
886
887        let bits_per_value = self.bits_per_value;
888        let uncompressed_bits_per_value = self.uncompressed_bits_per_value;
889        let signed = self.signed;
890        async move {
891            let bytes = bytes.await?;
892            Ok(Box::new(BitpackedPageDecoder {
893                buffer_bit_start_offsets,
894                buffer_bit_end_offsets,
895                bits_per_value,
896                uncompressed_bits_per_value,
897                signed,
898                data: bytes,
899            }) as Box<dyn PrimitivePageDecoder>)
900        }
901        .boxed()
902    }
903}
904
905#[derive(Debug)]
906struct BitpackedPageDecoder {
907    // bit offsets of the first value within each buffer
908    buffer_bit_start_offsets: Vec<u8>,
909
910    // bit offsets of the last value within each buffer. e.g. if there was a buffer
911    // with 2 values, packed into 5 bits, this would be [Some(3)], indicating that
912    // the bits from the 3rd->8th bit in the last byte shouldn't be decoded.
913    buffer_bit_end_offsets: Vec<Option<u8>>,
914
915    // the number of bits used to represent a compressed value. E.g. if the max value
916    // in the page was 7 (0b111), then this will be 3
917    bits_per_value: u64,
918
919    // number of bits in the uncompressed value. E.g. this will be 32 for u32
920    uncompressed_bits_per_value: u64,
921
922    // whether or not to use the msb as a sign bit during decoding
923    signed: bool,
924
925    data: Vec<Bytes>,
926}
927
928impl PrimitivePageDecoder for BitpackedPageDecoder {
929    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
930        let num_bytes = self.uncompressed_bits_per_value / 8 * num_rows;
931        let mut dest = vec![0; num_bytes as usize];
932
933        // current maximum supported bits per value = 64
934        debug_assert!(self.bits_per_value <= 64);
935
936        let mut rows_to_skip = rows_to_skip;
937        let mut rows_taken = 0;
938        let byte_len = self.uncompressed_bits_per_value / 8;
939        let mut dst_idx = 0; // index for current byte being written to destination buffer
940
941        // create bit mask for source bits
942        let mask = u64::MAX >> (64 - self.bits_per_value);
943
944        for i in 0..self.data.len() {
945            let src = &self.data[i];
946            let (mut src_idx, mut src_offset) = match compute_start_offset(
947                rows_to_skip,
948                src.len(),
949                self.bits_per_value,
950                self.buffer_bit_start_offsets[i],
951                self.buffer_bit_end_offsets[i],
952            ) {
953                StartOffset::SkipFull(rows_to_skip_here) => {
954                    rows_to_skip -= rows_to_skip_here;
955                    continue;
956                }
957                StartOffset::SkipSome(buffer_start_offset) => (
958                    buffer_start_offset.index,
959                    buffer_start_offset.bit_offset as u64,
960                ),
961            };
962
963            while src_idx < src.len() && rows_taken < num_rows {
964                rows_taken += 1;
965                let mut curr_mask = mask; // copy mask
966
967                // current source byte being written to destination
968                let mut curr_src = src[src_idx] & (curr_mask << src_offset) as u8;
969
970                // how many bits from the current source value have been written to destination
971                let mut src_bits_written = 0;
972
973                // the offset within the current destination byte to write to
974                let mut dst_offset = 0;
975
976                let is_negative = is_encoded_item_negative(
977                    src,
978                    src_idx,
979                    src_offset,
980                    self.bits_per_value as usize,
981                );
982
983                while src_bits_written < self.bits_per_value {
984                    // write bits from current source byte into destination
985                    dest[dst_idx] += (curr_src >> src_offset) << dst_offset;
986                    let bits_written = (self.bits_per_value - src_bits_written)
987                        .min(8 - src_offset)
988                        .min(8 - dst_offset);
989                    src_bits_written += bits_written;
990                    dst_offset += bits_written;
991                    src_offset += bits_written;
992                    curr_mask >>= bits_written;
993
994                    if dst_offset == 8 {
995                        dst_idx += 1;
996                        dst_offset = 0;
997                    }
998
999                    if src_offset == 8 {
1000                        src_idx += 1;
1001                        src_offset = 0;
1002                        if src_idx == src.len() {
1003                            break;
1004                        }
1005                        curr_src = src[src_idx] & curr_mask as u8;
1006                    }
1007                }
1008
1009                // if the type is signed, need to pad out the rest of the byte with 1s
1010                let mut negative_padded_current_byte = false;
1011                if self.signed && is_negative && dst_offset > 0 {
1012                    negative_padded_current_byte = true;
1013                    while dst_offset < 8 {
1014                        dest[dst_idx] |= 1 << dst_offset;
1015                        dst_offset += 1;
1016                    }
1017                }
1018
1019                // advance destination offset to the next location
1020                // note that we don't need to do this if we wrote the full number of bits
1021                // because source index would have been advanced by the inner loop above
1022                if self.uncompressed_bits_per_value != self.bits_per_value {
1023                    let partial_bytes_written = ceil(self.bits_per_value as usize, 8);
1024
1025                    // we also want to move one location to the next location in destination,
1026                    // unless we wrote something byte-aligned in which case the logic above
1027                    // would have already advanced dst_idx
1028                    let mut to_next_byte = 1;
1029                    if self.bits_per_value % 8 == 0 {
1030                        to_next_byte = 0;
1031                    }
1032                    let next_dst_idx =
1033                        dst_idx + byte_len as usize - partial_bytes_written + to_next_byte;
1034
1035                    // pad remaining bytes with 1 for negative signed numbers
1036                    if self.signed && is_negative {
1037                        if !negative_padded_current_byte {
1038                            dest[dst_idx] = 0xFF;
1039                        }
1040                        for i in dest.iter_mut().take(next_dst_idx).skip(dst_idx + 1) {
1041                            *i = 0xFF;
1042                        }
1043                    }
1044
1045                    dst_idx = next_dst_idx;
1046                }
1047
1048                // If we've reached the last byte, there may be some extra bits from the
1049                // next value outside the range. We don't want to be taking those.
1050                if let Some(buffer_bit_end_offset) = self.buffer_bit_end_offsets[i] {
1051                    if src_idx == src.len() - 1 && src_offset >= buffer_bit_end_offset as u64 {
1052                        break;
1053                    }
1054                }
1055            }
1056        }
1057
1058        Ok(DataBlock::FixedWidth(FixedWidthDataBlock {
1059            data: LanceBuffer::from(dest),
1060            bits_per_value: self.uncompressed_bits_per_value,
1061            num_values: num_rows,
1062            block_info: BlockInfo::new(),
1063        }))
1064    }
1065}
1066
1067fn is_encoded_item_negative(src: &Bytes, src_idx: usize, src_offset: u64, num_bits: usize) -> bool {
1068    let mut last_byte_idx = src_idx + ((src_offset as usize + num_bits) / 8);
1069    let shift_amount = (src_offset as usize + num_bits) % 8;
1070    let shift_amount = if shift_amount == 0 {
1071        last_byte_idx -= 1;
1072        7
1073    } else {
1074        shift_amount - 1
1075    };
1076    let last_byte = src[last_byte_idx];
1077    let sign_bit_mask = 1 << shift_amount;
1078    let sign_bit = last_byte & sign_bit_mask;
1079
1080    sign_bit > 0
1081}
1082
1083#[derive(Debug, PartialEq)]
1084struct BufferStartOffset {
1085    index: usize,
1086    bit_offset: u8,
1087}
1088
1089#[derive(Debug, PartialEq)]
1090enum StartOffset {
1091    // skip the full buffer. The value is how many rows are skipped
1092    // by skipping the full buffer (e.g., # rows in buffer)
1093    SkipFull(u64),
1094
1095    // skip to some start offset in the buffer
1096    SkipSome(BufferStartOffset),
1097}
1098
1099/// compute how far ahead in this buffer should we skip ahead and start reading
1100///
1101/// * `rows_to_skip` - how many rows to skip
1102/// * `buffer_len` - length buf buffer (in bytes)
1103/// * `bits_per_value` - number of bits used to represent a single bitpacked value
1104/// * `buffer_start_bit_offset` - offset of the start of the first value within the
1105///   buffer's  first byte
1106/// * `buffer_end_bit_offset` - end bit of the last value within the buffer. Can be
1107///   `None` if the end of the last value is byte aligned with end of buffer.
1108fn compute_start_offset(
1109    rows_to_skip: u64,
1110    buffer_len: usize,
1111    bits_per_value: u64,
1112    buffer_start_bit_offset: u8,
1113    buffer_end_bit_offset: Option<u8>,
1114) -> StartOffset {
1115    let rows_in_buffer = rows_in_buffer(
1116        buffer_len,
1117        bits_per_value,
1118        buffer_start_bit_offset,
1119        buffer_end_bit_offset,
1120    );
1121    if rows_to_skip >= rows_in_buffer {
1122        return StartOffset::SkipFull(rows_in_buffer);
1123    }
1124
1125    let start_bit = rows_to_skip * bits_per_value + buffer_start_bit_offset as u64;
1126    let start_byte = start_bit / 8;
1127
1128    StartOffset::SkipSome(BufferStartOffset {
1129        index: start_byte as usize,
1130        bit_offset: (start_bit % 8) as u8,
1131    })
1132}
1133
1134/// calculates the number of rows in a buffer
1135fn rows_in_buffer(
1136    buffer_len: usize,
1137    bits_per_value: u64,
1138    buffer_start_bit_offset: u8,
1139    buffer_end_bit_offset: Option<u8>,
1140) -> u64 {
1141    let mut bits_in_buffer = (buffer_len * 8) as u64 - buffer_start_bit_offset as u64;
1142
1143    // if the end of the last value of the buffer isn't byte aligned, subtract the
1144    // end offset from the total number of bits in buffer
1145    if let Some(buffer_end_bit_offset) = buffer_end_bit_offset {
1146        bits_in_buffer -= (8 - buffer_end_bit_offset) as u64;
1147    }
1148
1149    bits_in_buffer / bits_per_value
1150}
1151
1152#[cfg(test)]
1153pub mod test {
1154    use crate::{
1155        format::pb,
1156        testing::{check_round_trip_encoding_generated, ArrayGeneratorProvider, TestCases},
1157        version::LanceFileVersion,
1158    };
1159
1160    use super::*;
1161    use std::{marker::PhantomData, sync::Arc};
1162
1163    use arrow_array::{
1164        types::{UInt16Type, UInt8Type},
1165        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
1166        UInt16Array, UInt32Array, UInt64Array, UInt8Array,
1167    };
1168
1169    use arrow_schema::Field;
1170    use lance_datagen::{
1171        array::{fill, rand_with_distribution},
1172        gen_batch, ArrayGenerator, ArrayGeneratorExt, RowCount,
1173    };
1174    use rand::distr::Uniform;
1175
1176    #[test]
1177    fn test_bitpack_params() {
1178        fn gen_array(generator: Box<dyn ArrayGenerator>) -> ArrayRef {
1179            let arr = gen_batch()
1180                .anon_col(generator)
1181                .into_batch_rows(RowCount::from(10000))
1182                .unwrap()
1183                .column(0)
1184                .clone();
1185
1186            arr
1187        }
1188
1189        macro_rules! do_test {
1190            ($num_bits:expr, $data_type:ident, $null_probability:expr) => {
1191                let max = 1 << $num_bits - 1;
1192                let mut arr =
1193                    gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
1194
1195                // ensure we don't randomly generate all nulls, that won't work
1196                while arr.null_count() == arr.len() {
1197                    arr = gen_array(fill::<$data_type>(max).with_random_nulls($null_probability));
1198                }
1199                let result = bitpack_params(arr.as_ref());
1200                assert!(result.is_some());
1201                assert_eq!($num_bits, result.unwrap().num_bits);
1202            };
1203        }
1204
1205        let test_cases = vec![
1206            (5u64, 0.0f64),
1207            (5u64, 0.9f64),
1208            (1u64, 0.0f64),
1209            (1u64, 0.5f64),
1210            (8u64, 0.0f64),
1211            (8u64, 0.5f64),
1212        ];
1213
1214        for (num_bits, null_probability) in &test_cases {
1215            do_test!(*num_bits, UInt8Type, *null_probability);
1216            do_test!(*num_bits, UInt16Type, *null_probability);
1217            do_test!(*num_bits, UInt32Type, *null_probability);
1218            do_test!(*num_bits, UInt64Type, *null_probability);
1219        }
1220
1221        // do some test cases that that will only work on larger types
1222        let test_cases = vec![
1223            (13u64, 0.0f64),
1224            (13u64, 0.5f64),
1225            (16u64, 0.0f64),
1226            (16u64, 0.5f64),
1227        ];
1228        for (num_bits, null_probability) in &test_cases {
1229            do_test!(*num_bits, UInt16Type, *null_probability);
1230            do_test!(*num_bits, UInt32Type, *null_probability);
1231            do_test!(*num_bits, UInt64Type, *null_probability);
1232        }
1233        let test_cases = vec![
1234            (25u64, 0.0f64),
1235            (25u64, 0.5f64),
1236            (32u64, 0.0f64),
1237            (32u64, 0.5f64),
1238        ];
1239        for (num_bits, null_probability) in &test_cases {
1240            do_test!(*num_bits, UInt32Type, *null_probability);
1241            do_test!(*num_bits, UInt64Type, *null_probability);
1242        }
1243        let test_cases = vec![
1244            (48u64, 0.0f64),
1245            (48u64, 0.5f64),
1246            (64u64, 0.0f64),
1247            (64u64, 0.5f64),
1248        ];
1249        for (num_bits, null_probability) in &test_cases {
1250            do_test!(*num_bits, UInt64Type, *null_probability);
1251        }
1252
1253        // test that it returns None for datatypes that don't support bitpacking
1254        let arr = Float64Array::from_iter_values(vec![0.1, 0.2, 0.3]);
1255        let result = bitpack_params(&arr);
1256        assert!(result.is_none());
1257    }
1258
1259    #[test]
1260    fn test_num_compressed_bits_signed_types() {
1261        let values = Int32Array::from(vec![1, 2, -7]);
1262        let arr = values;
1263        let result = bitpack_params(&arr);
1264        assert!(result.is_some());
1265        let result = result.unwrap();
1266        assert_eq!(4, result.num_bits);
1267        assert!(result.signed);
1268
1269        // check that it doesn't add a sign bit if it doesn't need to
1270        let values = Int32Array::from(vec![1, 2, 7]);
1271        let arr = values;
1272        let result = bitpack_params(&arr);
1273        assert!(result.is_some());
1274        let result = result.unwrap();
1275        assert_eq!(3, result.num_bits);
1276        assert!(!result.signed);
1277    }
1278
1279    #[test]
1280    fn test_rows_in_buffer() {
1281        let test_cases = vec![
1282            (5usize, 5u64, 0u8, None, 8u64),
1283            (2, 3, 0, Some(5), 4),
1284            (2, 3, 7, Some(6), 2),
1285        ];
1286
1287        for (
1288            buffer_len,
1289            bits_per_value,
1290            buffer_start_bit_offset,
1291            buffer_end_bit_offset,
1292            expected,
1293        ) in test_cases
1294        {
1295            let result = rows_in_buffer(
1296                buffer_len,
1297                bits_per_value,
1298                buffer_start_bit_offset,
1299                buffer_end_bit_offset,
1300            );
1301            assert_eq!(expected, result);
1302        }
1303    }
1304
1305    #[test]
1306    fn test_compute_start_offset() {
1307        let result = compute_start_offset(0, 5, 5, 0, None);
1308        assert_eq!(
1309            StartOffset::SkipSome(BufferStartOffset {
1310                index: 0,
1311                bit_offset: 0
1312            }),
1313            result
1314        );
1315
1316        let result = compute_start_offset(10, 5, 5, 0, None);
1317        assert_eq!(StartOffset::SkipFull(8), result);
1318    }
1319
1320    #[test_log::test(test)]
1321    fn test_will_bitpack_allowed_types_when_possible() {
1322        let test_cases: Vec<(DataType, ArrayRef, u64)> = vec![
1323            (
1324                DataType::UInt8,
1325                Arc::new(UInt8Array::from_iter_values(vec![0, 1, 2, 3, 4, 5])),
1326                3, // bits per value
1327            ),
1328            (
1329                DataType::UInt16,
1330                Arc::new(UInt16Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 8])),
1331                11,
1332            ),
1333            (
1334                DataType::UInt32,
1335                Arc::new(UInt32Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 16])),
1336                19,
1337            ),
1338            (
1339                DataType::UInt64,
1340                Arc::new(UInt64Array::from_iter_values(vec![0, 1, 2, 3, 4, 5 << 32])),
1341                35,
1342            ),
1343            (
1344                DataType::Int8,
1345                Arc::new(Int8Array::from_iter_values(vec![0, 2, 3, 4, -5])),
1346                4,
1347            ),
1348            (
1349                // check it will not pack with signed bit if all values of signed type are positive
1350                DataType::Int8,
1351                Arc::new(Int8Array::from_iter_values(vec![0, 2, 3, 4, 5])),
1352                3,
1353            ),
1354            (
1355                DataType::Int16,
1356                Arc::new(Int16Array::from_iter_values(vec![0, 1, 2, 3, -4, 5 << 8])),
1357                12,
1358            ),
1359            (
1360                DataType::Int32,
1361                Arc::new(Int32Array::from_iter_values(vec![0, 1, 2, 3, 4, -5 << 16])),
1362                20,
1363            ),
1364            (
1365                DataType::Int64,
1366                Arc::new(Int64Array::from_iter_values(vec![
1367                    0,
1368                    1,
1369                    2,
1370                    -3,
1371                    -4,
1372                    -5 << 32,
1373                ])),
1374                36,
1375            ),
1376        ];
1377
1378        for (data_type, arr, bits_per_value) in test_cases {
1379            let mut buffed_index = 1;
1380            let params = bitpack_params(arr.as_ref()).unwrap();
1381            let encoder = BitpackedArrayEncoder {
1382                num_bits: params.num_bits,
1383                signed_type: params.signed,
1384            };
1385            let data = DataBlock::from_array(arr);
1386            let result = encoder.encode(data, &data_type, &mut buffed_index).unwrap();
1387
1388            let data = result.data.as_fixed_width().unwrap();
1389            assert_eq!(bits_per_value, data.bits_per_value);
1390
1391            let array_encoding = result.encoding.array_encoding.unwrap();
1392
1393            match array_encoding {
1394                pb::array_encoding::ArrayEncoding::Bitpacked(bitpacked) => {
1395                    assert_eq!(bits_per_value, bitpacked.compressed_bits_per_value);
1396                    assert_eq!(
1397                        (data_type.byte_width() * 8) as u64,
1398                        bitpacked.uncompressed_bits_per_value
1399                    );
1400                }
1401                _ => {
1402                    panic!("Array did not use bitpacking encoding")
1403                }
1404            }
1405        }
1406
1407        // check it will otherwise use flat encoding
1408        let test_cases: Vec<(DataType, ArrayRef)> = vec![
1409            // it should use flat encoding for datatypes that don't support bitpacking
1410            (
1411                DataType::Float32,
1412                Arc::new(Float32Array::from_iter_values(vec![0.1, 0.2, 0.3])),
1413            ),
1414            // it should still use flat encoding if bitpacked encoding would be packed
1415            // into the full byte range
1416            (
1417                DataType::UInt8,
1418                Arc::new(UInt8Array::from_iter_values(vec![0, 1, 2, 3, 4, 250])),
1419            ),
1420            (
1421                DataType::UInt16,
1422                Arc::new(UInt16Array::from_iter_values(vec![0, 1, 2, 3, 4, 250 << 8])),
1423            ),
1424            (
1425                DataType::UInt32,
1426                Arc::new(UInt32Array::from_iter_values(vec![
1427                    0,
1428                    1,
1429                    2,
1430                    3,
1431                    4,
1432                    250 << 24,
1433                ])),
1434            ),
1435            (
1436                DataType::UInt64,
1437                Arc::new(UInt64Array::from_iter_values(vec![
1438                    0,
1439                    1,
1440                    2,
1441                    3,
1442                    4,
1443                    250 << 56,
1444                ])),
1445            ),
1446            (
1447                DataType::Int8,
1448                Arc::new(Int8Array::from_iter_values(vec![-100])),
1449            ),
1450            (
1451                DataType::Int16,
1452                Arc::new(Int16Array::from_iter_values(vec![-100 << 8])),
1453            ),
1454            (
1455                DataType::Int32,
1456                Arc::new(Int32Array::from_iter_values(vec![-100 << 24])),
1457            ),
1458            (
1459                DataType::Int64,
1460                Arc::new(Int64Array::from_iter_values(vec![-100 << 56])),
1461            ),
1462        ];
1463
1464        for (data_type, arr) in test_cases {
1465            if let Some(params) = bitpack_params(arr.as_ref()) {
1466                assert_eq!(params.num_bits, data_type.byte_width() as u64 * 8);
1467            }
1468        }
1469    }
1470
1471    struct DistributionArrayGeneratorProvider<
1472        DataType,
1473        Dist: rand::distr::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1474    >
1475    where
1476        DataType::Native: Copy + 'static,
1477        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1478        DataType: ArrowPrimitiveType,
1479    {
1480        phantom: PhantomData<DataType>,
1481        distribution: Dist,
1482    }
1483
1484    impl<DataType, Dist> DistributionArrayGeneratorProvider<DataType, Dist>
1485    where
1486        Dist: rand::distr::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1487        DataType::Native: Copy + 'static,
1488        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1489        DataType: ArrowPrimitiveType,
1490    {
1491        fn new(dist: Dist) -> Self {
1492            Self {
1493                distribution: dist,
1494                phantom: Default::default(),
1495            }
1496        }
1497    }
1498
1499    impl<DataType, Dist> ArrayGeneratorProvider for DistributionArrayGeneratorProvider<DataType, Dist>
1500    where
1501        Dist: rand::distr::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1502        DataType::Native: Copy + 'static,
1503        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1504        DataType: ArrowPrimitiveType,
1505    {
1506        fn provide(&self) -> Box<dyn ArrayGenerator> {
1507            rand_with_distribution::<DataType, Dist>(self.distribution.clone())
1508        }
1509
1510        fn copy(&self) -> Box<dyn ArrayGeneratorProvider> {
1511            Box::new(Self {
1512                phantom: self.phantom,
1513                distribution: self.distribution.clone(),
1514            })
1515        }
1516    }
1517
1518    #[test_log::test(tokio::test)]
1519    async fn test_bitpack_primitive() {
1520        let bitpacked_test_cases: &Vec<(DataType, Box<dyn ArrayGeneratorProvider>)> = &vec![
1521            // check less than one byte for multi-byte type
1522            (
1523                DataType::UInt32,
1524                Box::new(
1525                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1526                        Uniform::new(0, 19).unwrap(),
1527                    ),
1528                ),
1529            ),
1530            // // check that more than one byte for multi-byte type
1531            (
1532                DataType::UInt32,
1533                Box::new(
1534                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1535                        Uniform::new(5 << 7, 6 << 7).unwrap(),
1536                    ),
1537                ),
1538            ),
1539            (
1540                DataType::UInt64,
1541                Box::new(
1542                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1543                        Uniform::new(5 << 42, 6 << 42).unwrap(),
1544                    ),
1545                ),
1546            ),
1547            // check less than one byte for single-byte type
1548            (
1549                DataType::UInt8,
1550                Box::new(
1551                    DistributionArrayGeneratorProvider::<UInt8Type, Uniform<u8>>::new(
1552                        Uniform::new(0, 19).unwrap(),
1553                    ),
1554                ),
1555            ),
1556            // check less than one byte for single-byte type
1557            (
1558                DataType::UInt64,
1559                Box::new(
1560                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1561                        Uniform::new(129, 259).unwrap(),
1562                    ),
1563                ),
1564            ),
1565            // check byte aligned for single byte
1566            (
1567                DataType::UInt32,
1568                Box::new(
1569                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1570                        // this range should always give 8 bits
1571                        Uniform::new(200, 250).unwrap(),
1572                    ),
1573                ),
1574            ),
1575            // check where the num_bits divides evenly into the bit length of the type
1576            (
1577                DataType::UInt64,
1578                Box::new(
1579                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1580                        Uniform::new(1, 3).unwrap(), // 2 bits
1581                    ),
1582                ),
1583            ),
1584            // check byte aligned for multiple bytes
1585            (
1586                DataType::UInt32,
1587                Box::new(
1588                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1589                        // this range should always always give 16 bits
1590                        Uniform::new(200 << 8, 250 << 8).unwrap(),
1591                    ),
1592                ),
1593            ),
1594            // check byte aligned where the num bits doesn't divide evenly into the byte length
1595            (
1596                DataType::UInt64,
1597                Box::new(
1598                    DistributionArrayGeneratorProvider::<UInt64Type, Uniform<u64>>::new(
1599                        // this range should always give 24 hits
1600                        Uniform::new(200 << 16, 250 << 16).unwrap(),
1601                    ),
1602                ),
1603            ),
1604            // check that we can still encode an all-0 array
1605            (
1606                DataType::UInt32,
1607                Box::new(
1608                    DistributionArrayGeneratorProvider::<UInt32Type, Uniform<u32>>::new(
1609                        Uniform::new(0, 1).unwrap(),
1610                    ),
1611                ),
1612            ),
1613            // check for signed types
1614            (
1615                DataType::Int16,
1616                Box::new(
1617                    DistributionArrayGeneratorProvider::<Int16Type, Uniform<i16>>::new(
1618                        Uniform::new(-5, 5).unwrap(),
1619                    ),
1620                ),
1621            ),
1622            (
1623                DataType::Int64,
1624                Box::new(
1625                    DistributionArrayGeneratorProvider::<Int64Type, Uniform<i64>>::new(
1626                        Uniform::new(-(5 << 42), 6 << 42).unwrap(),
1627                    ),
1628                ),
1629            ),
1630            (
1631                DataType::Int32,
1632                Box::new(
1633                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1634                        Uniform::new(-(5 << 7), 6 << 7).unwrap(),
1635                    ),
1636                ),
1637            ),
1638            // check signed where packed to < 1 byte for multi-byte type
1639            (
1640                DataType::Int32,
1641                Box::new(
1642                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1643                        Uniform::new(-19, 19).unwrap(),
1644                    ),
1645                ),
1646            ),
1647            // check signed byte aligned to single byte
1648            (
1649                DataType::Int32,
1650                Box::new(
1651                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1652                        // this range should always give 8 bits
1653                        Uniform::new(-120, 120).unwrap(),
1654                    ),
1655                ),
1656            ),
1657            // check signed byte aligned to multiple bytes
1658            (
1659                DataType::Int32,
1660                Box::new(
1661                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1662                        // this range should always give 16 bits
1663                        Uniform::new(-120 << 8, 120 << 8).unwrap(),
1664                    ),
1665                ),
1666            ),
1667            // check that it works for all positive integers even if type is signed
1668            (
1669                DataType::Int32,
1670                Box::new(
1671                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1672                        Uniform::new(10, 20).unwrap(),
1673                    ),
1674                ),
1675            ),
1676            // check that all 0 works for signed type
1677            (
1678                DataType::Int32,
1679                Box::new(
1680                    DistributionArrayGeneratorProvider::<Int32Type, Uniform<i32>>::new(
1681                        Uniform::new(0, 1).unwrap(),
1682                    ),
1683                ),
1684            ),
1685        ];
1686
1687        for (data_type, array_gen_provider) in bitpacked_test_cases {
1688            let field = Field::new("", data_type.clone(), false);
1689            let test_cases = TestCases::basic().with_min_file_version(LanceFileVersion::V2_1);
1690            check_round_trip_encoding_generated(field, array_gen_provider.copy(), test_cases).await;
1691        }
1692    }
1693}