Skip to main content

liquid_cache/liquid_array/raw/
bit_pack_array.rs

1use std::mem::size_of;
2use std::num::NonZero;
3
4use arrow::array::{ArrowPrimitiveType, PrimitiveArray};
5use arrow::buffer::{BooleanBuffer, Buffer, NullBuffer, ScalarBuffer};
6use arrow::datatypes::ArrowNativeType;
7use bytes;
8use fastlanes::BitPacking;
9
10/// A bit-packed array.
11#[derive(Debug)]
12pub struct BitPackedArray<T: ArrowPrimitiveType>
13where
14    T::Native: BitPacking,
15{
16    packed_values: ScalarBuffer<T::Native>,
17    nulls: Option<NullBuffer>,
18    bit_width: Option<NonZero<u8>>, // if None, the array is entirely null
19    original_len: usize,
20}
21
22/// Implement Clone for any T that implements ArrowPrimitiveType and BitPacking
23/// This allows us to clone it without requiring T to implement Clone
24impl<T: ArrowPrimitiveType> Clone for BitPackedArray<T>
25where
26    T::Native: BitPacking,
27{
28    fn clone(&self) -> Self {
29        Self {
30            packed_values: self.packed_values.clone(),
31            nulls: self.nulls.clone(),
32            bit_width: self.bit_width,
33            original_len: self.original_len,
34        }
35    }
36}
37
38impl<T: ArrowPrimitiveType> BitPackedArray<T>
39where
40    T::Native: BitPacking,
41{
42    /// Creates a new null array with the given length.
43    pub fn new_null_array(len: usize) -> Self {
44        Self {
45            packed_values: vec![T::Native::usize_as(0); len].into(),
46            nulls: Some(NullBuffer::new_null(len)),
47            bit_width: None,
48            original_len: len,
49        }
50    }
51
52    pub(crate) fn len(&self) -> usize {
53        self.original_len
54    }
55
56    pub(crate) fn nulls(&self) -> Option<&NullBuffer> {
57        self.nulls.as_ref()
58    }
59
60    pub(crate) fn bit_width(&self) -> Option<NonZero<u8>> {
61        self.bit_width
62    }
63
64    /// Returns true if the array is nullable.
65    #[cfg(test)]
66    fn is_nullable(&self) -> bool {
67        self.nulls.is_some()
68    }
69
70    /// Creates a new bit-packed array from a primitive array and a bit width.
71    pub fn from_primitive(array: PrimitiveArray<T>, bit_width: NonZero<u8>) -> Self {
72        let original_len = array.len();
73        let (_data_type, values, nulls) = array.into_parts();
74
75        let bit_width_usize = bit_width.get() as usize;
76        let num_chunks = original_len.div_ceil(1024);
77        let num_full_chunks = original_len / 1024;
78        let packed_len = (1024 * bit_width_usize).div_ceil(size_of::<T::Native>() * 8);
79
80        let mut output = Vec::<T::Native>::with_capacity(num_chunks * packed_len);
81
82        (0..num_full_chunks).for_each(|i| {
83            let start_elem = i * 1024;
84
85            output.reserve(packed_len);
86            let output_len = output.len();
87            unsafe {
88                output.set_len(output_len + packed_len);
89                BitPacking::unchecked_pack(
90                    bit_width_usize,
91                    &values[start_elem..][..1024],
92                    &mut output[output_len..][..packed_len],
93                );
94            }
95        });
96
97        if num_chunks != num_full_chunks {
98            let last_chunk_size = values.len() % 1024;
99            let mut last_chunk = vec![T::Native::default(); 1024];
100            last_chunk[..last_chunk_size]
101                .copy_from_slice(&values[values.len() - last_chunk_size..]);
102
103            output.reserve(packed_len);
104            let output_len = output.len();
105            unsafe {
106                output.set_len(output_len + packed_len);
107                BitPacking::unchecked_pack(
108                    bit_width_usize,
109                    &last_chunk,
110                    &mut output[output_len..][..packed_len],
111                );
112            }
113        }
114
115        let buffer = Buffer::from(output);
116        let scalar_buffer = ScalarBuffer::new(buffer, 0, num_chunks * packed_len);
117
118        Self {
119            packed_values: scalar_buffer,
120            nulls,
121            bit_width: Some(bit_width),
122            original_len,
123        }
124    }
125
126    /// Converts the bit-packed array to a primitive array.
127    pub fn to_primitive(&self) -> PrimitiveArray<T> {
128        // Special case for all nulls, don't unpack
129        let bit_width = if let Some(bit_width) = self.bit_width {
130            bit_width.get() as usize
131        } else {
132            return PrimitiveArray::<T>::new_null(self.original_len);
133        };
134        let packed = self.packed_values.as_ref();
135        let length = self.original_len;
136        let offset = 0;
137
138        let num_chunks = (offset + length).div_ceil(1024);
139        let elements_per_chunk = (1024 * bit_width).div_ceil(size_of::<T::Native>() * 8);
140
141        let mut output = Vec::<T::Native>::with_capacity(num_chunks * 1024 - offset);
142
143        let first_full_chunk = if offset != 0 {
144            let chunk: &[T::Native] = &packed[0..elements_per_chunk];
145            let mut decoded = vec![T::Native::default(); 1024];
146            unsafe { BitPacking::unchecked_unpack(bit_width, chunk, &mut decoded) };
147            output.extend_from_slice(&decoded[offset..]);
148            1
149        } else {
150            0
151        };
152
153        (first_full_chunk..num_chunks).for_each(|i| {
154            let chunk: &[T::Native] = &packed[i * elements_per_chunk..][0..elements_per_chunk];
155            unsafe {
156                let output_len = output.len();
157                output.set_len(output_len + 1024);
158                BitPacking::unchecked_unpack(bit_width, chunk, &mut output[output_len..][..1024]);
159            }
160        });
161
162        output.truncate(length);
163        if output.len() < 1024 {
164            output.shrink_to_fit();
165        }
166
167        let nulls = self.nulls.clone();
168        PrimitiveArray::<T>::new(ScalarBuffer::from(output), nulls)
169    }
170
171    /// Returns the memory size of the bit-packed array.
172    pub fn get_array_memory_size(&self) -> usize {
173        std::mem::size_of::<Self>()
174            + self.packed_values.inner().capacity()
175            + self
176                .nulls
177                .as_ref()
178                .map_or(0, |nulls| nulls.buffer().capacity())
179    }
180
181    /*
182    Memory Layout (serialized):
183
184    +-----------------------------+  // Header (16 bytes total)
185    | original_len (4 bytes)      |  // Offset  0 -  3: Array length (u32)
186    +-----------------------------+  //
187    | bit_width (1 byte)          |  // Offset      4: Bit width (u8)
188    +-----------------------------+  //
189    | has_nulls (1 byte)          |  // Offset      5: Null flag (1 if nulls present)
190    +-----------------------------+  //
191    | nulls_len (4 bytes)         |  // Offset  6 -  9: Length of nulls buffer (u32)
192    +-----------------------------+  //
193    | values_len (4 bytes)        |  // Offset 10 - 13: Length of values buffer (u32)
194    +-----------------------------+  //
195    | padding (2 bytes)           |  // Offset 14 - 15: Padding to ensure 16-byte header
196    +-----------------------------+
197
198    [If has_nulls == 1]
199    +-----------------------------+  // Nulls Buffer
200    | nulls data (nulls_len bytes)|  // Offset 16 - (16 + nulls_len - 1)
201    +-----------------------------+
202
203    +-----------------------------+
204    | Padding for 8-byte alignment|  // Ensure values buffer is 8-byte aligned
205    +-----------------------------+
206
207    +-----------------------------+  // Values Buffer (bit-packed data)
208    | values data (values_len)    |  // Starts at the 8-byte aligned offset
209    +-----------------------------+
210    */
211    /// Serializes the bit-packed array to a byte buffer.
212    pub fn to_bytes(&self, buffer: &mut Vec<u8>) {
213        let has_nulls = self.nulls.is_some() as u8;
214
215        let nulls_sliced;
216        let nulls_bytes = if has_nulls == 1 {
217            let nulls = self.nulls.as_ref().unwrap();
218            if nulls.offset() == 0 {
219                nulls.buffer().as_slice()
220            } else {
221                nulls_sliced = Some(nulls.inner().sliced());
222                nulls_sliced.as_ref().unwrap().as_slice()
223            }
224        } else {
225            &[]
226        };
227
228        let values_bytes = self.packed_values.inner().as_slice();
229
230        let header_size = 16;
231
232        let values_offset_base = header_size + if has_nulls == 1 { nulls_bytes.len() } else { 0 };
233        let values_offset = (values_offset_base + 7) & !7;
234
235        let total_size = values_offset + values_bytes.len();
236        buffer.reserve(total_size);
237
238        let start_offset = buffer.len();
239
240        buffer.extend_from_slice(&(self.original_len as u32).to_le_bytes());
241        buffer.push(self.bit_width.map_or(0, |bit_width| bit_width.get()));
242        buffer.push(has_nulls);
243        buffer.extend_from_slice(&(nulls_bytes.len() as u32).to_le_bytes());
244        buffer.extend_from_slice(&(values_bytes.len() as u32).to_le_bytes());
245        buffer.extend_from_slice(&[0, 0]);
246
247        if has_nulls == 1 {
248            buffer.extend_from_slice(nulls_bytes);
249        }
250
251        while (buffer.len() - start_offset) < values_offset {
252            buffer.push(0);
253        }
254
255        buffer.extend_from_slice(values_bytes);
256    }
257
258    /// Deserializes a bit-packed array from a byte buffer.
259    pub fn from_bytes(bytes: bytes::Bytes) -> Self
260    where
261        T::Native: BitPacking,
262    {
263        use std::mem::size_of;
264
265        if bytes.len() < 16 {
266            panic!("Input buffer too small for header");
267        }
268
269        // Read header fields
270        let original_len = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize;
271        let bit_width = bytes[4];
272        let has_nulls = bytes[5] != 0;
273        let nulls_len = u32::from_le_bytes(bytes[6..10].try_into().unwrap()) as usize;
274        let values_len = u32::from_le_bytes(bytes[10..14].try_into().unwrap()) as usize;
275
276        // Calculate offsets
277        let header_size = 16;
278        let nulls_offset = if has_nulls { header_size } else { 0 };
279        let values_offset_base = header_size + if has_nulls { nulls_len } else { 0 };
280        let values_offset = (values_offset_base + 7) & !7; // 8-byte aligned
281
282        if values_len == 0 {
283            // if empty array, return a new null array
284            return Self::new_null_array(original_len);
285        }
286
287        // Validate offsets and lengths
288        if has_nulls {
289            if nulls_offset == 0 || nulls_len == 0 {
290                panic!("Array has nulls but null buffer is missing");
291            }
292            if nulls_offset + nulls_len > bytes.len() {
293                panic!("Null buffer extends beyond input buffer");
294            }
295        }
296
297        if values_offset == 0 || values_len == 0 {
298            panic!("Values buffer is required");
299        }
300        if values_offset + values_len > bytes.len() {
301            panic!("Values buffer extends beyond input buffer");
302        }
303
304        // Create the nulls buffer if present
305        let nulls = if has_nulls {
306            // Create a buffer view into the nulls section
307            let nulls_slice = bytes.slice(nulls_offset..nulls_offset + nulls_len);
308            let nulls_buffer = Buffer::from(nulls_slice);
309            let boolean_buffer = BooleanBuffer::new(nulls_buffer, 0, original_len);
310            Some(NullBuffer::from(boolean_buffer))
311        } else {
312            None
313        };
314
315        let values_slice = bytes.slice(values_offset..values_offset + values_len);
316        let values_buffer = Buffer::from(values_slice);
317
318        let element_size = size_of::<T::Native>();
319        let packed_len = values_len / element_size;
320
321        let packed_values = ScalarBuffer::<T::Native>::new(values_buffer, 0, packed_len);
322
323        if nulls.is_some() && nulls.as_ref().unwrap().null_count() == original_len {
324            return Self::new_null_array(original_len);
325        }
326
327        Self {
328            packed_values,
329            nulls,
330            bit_width: Some(NonZero::new(bit_width).unwrap()),
331            original_len,
332        }
333    }
334}
335
336#[allow(dead_code)]
337fn best_arrow_primitive_width(bit_width: NonZero<u8>) -> usize {
338    match bit_width.get() {
339        0..=8 => 8,
340        9..=16 => 16,
341        17..=32 => 32,
342        33..=64 => 64,
343        _ => panic!("Unsupported bit width: {}", bit_width.get()),
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use arrow::{
351        array::Array,
352        datatypes::{UInt16Type, UInt32Type},
353    };
354
355    #[test]
356    fn test_bit_pack_roundtrip() {
357        // Test with a full chunk (1024 elements)
358        let values: Vec<u32> = (0..1024).collect();
359
360        let array = PrimitiveArray::<UInt32Type>::from(values);
361        let before_size = array.get_array_memory_size();
362        let bit_packed = BitPackedArray::from_primitive(array, NonZero::new(10).unwrap());
363        let after_size = bit_packed.get_array_memory_size();
364        println!("before: {before_size}, after: {after_size}");
365        let unpacked = bit_packed.to_primitive();
366
367        assert_eq!(unpacked.len(), 1024);
368        for i in 0..1024 {
369            assert_eq!(unpacked.value(i), i as u32);
370        }
371    }
372
373    #[test]
374    fn test_bit_pack_partial_chunk() {
375        // Test with a partial chunk (500 elements)
376        let values: Vec<u32> = (0..500).collect();
377        let array = PrimitiveArray::<UInt32Type>::from(values);
378        let bit_packed = BitPackedArray::from_primitive(array, NonZero::new(10).unwrap());
379        let unpacked = bit_packed.to_primitive();
380
381        assert_eq!(unpacked.len(), 500);
382        for i in 0..500 {
383            assert_eq!(unpacked.value(i), i as u32);
384        }
385    }
386
387    #[test]
388    fn test_bit_pack_multiple_chunks() {
389        // Test with multiple chunks (2048 elements = 2 full chunks)
390        let values: Vec<u32> = (0..2048).collect();
391        let array = PrimitiveArray::<UInt32Type>::from(values);
392        let bit_packed = BitPackedArray::from_primitive(array, NonZero::new(11).unwrap());
393        let unpacked = bit_packed.to_primitive();
394
395        assert_eq!(unpacked.len(), 2048);
396        for i in 0..2048 {
397            assert_eq!(unpacked.value(i), i as u32);
398        }
399    }
400
401    #[test]
402    fn test_bit_pack_with_nulls() {
403        let values: Vec<Option<u32>> = (0..1000)
404            .map(|i| if i % 2 == 0 { Some(i as u32) } else { None })
405            .collect();
406        let array = PrimitiveArray::<UInt32Type>::from(values);
407        let bit_packed = BitPackedArray::from_primitive(array, NonZero::new(10).unwrap());
408        let unpacked = bit_packed.to_primitive();
409
410        assert_eq!(unpacked.len(), 1000);
411        for i in 0..1000_usize {
412            if i.is_multiple_of(2) {
413                assert_eq!(unpacked.value(i), i as u32);
414            } else {
415                assert!(unpacked.is_null(i));
416            }
417        }
418    }
419
420    #[test]
421    fn test_different_bit_widths() {
422        // Test with different bit widths
423        let values: Vec<u32> = (0..100).map(|i| i * 2).collect();
424        let array = PrimitiveArray::<UInt32Type>::from(values);
425
426        for bit_width in [8, 16, 24, 32] {
427            let bit_packed =
428                BitPackedArray::from_primitive(array.clone(), NonZero::new(bit_width).unwrap());
429            let unpacked = bit_packed.to_primitive();
430
431            assert_eq!(unpacked.len(), 100);
432            for i in 0..100 {
433                assert_eq!(unpacked.value(i), i as u32 * 2);
434            }
435        }
436    }
437
438    #[test]
439    fn test_to_bytes_from_bytes_roundtrip() {
440        // Create a test array with some values
441        let values: Vec<u32> = (0..100).collect();
442        let array = PrimitiveArray::<UInt32Type>::from(values);
443        let bit_width = NonZero::new(10).unwrap();
444        let original = BitPackedArray::from_primitive(array, bit_width);
445
446        // Serialize to bytes
447        let mut buffer = Vec::new();
448        original.to_bytes(&mut buffer);
449
450        // Make sure we have some reasonable amount of data
451        assert!(!buffer.is_empty());
452        assert!(buffer.len() > 16); // At least header size
453
454        // Deserialize back using from_bytes
455        let bytes = bytes::Bytes::from(buffer);
456        let deserialized = BitPackedArray::<UInt32Type>::from_bytes(bytes);
457
458        // Verify the deserialized data matches the original
459        assert_eq!(deserialized.bit_width(), original.bit_width());
460        assert_eq!(deserialized.len(), original.len());
461        assert_eq!(deserialized.is_nullable(), original.is_nullable());
462
463        // Convert to primitive arrays and compare values
464        let original_primitive = original.to_primitive();
465        let deserialized_primitive = deserialized.to_primitive();
466
467        assert_eq!(original_primitive.len(), deserialized_primitive.len());
468        for i in 0..original_primitive.len() {
469            assert_eq!(original_primitive.value(i), deserialized_primitive.value(i));
470        }
471    }
472
473    #[test]
474    fn test_to_bytes_from_bytes_with_nulls() {
475        // Create a test array with some nulls
476        let values: Vec<Option<u32>> = (0..100)
477            .map(|i: u32| if i.is_multiple_of(3) { None } else { Some(i) })
478            .collect();
479        let array = PrimitiveArray::<UInt32Type>::from(values);
480        let bit_width = NonZero::new(10).unwrap();
481        let original = BitPackedArray::from_primitive(array, bit_width);
482
483        // Serialize to bytes
484        let mut buffer = Vec::new();
485        original.to_bytes(&mut buffer);
486
487        // Deserialize back
488        let bytes = bytes::Bytes::from(buffer);
489        let deserialized = BitPackedArray::<UInt32Type>::from_bytes(bytes);
490
491        // Verify the deserialized data matches the original
492        assert_eq!(deserialized.bit_width(), original.bit_width());
493        assert_eq!(deserialized.len(), original.len());
494        assert_eq!(deserialized.is_nullable(), original.is_nullable());
495
496        // Convert to primitive arrays and compare values including nulls
497        let original_primitive = original.to_primitive();
498        let deserialized_primitive = deserialized.to_primitive();
499
500        assert_eq!(original_primitive.len(), deserialized_primitive.len());
501        for i in 0..original_primitive.len() {
502            assert_eq!(
503                original_primitive.is_null(i),
504                deserialized_primitive.is_null(i)
505            );
506            if !original_primitive.is_null(i) {
507                assert_eq!(original_primitive.value(i), deserialized_primitive.value(i));
508            }
509        }
510    }
511
512    #[test]
513    fn test_to_bytes_from_bytes_with_nulls_and_offset() {
514        let values: Vec<Option<u16>> = (0..32)
515            .map(|i| if i % 3 == 0 { None } else { Some(i as u16) })
516            .collect();
517        let array = PrimitiveArray::<UInt16Type>::from(values);
518
519        // Slice to create a non-zero offset (and therefore a non-zero null bitmap bit offset).
520        let sliced = array.slice(1, 23);
521
522        let bit_width = NonZero::new(16).unwrap();
523        let original = BitPackedArray::from_primitive(sliced.clone(), bit_width);
524
525        let mut buffer = Vec::new();
526        original.to_bytes(&mut buffer);
527        let deserialized = BitPackedArray::<UInt16Type>::from_bytes(buffer.into());
528
529        let roundtripped = deserialized.to_primitive();
530        assert_eq!(roundtripped, sliced);
531    }
532
533    #[test]
534    fn test_memory_size_calculation() {
535        use super::*;
536        use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer};
537        use arrow::datatypes::UInt32Type;
538
539        let scalar_buffer = ScalarBuffer::<u32>::new(Buffer::from(vec![0; 1024]), 0, 1024);
540
541        // --- Test without nulls ---
542        let bit_packed_no_nulls = BitPackedArray::<UInt32Type> {
543            packed_values: scalar_buffer.clone(),
544            nulls: None,
545            bit_width: Some(NonZero::new(10).unwrap()),
546            original_len: 1024,
547        };
548
549        let expected_size_no_nulls =
550            size_of::<BitPackedArray<UInt32Type>>() + scalar_buffer.inner().capacity();
551        assert_eq!(
552            bit_packed_no_nulls.get_array_memory_size(),
553            expected_size_no_nulls,
554            "Memory size mismatch without nulls"
555        );
556
557        // --- Test with nulls ---
558        // Create dummy null buffer
559        let null_buffer = NullBuffer::new_null(1024);
560        let nulls = Some(null_buffer);
561
562        let bit_packed_with_nulls = BitPackedArray::<UInt32Type> {
563            packed_values: scalar_buffer.clone(),
564            nulls: nulls.clone(), // Clone the Option<NullBuffer>
565            bit_width: Some(NonZero::new(10).unwrap()),
566            original_len: 1024,
567        };
568
569        // Calculate expected size including null buffer
570        // Note: Arrow's Buffer might allocate slightly more than null_bitmap_len_bytes
571        // We use the actual buffer capacity for a more precise comparison
572        let actual_null_buffer_size = nulls.as_ref().map_or(0, |nb| nb.buffer().capacity());
573        let expected_size_with_nulls = size_of::<BitPackedArray<UInt32Type>>()
574            + scalar_buffer.inner().capacity()
575            + actual_null_buffer_size;
576
577        assert_eq!(
578            bit_packed_with_nulls.get_array_memory_size(),
579            expected_size_with_nulls,
580            "Memory size mismatch with nulls"
581        );
582    }
583}