lance_encoding/encodings/logical/
struct.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::BinaryHeap, ops::Range, sync::Arc};
5
6use crate::{
7    decoder::{
8        DecodedArray, FilterExpression, LoadedPage, NextDecodeTask, PageEncoding,
9        ScheduledScanLine, SchedulerContext, StructuralDecodeArrayTask, StructuralFieldDecoder,
10        StructuralFieldScheduler, StructuralSchedulingJob,
11    },
12    encoder::{EncodeTask, EncodedColumn, EncodedPage, FieldEncoder, OutOfLineBuffers},
13    format::pb,
14    repdef::RepDefBuilder,
15};
16use arrow_array::{cast::AsArray, Array, ArrayRef, StructArray};
17use arrow_schema::{DataType, Fields};
18use futures::{
19    future::BoxFuture,
20    stream::{FuturesOrdered, FuturesUnordered},
21    FutureExt, StreamExt, TryStreamExt,
22};
23use itertools::Itertools;
24use lance_arrow::deepcopy::deep_copy_nulls;
25use lance_arrow::FieldExt;
26use lance_core::Result;
27use log::trace;
28
29use super::{list::StructuralListDecoder, primitive::StructuralPrimitiveFieldDecoder};
30
31#[derive(Debug)]
32struct StructuralSchedulingJobWithStatus<'a> {
33    col_idx: u32,
34    col_name: &'a str,
35    job: Box<dyn StructuralSchedulingJob + 'a>,
36    rows_scheduled: u64,
37    rows_remaining: u64,
38}
39
40impl PartialEq for StructuralSchedulingJobWithStatus<'_> {
41    fn eq(&self, other: &Self) -> bool {
42        self.col_idx == other.col_idx
43    }
44}
45
46impl Eq for StructuralSchedulingJobWithStatus<'_> {}
47
48impl PartialOrd for StructuralSchedulingJobWithStatus<'_> {
49    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
50        Some(self.cmp(other))
51    }
52}
53
54impl Ord for StructuralSchedulingJobWithStatus<'_> {
55    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
56        // Note this is reversed to make it min-heap
57        other.rows_scheduled.cmp(&self.rows_scheduled)
58    }
59}
60
61/// Scheduling job for struct data
62///
63/// The order in which we schedule the children is important.  We want to schedule the child
64/// with the least amount of data first.
65///
66/// This allows us to decode entire rows as quickly as possible
67#[derive(Debug)]
68struct RepDefStructSchedulingJob<'a> {
69    /// A min-heap whose key is the # of rows currently scheduled
70    children: BinaryHeap<StructuralSchedulingJobWithStatus<'a>>,
71    rows_scheduled: u64,
72}
73
74impl<'a> RepDefStructSchedulingJob<'a> {
75    fn new(
76        scheduler: &'a StructuralStructScheduler,
77        children: Vec<Box<dyn StructuralSchedulingJob + 'a>>,
78        num_rows: u64,
79    ) -> Self {
80        let children = children
81            .into_iter()
82            .enumerate()
83            .map(|(idx, job)| StructuralSchedulingJobWithStatus {
84                col_idx: idx as u32,
85                col_name: scheduler.child_fields[idx].name(),
86                job,
87                rows_scheduled: 0,
88                rows_remaining: num_rows,
89            })
90            .collect::<BinaryHeap<_>>();
91        Self {
92            children,
93            rows_scheduled: 0,
94        }
95    }
96}
97
98impl StructuralSchedulingJob for RepDefStructSchedulingJob<'_> {
99    fn schedule_next(
100        &mut self,
101        mut context: &mut SchedulerContext,
102    ) -> Result<Option<ScheduledScanLine>> {
103        let mut decoders = Vec::new();
104        let old_rows_scheduled = self.rows_scheduled;
105        // Schedule as many children as we need to until we have scheduled at least one
106        // complete row
107        while old_rows_scheduled == self.rows_scheduled {
108            let mut next_child = self.children.pop().unwrap();
109            let scoped = context.push(next_child.col_name, next_child.col_idx);
110            let child_scan = next_child.job.schedule_next(scoped.context)?;
111            // next_child is the least-scheduled child and, if it's done, that
112            // means we are completely done.
113            if child_scan.is_none() {
114                return Ok(None);
115            }
116            let child_scan = child_scan.unwrap();
117
118            trace!(
119                "Scheduled {} rows for child {}",
120                child_scan.rows_scheduled,
121                next_child.col_idx
122            );
123            next_child.rows_scheduled += child_scan.rows_scheduled;
124            next_child.rows_remaining -= child_scan.rows_scheduled;
125            decoders.extend(child_scan.decoders);
126            self.children.push(next_child);
127            self.rows_scheduled = self.children.peek().unwrap().rows_scheduled;
128            context = scoped.pop();
129        }
130        let struct_rows_scheduled = self.rows_scheduled - old_rows_scheduled;
131        Ok(Some(ScheduledScanLine {
132            decoders,
133            rows_scheduled: struct_rows_scheduled,
134        }))
135    }
136}
137
138/// A scheduler for structs
139///
140/// The implementation is actually a bit more tricky than one might initially think.  We can't just
141/// go through and schedule each column one after the other.  This would mean our decode can't start
142/// until nearly all the data has arrived (since we need data from each column to yield a batch)
143///
144/// Instead, we schedule in row-major fashion
145///
146/// Note: this scheduler is the starting point for all decoding.  This is because we treat the top-level
147/// record batch as a non-nullable struct.
148#[derive(Debug)]
149pub struct StructuralStructScheduler {
150    children: Vec<Box<dyn StructuralFieldScheduler>>,
151    child_fields: Fields,
152}
153
154impl StructuralStructScheduler {
155    pub fn new(children: Vec<Box<dyn StructuralFieldScheduler>>, child_fields: Fields) -> Self {
156        debug_assert!(!children.is_empty());
157        Self {
158            children,
159            child_fields,
160        }
161    }
162}
163
164impl StructuralFieldScheduler for StructuralStructScheduler {
165    fn schedule_ranges<'a>(
166        &'a self,
167        ranges: &[Range<u64>],
168        filter: &FilterExpression,
169    ) -> Result<Box<dyn StructuralSchedulingJob + 'a>> {
170        let num_rows = ranges.iter().map(|r| r.end - r.start).sum();
171
172        let child_schedulers = self
173            .children
174            .iter()
175            .map(|child| child.schedule_ranges(ranges, filter))
176            .collect::<Result<Vec<_>>>()?;
177
178        Ok(Box::new(RepDefStructSchedulingJob::new(
179            self,
180            child_schedulers,
181            num_rows,
182        )))
183    }
184
185    fn initialize<'a>(
186        &'a mut self,
187        filter: &'a FilterExpression,
188        context: &'a SchedulerContext,
189    ) -> BoxFuture<'a, Result<()>> {
190        let children_initialization = self
191            .children
192            .iter_mut()
193            .map(|child| child.initialize(filter, context))
194            .collect::<FuturesUnordered<_>>();
195        async move {
196            children_initialization
197                .map(|res| res.map(|_| ()))
198                .try_collect::<Vec<_>>()
199                .await?;
200            Ok(())
201        }
202        .boxed()
203    }
204}
205
206#[derive(Debug)]
207pub struct StructuralStructDecoder {
208    children: Vec<Box<dyn StructuralFieldDecoder>>,
209    data_type: DataType,
210    child_fields: Fields,
211    // The root decoder is slightly different because it cannot have nulls
212    is_root: bool,
213}
214
215impl StructuralStructDecoder {
216    pub fn new(fields: Fields, should_validate: bool, is_root: bool) -> Self {
217        let children = fields
218            .iter()
219            .map(|field| Self::field_to_decoder(field, should_validate))
220            .collect();
221        let data_type = DataType::Struct(fields.clone());
222        Self {
223            data_type,
224            children,
225            child_fields: fields,
226            is_root,
227        }
228    }
229
230    fn field_to_decoder(
231        field: &Arc<arrow_schema::Field>,
232        should_validate: bool,
233    ) -> Box<dyn StructuralFieldDecoder> {
234        match field.data_type() {
235            DataType::Struct(fields) => {
236                if field.is_packed_struct() {
237                    let decoder =
238                        StructuralPrimitiveFieldDecoder::new(&field.clone(), should_validate);
239                    Box::new(decoder)
240                } else {
241                    Box::new(Self::new(fields.clone(), should_validate, false))
242                }
243            }
244            DataType::List(child_field) | DataType::LargeList(child_field) => {
245                let child_decoder = Self::field_to_decoder(child_field, should_validate);
246                Box::new(StructuralListDecoder::new(
247                    child_decoder,
248                    field.data_type().clone(),
249                ))
250            }
251            DataType::RunEndEncoded(_, _) => todo!(),
252            DataType::ListView(_) | DataType::LargeListView(_) => todo!(),
253            DataType::Map(_, _) => todo!(),
254            DataType::Union(_, _) => todo!(),
255            _ => Box::new(StructuralPrimitiveFieldDecoder::new(field, should_validate)),
256        }
257    }
258
259    pub fn drain_batch_task(&mut self, num_rows: u64) -> Result<NextDecodeTask> {
260        let array_drain = self.drain(num_rows)?;
261        Ok(NextDecodeTask {
262            num_rows,
263            task: Box::new(array_drain),
264        })
265    }
266}
267
268impl StructuralFieldDecoder for StructuralStructDecoder {
269    fn accept_page(&mut self, mut child: LoadedPage) -> Result<()> {
270        // children with empty path should not be delivered to this method
271        let child_idx = child.path.pop_front().unwrap();
272        // This decoder is intended for one of our children
273        self.children[child_idx as usize].accept_page(child)?;
274        Ok(())
275    }
276
277    fn drain(&mut self, num_rows: u64) -> Result<Box<dyn StructuralDecodeArrayTask>> {
278        let child_tasks = self
279            .children
280            .iter_mut()
281            .map(|child| child.drain(num_rows))
282            .collect::<Result<Vec<_>>>()?;
283        Ok(Box::new(RepDefStructDecodeTask {
284            children: child_tasks,
285            child_fields: self.child_fields.clone(),
286            is_root: self.is_root,
287        }))
288    }
289
290    fn data_type(&self) -> &DataType {
291        &self.data_type
292    }
293}
294
295#[derive(Debug)]
296struct RepDefStructDecodeTask {
297    children: Vec<Box<dyn StructuralDecodeArrayTask>>,
298    child_fields: Fields,
299    is_root: bool,
300}
301
302impl StructuralDecodeArrayTask for RepDefStructDecodeTask {
303    fn decode(self: Box<Self>) -> Result<DecodedArray> {
304        let arrays = self
305            .children
306            .into_iter()
307            .map(|task| task.decode())
308            .collect::<Result<Vec<_>>>()?;
309        let mut children = Vec::with_capacity(arrays.len());
310        let mut arrays_iter = arrays.into_iter();
311        let first_array = arrays_iter.next().unwrap();
312        let length = first_array.array.len();
313
314        // The repdef should be identical across all children at this point
315        let mut repdef = first_array.repdef;
316        children.push(first_array.array);
317
318        for array in arrays_iter {
319            debug_assert_eq!(length, array.array.len());
320            children.push(array.array);
321        }
322
323        let validity = if self.is_root {
324            None
325        } else {
326            repdef.unravel_validity(length)
327        };
328        let array = StructArray::new(self.child_fields, children, validity);
329        Ok(DecodedArray {
330            array: Arc::new(array),
331            repdef,
332        })
333    }
334}
335
336/// A structural encoder for struct fields
337///
338/// The struct's validity is added to the rep/def builder
339/// and the builder is cloned to all children.
340pub struct StructStructuralEncoder {
341    keep_original_array: bool,
342    children: Vec<Box<dyn FieldEncoder>>,
343}
344
345impl StructStructuralEncoder {
346    pub fn new(keep_original_array: bool, children: Vec<Box<dyn FieldEncoder>>) -> Self {
347        Self {
348            keep_original_array,
349            children,
350        }
351    }
352}
353
354impl FieldEncoder for StructStructuralEncoder {
355    fn maybe_encode(
356        &mut self,
357        array: ArrayRef,
358        external_buffers: &mut OutOfLineBuffers,
359        mut repdef: RepDefBuilder,
360        row_number: u64,
361        num_rows: u64,
362    ) -> Result<Vec<EncodeTask>> {
363        let struct_array = array.as_struct();
364        if let Some(validity) = struct_array.nulls() {
365            if self.keep_original_array {
366                repdef.add_validity_bitmap(validity.clone())
367            } else {
368                repdef.add_validity_bitmap(deep_copy_nulls(Some(validity)).unwrap())
369            }
370        } else {
371            repdef.add_no_null(struct_array.len());
372        }
373        let child_tasks = self
374            .children
375            .iter_mut()
376            .zip(struct_array.columns().iter())
377            .map(|(encoder, arr)| {
378                encoder.maybe_encode(
379                    arr.clone(),
380                    external_buffers,
381                    repdef.clone(),
382                    row_number,
383                    num_rows,
384                )
385            })
386            .collect::<Result<Vec<_>>>()?;
387        Ok(child_tasks.into_iter().flatten().collect::<Vec<_>>())
388    }
389
390    fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result<Vec<EncodeTask>> {
391        self.children
392            .iter_mut()
393            .map(|encoder| encoder.flush(external_buffers))
394            .flatten_ok()
395            .collect::<Result<Vec<_>>>()
396    }
397
398    fn num_columns(&self) -> u32 {
399        self.children
400            .iter()
401            .map(|child| child.num_columns())
402            .sum::<u32>()
403    }
404
405    fn finish(
406        &mut self,
407        external_buffers: &mut OutOfLineBuffers,
408    ) -> BoxFuture<'_, Result<Vec<crate::encoder::EncodedColumn>>> {
409        let mut child_columns = self
410            .children
411            .iter_mut()
412            .map(|child| child.finish(external_buffers))
413            .collect::<FuturesOrdered<_>>();
414        async move {
415            let mut encoded_columns = Vec::with_capacity(child_columns.len());
416            while let Some(child_cols) = child_columns.next().await {
417                encoded_columns.extend(child_cols?);
418            }
419            Ok(encoded_columns)
420        }
421        .boxed()
422    }
423}
424
425pub struct StructFieldEncoder {
426    children: Vec<Box<dyn FieldEncoder>>,
427    column_index: u32,
428    num_rows_seen: u64,
429}
430
431impl StructFieldEncoder {
432    #[allow(dead_code)]
433    pub fn new(children: Vec<Box<dyn FieldEncoder>>, column_index: u32) -> Self {
434        Self {
435            children,
436            column_index,
437            num_rows_seen: 0,
438        }
439    }
440}
441
442impl FieldEncoder for StructFieldEncoder {
443    fn maybe_encode(
444        &mut self,
445        array: ArrayRef,
446        external_buffers: &mut OutOfLineBuffers,
447        repdef: RepDefBuilder,
448        row_number: u64,
449        num_rows: u64,
450    ) -> Result<Vec<EncodeTask>> {
451        self.num_rows_seen += array.len() as u64;
452        let struct_array = array.as_struct();
453        let child_tasks = self
454            .children
455            .iter_mut()
456            .zip(struct_array.columns().iter())
457            .map(|(encoder, arr)| {
458                encoder.maybe_encode(
459                    arr.clone(),
460                    external_buffers,
461                    repdef.clone(),
462                    row_number,
463                    num_rows,
464                )
465            })
466            .collect::<Result<Vec<_>>>()?;
467        Ok(child_tasks.into_iter().flatten().collect::<Vec<_>>())
468    }
469
470    fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result<Vec<EncodeTask>> {
471        let child_tasks = self
472            .children
473            .iter_mut()
474            .map(|encoder| encoder.flush(external_buffers))
475            .collect::<Result<Vec<_>>>()?;
476        Ok(child_tasks.into_iter().flatten().collect::<Vec<_>>())
477    }
478
479    fn num_columns(&self) -> u32 {
480        self.children
481            .iter()
482            .map(|child| child.num_columns())
483            .sum::<u32>()
484            + 1
485    }
486
487    fn finish(
488        &mut self,
489        external_buffers: &mut OutOfLineBuffers,
490    ) -> BoxFuture<'_, Result<Vec<crate::encoder::EncodedColumn>>> {
491        let mut child_columns = self
492            .children
493            .iter_mut()
494            .map(|child| child.finish(external_buffers))
495            .collect::<FuturesOrdered<_>>();
496        let num_rows_seen = self.num_rows_seen;
497        let column_index = self.column_index;
498        async move {
499            let mut columns = Vec::new();
500            // Add a column for the struct header
501            let mut header = EncodedColumn::default();
502            header.final_pages.push(EncodedPage {
503                data: Vec::new(),
504                description: PageEncoding::Legacy(pb::ArrayEncoding {
505                    array_encoding: Some(pb::array_encoding::ArrayEncoding::Struct(
506                        pb::SimpleStruct {},
507                    )),
508                }),
509                num_rows: num_rows_seen,
510                column_idx: column_index,
511                row_number: 0, // Not used by legacy encoding
512            });
513            columns.push(header);
514            // Now run finish on the children
515            while let Some(child_cols) = child_columns.next().await {
516                columns.extend(child_cols?);
517            }
518            Ok(columns)
519        }
520        .boxed()
521    }
522}
523
524#[cfg(test)]
525mod tests {
526
527    use std::{collections::HashMap, sync::Arc};
528
529    use arrow_array::{
530        builder::{Int32Builder, ListBuilder},
531        Array, ArrayRef, Int32Array, StructArray,
532    };
533    use arrow_buffer::NullBuffer;
534    use arrow_schema::{DataType, Field, Fields};
535
536    use crate::{
537        testing::{check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases},
538        version::LanceFileVersion,
539    };
540
541    #[test_log::test(tokio::test)]
542    async fn test_simple_struct() {
543        let data_type = DataType::Struct(Fields::from(vec![
544            Field::new("a", DataType::Int32, false),
545            Field::new("b", DataType::Int32, false),
546        ]));
547        let field = Field::new("", data_type, false);
548        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
549    }
550
551    #[test_log::test(tokio::test)]
552    async fn test_nullable_struct() {
553        // Test data struct<score: int32, location: struct<x: int32, y: int32>>
554        // - score: null
555        //   location:
556        //     x: 1
557        //     y: 6
558        // - score: 12
559        //   location:
560        //     x: 2
561        //     y: null
562        // - score: 13
563        //   location:
564        //     x: 3
565        //     y: 8
566        // - score: 14
567        //   location: null
568        // - null
569        //
570        let inner_fields = Fields::from(vec![
571            Field::new("x", DataType::Int32, false),
572            Field::new("y", DataType::Int32, true),
573        ]);
574        let inner_struct = DataType::Struct(inner_fields.clone());
575        let outer_fields = Fields::from(vec![
576            Field::new("score", DataType::Int32, true),
577            Field::new("location", inner_struct, true),
578        ]);
579
580        let x_vals = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]);
581        let y_vals = Int32Array::from(vec![Some(6), None, Some(8), Some(9), Some(10)]);
582        let scores = Int32Array::from(vec![None, Some(12), Some(13), Some(14), Some(15)]);
583
584        let location_validity = NullBuffer::from(vec![true, true, true, false, true]);
585        let locations = StructArray::new(
586            inner_fields,
587            vec![Arc::new(x_vals), Arc::new(y_vals)],
588            Some(location_validity),
589        );
590
591        let rows_validity = NullBuffer::from(vec![true, true, true, true, false]);
592        let rows = StructArray::new(
593            outer_fields,
594            vec![Arc::new(scores), Arc::new(locations)],
595            Some(rows_validity),
596        );
597
598        let test_cases = TestCases::default().with_file_version(LanceFileVersion::V2_1);
599
600        check_round_trip_encoding_of_data(vec![Arc::new(rows)], &test_cases, HashMap::new()).await;
601    }
602
603    #[test_log::test(tokio::test)]
604    async fn test_struct_list() {
605        let data_type = DataType::Struct(Fields::from(vec![
606            Field::new(
607                "inner_list",
608                DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
609                true,
610            ),
611            Field::new("outer_int", DataType::Int32, true),
612        ]));
613        let field = Field::new("row", data_type, false);
614        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
615    }
616
617    #[test_log::test(tokio::test)]
618    async fn test_empty_struct() {
619        // It's technically legal for a struct to have 0 children, need to
620        // make sure we support that
621        let data_type = DataType::Struct(Fields::from(Vec::<Field>::default()));
622        let field = Field::new("row", data_type, false);
623        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
624    }
625
626    #[test_log::test(tokio::test)]
627    async fn test_complicated_struct() {
628        let data_type = DataType::Struct(Fields::from(vec![
629            Field::new("int", DataType::Int32, true),
630            Field::new(
631                "inner",
632                DataType::Struct(Fields::from(vec![
633                    Field::new("inner_int", DataType::Int32, true),
634                    Field::new(
635                        "inner_list",
636                        DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
637                        true,
638                    ),
639                ])),
640                true,
641            ),
642            Field::new("outer_binary", DataType::Binary, true),
643        ]));
644        let field = Field::new("row", data_type, false);
645        check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await;
646    }
647
648    #[test_log::test(tokio::test)]
649    async fn test_ragged_scheduling() {
650        // This test covers scheduling when batches straddle page boundaries
651
652        // Create a list with 10k nulls
653        let items_builder = Int32Builder::new();
654        let mut list_builder = ListBuilder::new(items_builder);
655        for _ in 0..10000 {
656            list_builder.append_null();
657        }
658        let list_array = Arc::new(list_builder.finish());
659        let int_array = Arc::new(Int32Array::from_iter_values(0..10000));
660        let fields = vec![
661            Field::new("", list_array.data_type().clone(), true),
662            Field::new("", int_array.data_type().clone(), true),
663        ];
664        let struct_array = Arc::new(StructArray::new(
665            Fields::from(fields),
666            vec![list_array, int_array],
667            None,
668        )) as ArrayRef;
669        let struct_arrays = (0..10000)
670            // Intentionally skip in some randomish amount to create more ragged scheduling
671            .step_by(437)
672            .map(|offset| struct_array.slice(offset, 437.min(10000 - offset)))
673            .collect::<Vec<_>>();
674        check_round_trip_encoding_of_data(struct_arrays, &TestCases::default(), HashMap::new())
675            .await;
676    }
677}