Skip to main content

lance_encoding/previous/encodings/physical/
packed_struct.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5
6use arrow_schema::{DataType, Fields};
7use bytes::Bytes;
8use bytes::BytesMut;
9use futures::{FutureExt, future::BoxFuture};
10use lance_arrow::DataTypeExt;
11use lance_core::{Error, Result};
12
13use crate::data::BlockInfo;
14use crate::data::FixedSizeListBlock;
15use crate::format::ProtobufUtils;
16use crate::{
17    EncodingsIo,
18    buffer::LanceBuffer,
19    data::{DataBlock, FixedWidthDataBlock, StructDataBlock},
20    decoder::{PageScheduler, PrimitivePageDecoder},
21    previous::encoder::{ArrayEncoder, EncodedArray},
22};
23
24#[derive(Debug)]
25pub struct PackedStructPageScheduler {
26    // We don't actually need these schedulers right now since we decode all the field bytes directly
27    // But they can be useful if we actually need to use the decoders for the inner fields later
28    // e.g. once bitpacking is added
29    _inner_schedulers: Vec<Box<dyn PageScheduler>>,
30    fields: Fields,
31    buffer_offset: u64,
32}
33
34impl PackedStructPageScheduler {
35    pub fn new(
36        _inner_schedulers: Vec<Box<dyn PageScheduler>>,
37        struct_datatype: DataType,
38        buffer_offset: u64,
39    ) -> Self {
40        let DataType::Struct(fields) = struct_datatype else {
41            panic!("Struct datatype expected");
42        };
43        Self {
44            _inner_schedulers,
45            fields,
46            buffer_offset,
47        }
48    }
49}
50
51impl PageScheduler for PackedStructPageScheduler {
52    fn schedule_ranges(
53        &self,
54        ranges: &[std::ops::Range<u64>],
55        scheduler: &Arc<dyn EncodingsIo>,
56        top_level_row: u64,
57    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
58        let mut total_bytes_per_row: u64 = 0;
59
60        for field in &self.fields {
61            let bytes_per_field = field.data_type().byte_width() as u64;
62            total_bytes_per_row += bytes_per_field;
63        }
64
65        // Parts of the arrays in a page may be encoded in different encoding tasks
66        // In that case decoding two different sets of rows can result in the same ranges parameter being passed in
67        // e.g. we may get ranges[0..2] and ranges[0..2] to decode 4 rows through 2 tasks
68        // So to get the correct byte ranges we need to know the position of the buffer in the page (i.e. the buffer offset)
69        // This is computed directly from the buffer stored in the protobuf
70        let byte_ranges = ranges
71            .iter()
72            .map(|range| {
73                let start = self.buffer_offset + (range.start * total_bytes_per_row);
74                let end = self.buffer_offset + (range.end * total_bytes_per_row);
75                start..end
76            })
77            .collect::<Vec<_>>();
78
79        // Directly creates a future to decode the bytes
80        let bytes = scheduler.submit_request(byte_ranges, top_level_row);
81
82        let copy_struct_fields = self.fields.clone();
83
84        tokio::spawn(async move {
85            let bytes = bytes.await?;
86
87            let mut combined_bytes = BytesMut::default();
88            for byte_slice in bytes {
89                combined_bytes.extend_from_slice(&byte_slice);
90            }
91
92            Ok(Box::new(PackedStructPageDecoder {
93                data: combined_bytes.freeze(),
94                fields: copy_struct_fields,
95                total_bytes_per_row: total_bytes_per_row as usize,
96            }) as Box<dyn PrimitivePageDecoder>)
97        })
98        .map(|join_handle| join_handle.unwrap())
99        .boxed()
100    }
101}
102
103struct PackedStructPageDecoder {
104    data: Bytes,
105    fields: Fields,
106    total_bytes_per_row: usize,
107}
108
109impl PrimitivePageDecoder for PackedStructPageDecoder {
110    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
111        // Decoding workflow:
112        // rows 0-2: {x: [1, 2, 3], y: [4, 5, 6], z: [7, 8, 9]}
113        // rows 3-5: {x: [10, 11, 12], y: [13, 14, 15], z: [16, 17, 18]}
114        // packed encoding: [
115        // [1, 4, 7, 2, 5, 8, 3, 6, 9],
116        // [10, 13, 16, 11, 14, 17, 12, 15, 18]
117        // ]
118        // suppose bytes_per_field=1, 4, 8 for fields x, y, and z, respectively.
119        // Then total_bytes_per_row = 13
120        // Suppose rows_to_skip=1 and num_rows=2. Then we will slice bytes 13 to 39.
121        // Now we have [2, 5, 8, 3, 6, 9]
122        // We rearrange this to get [BytesMut(2, 3), BytesMut(5, 6), BytesMut(8, 9)] as a Vec<BytesMut>
123        // This is used to reconstruct the struct array later
124
125        let bytes_to_skip = (rows_to_skip as usize) * self.total_bytes_per_row;
126
127        let mut children = Vec::with_capacity(self.fields.len());
128
129        let mut start_index = 0;
130
131        for field in &self.fields {
132            let bytes_per_field = field.data_type().byte_width();
133            let mut field_bytes = Vec::with_capacity(bytes_per_field * num_rows as usize);
134
135            let mut byte_index = start_index;
136
137            for _ in 0..num_rows {
138                let start = bytes_to_skip + byte_index;
139                field_bytes.extend_from_slice(&self.data[start..(start + bytes_per_field)]);
140                byte_index += self.total_bytes_per_row;
141            }
142
143            start_index += bytes_per_field;
144            let child_block = FixedWidthDataBlock {
145                data: LanceBuffer::from(field_bytes),
146                bits_per_value: bytes_per_field as u64 * 8,
147                num_values: num_rows,
148                block_info: BlockInfo::new(),
149            };
150            let child_block = FixedSizeListBlock::from_flat(child_block, field.data_type());
151            children.push(child_block);
152        }
153        Ok(DataBlock::Struct(StructDataBlock {
154            children,
155            block_info: BlockInfo::default(),
156            validity: None,
157        }))
158    }
159}
160
161#[derive(Debug)]
162pub struct PackedStructEncoder {
163    inner_encoders: Vec<Box<dyn ArrayEncoder>>,
164}
165
166impl PackedStructEncoder {
167    pub fn new(inner_encoders: Vec<Box<dyn ArrayEncoder>>) -> Self {
168        Self { inner_encoders }
169    }
170}
171
172impl ArrayEncoder for PackedStructEncoder {
173    fn encode(
174        &self,
175        data: DataBlock,
176        data_type: &DataType,
177        buffer_index: &mut u32,
178    ) -> Result<EncodedArray> {
179        let struct_data = data.as_struct().unwrap();
180
181        let DataType::Struct(child_types) = data_type else {
182            panic!("Struct datatype expected");
183        };
184
185        // Encode individual fields
186        let mut encoded_fields = Vec::with_capacity(struct_data.children.len());
187        for ((child, encoder), child_type) in struct_data
188            .children
189            .into_iter()
190            .zip(&self.inner_encoders)
191            .zip(child_types)
192        {
193            encoded_fields.push(encoder.encode(child, child_type.data_type(), &mut 0)?);
194        }
195
196        let (encoded_data_vec, child_encodings): (Vec<_>, Vec<_>) = encoded_fields
197            .into_iter()
198            .map(|field| (field.data, field.encoding))
199            .unzip();
200
201        // Zip together encoded data
202        //
203        // We can currently encode both FixedWidth and FixedSizeList.  In order
204        // to encode the latter we "flatten" it converting a FixedSizeList into
205        // a FixedWidth with very wide items.
206        let fixed_fields = encoded_data_vec
207            .into_iter()
208            .map(|child| match child {
209                DataBlock::FixedWidth(fixed) => Ok(fixed),
210                DataBlock::FixedSizeList(fixed_size_list) => {
211                    let flattened = fixed_size_list.try_into_flat().ok_or_else(|| {
212                        Error::invalid_input(
213                            "Packed struct encoder cannot pack nullable fixed-width data blocks",
214                        )
215                    })?;
216                    Ok(flattened)
217                }
218                _ => Err(Error::invalid_input(
219                    "Packed struct encoder currently only implemented for fixed-width data blocks",
220                )),
221            })
222            .collect::<Result<Vec<_>>>()?;
223        let total_bits_per_value = fixed_fields.iter().map(|f| f.bits_per_value).sum::<u64>();
224
225        let num_values = fixed_fields[0].num_values;
226        debug_assert!(
227            fixed_fields
228                .iter()
229                .all(|field| field.num_values == num_values)
230        );
231
232        let zipped_input = fixed_fields
233            .into_iter()
234            .map(|field| (field.data, field.bits_per_value))
235            .collect::<Vec<_>>();
236        let zipped = LanceBuffer::zip_into_one(zipped_input, num_values)?;
237
238        // Create encoding protobuf
239        let index = *buffer_index;
240        *buffer_index += 1;
241
242        let packed_data = DataBlock::FixedWidth(FixedWidthDataBlock {
243            data: zipped,
244            bits_per_value: total_bits_per_value,
245            num_values,
246            block_info: BlockInfo::new(),
247        });
248
249        let encoding = ProtobufUtils::packed_struct(child_encodings, index);
250
251        Ok(EncodedArray {
252            data: packed_data,
253            encoding,
254        })
255    }
256}
257
258#[cfg(test)]
259pub mod tests {
260
261    use arrow_array::{ArrayRef, Int32Array, StructArray, UInt8Array, UInt64Array};
262    use arrow_schema::{DataType, Field, Fields};
263    use std::{collections::HashMap, sync::Arc, vec};
264
265    use crate::testing::{TestCases, check_basic_random, check_round_trip_encoding_of_data};
266
267    #[test_log::test(tokio::test)]
268    async fn test_random_packed_struct() {
269        let data_type = DataType::Struct(Fields::from(vec![
270            Field::new("a", DataType::UInt64, false),
271            Field::new("b", DataType::UInt32, false),
272        ]));
273        let mut metadata = HashMap::new();
274        metadata.insert("packed".to_string(), "true".to_string());
275
276        let field = Field::new("", data_type, false).with_metadata(metadata);
277
278        check_basic_random(field).await;
279    }
280
281    #[test_log::test(tokio::test)]
282    async fn test_specific_packed_struct() {
283        let array1 = Arc::new(UInt64Array::from(vec![1, 2, 3, 4]));
284        let array2 = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
285        let array3 = Arc::new(UInt8Array::from(vec![9, 10, 11, 12]));
286
287        let struct_array1 = Arc::new(StructArray::from(vec![
288            (
289                Arc::new(Field::new("x", DataType::UInt64, false)),
290                array1.clone() as ArrayRef,
291            ),
292            (
293                Arc::new(Field::new("y", DataType::Int32, false)),
294                array2.clone() as ArrayRef,
295            ),
296            (
297                Arc::new(Field::new("z", DataType::UInt8, false)),
298                array3.clone() as ArrayRef,
299            ),
300        ]));
301
302        let array4 = Arc::new(UInt64Array::from(vec![13, 14, 15, 16]));
303        let array5 = Arc::new(Int32Array::from(vec![17, 18, 19, 20]));
304        let array6 = Arc::new(UInt8Array::from(vec![21, 22, 23, 24]));
305
306        let struct_array2 = Arc::new(StructArray::from(vec![
307            (
308                Arc::new(Field::new("x", DataType::UInt64, false)),
309                array4.clone() as ArrayRef,
310            ),
311            (
312                Arc::new(Field::new("y", DataType::Int32, false)),
313                array5.clone() as ArrayRef,
314            ),
315            (
316                Arc::new(Field::new("z", DataType::UInt8, false)),
317                array6.clone() as ArrayRef,
318            ),
319        ]));
320
321        let test_cases = TestCases::default()
322            .with_range(0..2)
323            .with_range(0..6)
324            .with_range(1..4)
325            .with_indices(vec![1, 3, 7]);
326
327        let mut metadata = HashMap::new();
328        metadata.insert("packed".to_string(), "true".to_string());
329
330        check_round_trip_encoding_of_data(
331            vec![struct_array1, struct_array2],
332            &test_cases,
333            metadata,
334        )
335        .await;
336    }
337}