lance/io/exec/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::borrow::Cow;
5use std::collections::{HashMap, HashSet};
6use std::sync::{Arc, Mutex};
7
8use arrow::array::AsArray;
9use arrow::compute::{concat_batches, TakeOptions};
10use arrow::datatypes::UInt64Type;
11use arrow_array::{Array, UInt32Array};
12use arrow_array::{RecordBatch, UInt64Array};
13use arrow_schema::{Schema as ArrowSchema, SchemaRef};
14use datafusion::common::Statistics;
15use datafusion::error::{DataFusionError, Result};
16use datafusion::physical_plan::metrics::{
17    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet,
18};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use datafusion_physical_expr::EquivalenceProperties;
24use futures::stream::{FuturesOrdered, Stream, StreamExt, TryStreamExt};
25use futures::FutureExt;
26use lance_arrow::RecordBatchExt;
27use lance_core::datatypes::{Field, OnMissing, Projection};
28use lance_core::error::{DataFusionResult, LanceOptionExt};
29use lance_core::utils::address::RowAddress;
30use lance_core::utils::tokio::get_num_compute_intensive_cpus;
31use lance_core::{ROW_ADDR, ROW_ID};
32use lance_io::scheduler::{ScanScheduler, SchedulerConfig};
33
34use crate::dataset::fragment::{FragReadConfig, FragmentReader};
35use crate::dataset::rowids::get_row_id_index;
36use crate::dataset::Dataset;
37use crate::datatypes::Schema;
38
39use super::utils::IoMetrics;
40
41#[derive(Debug, Clone)]
42struct TakeStreamMetrics {
43    baseline_metrics: BaselineMetrics,
44    batches_processed: Count,
45    io_metrics: IoMetrics,
46}
47
48impl TakeStreamMetrics {
49    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
50        let batches_processed = Count::new();
51        MetricBuilder::new(metrics)
52            .with_partition(partition)
53            .build(MetricValue::Count {
54                name: Cow::Borrowed("batches_processed"),
55                count: batches_processed.clone(),
56            });
57        Self {
58            baseline_metrics: BaselineMetrics::new(metrics, partition),
59            batches_processed,
60            io_metrics: IoMetrics::new(metrics, partition),
61        }
62    }
63}
64
65struct TakeStream {
66    /// The dataset to take from
67    dataset: Arc<Dataset>,
68    /// The fields to take from the input stream
69    fields_to_take: Arc<Schema>,
70    /// The output schema, needed for us to merge the new columns
71    /// into the input data in the correct order
72    output_schema: SchemaRef,
73    /// A cache of opened file readers
74    ///
75    /// This is a map from fragment id to a reader.
76    readers_cache: Arc<Mutex<HashMap<u32, Arc<FragmentReader>>>>,
77    /// The scan scheduler to use for reading fragments
78    scan_scheduler: Arc<ScanScheduler>,
79    /// The metrics for the stream
80    metrics: TakeStreamMetrics,
81}
82
83impl TakeStream {
84    fn new(
85        dataset: Arc<Dataset>,
86        fields_to_take: Arc<Schema>,
87        output_schema: SchemaRef,
88        scan_scheduler: Arc<ScanScheduler>,
89        metrics: &ExecutionPlanMetricsSet,
90        partition: usize,
91    ) -> Self {
92        Self {
93            dataset,
94            fields_to_take,
95            output_schema,
96            readers_cache: Arc::new(Mutex::new(HashMap::new())),
97            scan_scheduler,
98            metrics: TakeStreamMetrics::new(metrics, partition),
99        }
100    }
101
102    async fn do_open_reader(&self, fragment_id: u32) -> DataFusionResult<Arc<FragmentReader>> {
103        let fragment = self
104        .dataset
105        .get_fragment(fragment_id as usize)
106        .ok_or_else(|| {
107            DataFusionError::Execution(format!("The input to a take operation specified fragment id {} but this fragment does not exist in the dataset", fragment_id))
108        })?;
109
110        let reader = Arc::new(
111            fragment
112                .open(
113                    &self.fields_to_take,
114                    FragReadConfig::default().with_scan_scheduler(self.scan_scheduler.clone()),
115                )
116                .await?,
117        );
118
119        let mut readers = self.readers_cache.lock().unwrap();
120        readers.insert(fragment_id, reader.clone());
121        Ok(reader)
122    }
123
124    async fn open_reader(&self, fragment_id: u32) -> DataFusionResult<Arc<FragmentReader>> {
125        if let Some(reader) = self
126            .readers_cache
127            .lock()
128            .unwrap()
129            .get(&fragment_id)
130            .cloned()
131        {
132            return Ok(reader);
133        }
134
135        self.do_open_reader(fragment_id).await
136    }
137
138    async fn get_row_addrs(&self, batch: &RecordBatch) -> Result<Arc<dyn Array>> {
139        if let Some(row_addr_array) = batch.column_by_name(ROW_ADDR) {
140            Ok(row_addr_array.clone())
141        } else {
142            let row_id_array = batch.column_by_name(ROW_ID).expect_ok()?;
143
144            if let Some(row_id_index) = get_row_id_index(&self.dataset).await? {
145                let row_id_array = row_id_array.as_primitive::<UInt64Type>();
146                let addresses = row_id_array
147                    .values()
148                    .iter()
149                    .filter_map(|id| row_id_index.get(*id).map(|address| address.into()))
150                    .collect::<Vec<u64>>();
151                Ok(Arc::new(UInt64Array::from(addresses)))
152            } else {
153                Ok(row_id_array.clone())
154            }
155        }
156    }
157
158    async fn map_batch(
159        self: Arc<Self>,
160        batch: RecordBatch,
161        batch_number: u32,
162    ) -> DataFusionResult<RecordBatch> {
163        let compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer();
164        let row_addrs_arr = self.get_row_addrs(&batch).await?;
165        let row_addrs = row_addrs_arr.as_primitive::<UInt64Type>();
166
167        debug_assert!(
168            row_addrs.null_count() == 0,
169            "{} nulls in row addresses",
170            row_addrs.null_count()
171        );
172        // Check if the row addresses are already sorted to avoid unnecessary reorders
173        let is_sorted = row_addrs.values().is_sorted();
174
175        let sorted_addrs: Arc<dyn Array>;
176        let (sorted_addrs, permutation) = if is_sorted {
177            (row_addrs, None)
178        } else {
179            let permutation = arrow::compute::sort_to_indices(&row_addrs_arr, None, None).unwrap();
180            sorted_addrs = arrow::compute::take(
181                &row_addrs_arr,
182                &permutation,
183                Some(TakeOptions {
184                    check_bounds: false,
185                }),
186            )
187            .unwrap();
188            // Calculate the inverse permutation to restore the original order
189            let mut inverse_permutation = vec![0; permutation.len()];
190            for (i, p) in permutation.values().iter().enumerate() {
191                inverse_permutation[*p as usize] = i as u32;
192            }
193            (
194                sorted_addrs.as_primitive::<UInt64Type>(),
195                Some(UInt32Array::from(inverse_permutation)),
196            )
197        };
198
199        let mut futures = FuturesOrdered::new();
200        let mut current_offsets = Vec::new();
201        let mut current_fragment_id = None;
202
203        for row_addr in sorted_addrs.values() {
204            let addr = RowAddress::new_from_u64(*row_addr);
205
206            if Some(addr.fragment_id()) != current_fragment_id {
207                // Start a new group
208                if let Some(fragment_id) = current_fragment_id {
209                    let reader = self.open_reader(fragment_id).await?;
210                    let offsets = std::mem::take(&mut current_offsets);
211                    futures.push_back(
212                        async move { reader.take_as_batch(&offsets, Some(batch_number)).await }
213                            .boxed(),
214                    );
215                }
216                current_fragment_id = Some(addr.fragment_id());
217            }
218
219            current_offsets.push(addr.row_offset());
220        }
221
222        // Handle the last group
223        if let Some(fragment_id) = current_fragment_id {
224            let reader = self.open_reader(fragment_id).await?;
225            futures.push_back(
226                async move {
227                    reader
228                        .take_as_batch(&current_offsets, Some(batch_number))
229                        .await
230                }
231                .boxed(),
232            );
233        }
234
235        // Stop the compute timer, don't count I/O time
236        drop(compute_timer);
237
238        let batches = futures.try_collect::<Vec<_>>().await?;
239
240        if batches.is_empty() {
241            return Ok(RecordBatch::new_empty(self.output_schema.clone()));
242        }
243
244        let _compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer();
245        let schema = batches.first().expect_ok()?.schema();
246        let mut new_data = concat_batches(&schema, batches.iter())?;
247
248        // Restore previous order (if addresses were out of order originally)
249        if let Some(permutation) = permutation {
250            new_data = arrow_select::take::take_record_batch(&new_data, &permutation).unwrap();
251        }
252
253        self.metrics
254            .baseline_metrics
255            .record_output(new_data.num_rows());
256        self.metrics.batches_processed.add(1);
257        Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?)
258    }
259
260    fn apply<S: Stream<Item = Result<RecordBatch>> + Send + 'static>(
261        self: Arc<Self>,
262        input: S,
263    ) -> impl Stream<Item = Result<RecordBatch>> {
264        let scan_scheduler = self.scan_scheduler.clone();
265        let metrics = self.metrics.clone();
266        let batches = input
267            .enumerate()
268            .map(move |(batch_index, batch)| {
269                let batch = batch?;
270                let this = self.clone();
271                Ok(
272                    tokio::task::spawn(this.map_batch(batch, batch_index as u32))
273                        .map(|res| res.unwrap()),
274                )
275            })
276            .boxed();
277        batches
278            .inspect_ok(move |_| metrics.io_metrics.record(&scan_scheduler))
279            .try_buffered(get_num_compute_intensive_cpus())
280    }
281}
282
283#[derive(Debug)]
284pub struct TakeExec {
285    // The dataset to take from
286    dataset: Arc<Dataset>,
287    // The desired output projection of the relation (input schema + take schema)
288    //
289    // This is used to re-calculate output_projection and extra_schema when
290    // with_new_children is called.
291    output_projection: Projection,
292    // The schema of the extra columns to take from the dataset
293    schema_to_take: Arc<Schema>,
294    // The schema of the output
295    output_schema: SchemaRef,
296    input: Arc<dyn ExecutionPlan>,
297    properties: PlanProperties,
298    metrics: ExecutionPlanMetricsSet,
299}
300
301impl DisplayAs for TakeExec {
302    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
303        let extra_fields = self
304            .schema_to_take
305            .fields
306            .iter()
307            .map(|f| f.name.clone())
308            .collect::<HashSet<_>>();
309        let columns = self
310            .output_schema
311            .fields
312            .iter()
313            .map(|f| {
314                let name = f.name();
315                if extra_fields.contains(name) {
316                    format!("({})", name)
317                } else {
318                    name.clone()
319                }
320            })
321            .collect::<Vec<_>>()
322            .join(", ");
323        match t {
324            DisplayFormatType::Default | DisplayFormatType::Verbose => {
325                write!(f, "Take: columns={:?}", columns)
326            }
327            DisplayFormatType::TreeRender => {
328                write!(f, "Take\ncolumns={:?}", columns)
329            }
330        }
331    }
332}
333
334impl TakeExec {
335    /// Create a [`TakeExec`] node.
336    ///
337    /// - dataset: the dataset to read from
338    /// - input: the upstream [`ExecutionPlan`] to feed data in.
339    /// - projection: the desired output projection, can overlap with the input schema if desired
340    ///
341    /// Returns None if no extra columns are required (everything in the projection exists in the input schema).
342    pub fn try_new(
343        dataset: Arc<Dataset>,
344        input: Arc<dyn ExecutionPlan>,
345        projection: Projection,
346    ) -> Result<Option<Self>> {
347        let original_projection = projection.clone();
348        let projection =
349            projection.subtract_arrow_schema(input.schema().as_ref(), OnMissing::Ignore)?;
350        if !projection.has_data_fields() {
351            return Ok(None);
352        }
353
354        // We actually need a take so lets make sure we have a row id
355        if input.schema().column_with_name(ROW_ADDR).is_none()
356            && input.schema().column_with_name(ROW_ID).is_none()
357        {
358            return Err(DataFusionError::Plan(format!(
359                "TakeExec requires the input plan to have a column named '{}' or '{}'",
360                ROW_ADDR, ROW_ID
361            )));
362        }
363
364        // Can't use take if we don't want any fields and we can't use take to add row_id or row_addr
365        assert!(
366            !projection.with_row_id && !projection.with_row_addr,
367            "Take should not be used to insert row_id / row_addr: {:#?}",
368            projection
369        );
370
371        let output_schema = Arc::new(Self::calculate_output_schema(
372            dataset.schema(),
373            &input.schema(),
374            &projection,
375        ));
376        let output_arrow = Arc::new(ArrowSchema::from(output_schema.as_ref()));
377        let properties = input
378            .properties()
379            .clone()
380            .with_eq_properties(EquivalenceProperties::new(output_arrow.clone()));
381
382        Ok(Some(Self {
383            dataset,
384            output_projection: original_projection,
385            schema_to_take: projection.into_schema_ref(),
386            input,
387            output_schema: output_arrow,
388            properties,
389            metrics: ExecutionPlanMetricsSet::new(),
390        }))
391    }
392
393    /// The output of a take operation will be all columns from the input schema followed
394    /// by any new columns from the dataset.
395    ///
396    /// The output fields will always be added in dataset schema order
397    ///
398    /// Nested columns in the input schema may have new fields inserted into them.
399    ///
400    /// If this happens the order of the new nested fields will match the order defined in
401    /// the dataset schema.
402    fn calculate_output_schema(
403        dataset_schema: &Schema,
404        input_schema: &ArrowSchema,
405        projection: &Projection,
406    ) -> Schema {
407        // TakeExec doesn't reorder top-level fields and so the first thing we need to do is determine the
408        // top-level field order.
409        let mut top_level_fields_added = HashSet::with_capacity(input_schema.fields.len());
410        let projected_schema = projection.to_schema();
411
412        let mut output_fields =
413            Vec::with_capacity(input_schema.fields.len() + projected_schema.fields.len());
414        // TakeExec always moves the _rowid to the start of the output schema
415        output_fields.extend(input_schema.fields.iter().map(|f| {
416            let f = Field::try_from(f.as_ref()).unwrap();
417            if let Some(ds_field) = dataset_schema.field(&f.name) {
418                top_level_fields_added.insert(ds_field.id);
419                // Field is in the dataset, it might have new fields added to it
420                if let Some(projected_field) = ds_field.apply_projection(projection) {
421                    f.merge_with_reference(&projected_field, ds_field)
422                } else {
423                    // No new fields added, keep as-is
424                    f
425                }
426            } else {
427                // Field not in dataset, not possible to add extra fields, use as-is
428                f
429            }
430        }));
431
432        // Now we add to the end any brand new top-level fields.  These will be added
433        // dataset schema order.
434        output_fields.extend(
435            projected_schema
436                .fields
437                .into_iter()
438                .filter(|f| !top_level_fields_added.contains(&f.id)),
439        );
440        Schema {
441            fields: output_fields,
442            metadata: dataset_schema.metadata.clone(),
443        }
444    }
445
446    /// Get the dataset.
447    pub fn dataset(&self) -> &Arc<Dataset> {
448        &self.dataset
449    }
450}
451
452impl ExecutionPlan for TakeExec {
453    fn name(&self) -> &str {
454        "TakeExec"
455    }
456
457    fn as_any(&self) -> &dyn std::any::Any {
458        self
459    }
460
461    fn schema(&self) -> SchemaRef {
462        self.output_schema.clone()
463    }
464
465    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
466        vec![&self.input]
467    }
468
469    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
470        // This is an I/O bound operation and wouldn't really benefit from partitioning
471        //
472        // Plus, if we did that, we would be creating multiple schedulers which could use
473        // a lot of RAM.
474        vec![false]
475    }
476
477    /// This preserves the output schema.
478    fn with_new_children(
479        self: Arc<Self>,
480        children: Vec<Arc<dyn ExecutionPlan>>,
481    ) -> Result<Arc<dyn ExecutionPlan>> {
482        if children.len() != 1 {
483            return Err(DataFusionError::Internal(
484                "TakeExec wrong number of children".to_string(),
485            ));
486        }
487
488        let projection = self.output_projection.clone();
489
490        let plan = Self::try_new(self.dataset.clone(), children[0].clone(), projection)?;
491
492        if let Some(plan) = plan {
493            Ok(Arc::new(plan))
494        } else {
495            // Is this legal or do we need to insert a no-op node?
496            Ok(children[0].clone())
497        }
498    }
499
500    fn execute(
501        &self,
502        partition: usize,
503        context: Arc<datafusion::execution::context::TaskContext>,
504    ) -> Result<SendableRecordBatchStream> {
505        let input_stream = self.input.execute(partition, context)?;
506        let dataset = self.dataset.clone();
507        let schema_to_take = self.schema_to_take.clone();
508        let output_schema = self.output_schema.clone();
509        let metrics = self.metrics.clone();
510
511        // ScanScheduler::new launches the I/O scheduler in the background.
512        // We aren't allowed to do work in `execute` and so we defer creation of the
513        // TakeStream until the stream is polled.
514        let lazy_take_stream = futures::stream::once(async move {
515            let obj_store = dataset.object_store.clone();
516            let scheduler_config = SchedulerConfig::max_bandwidth(&obj_store);
517            let scan_scheduler = ScanScheduler::new(obj_store, scheduler_config);
518
519            let take_stream = Arc::new(TakeStream::new(
520                dataset,
521                schema_to_take,
522                output_schema,
523                scan_scheduler,
524                &metrics,
525                partition,
526            ));
527            take_stream.apply(input_stream)
528        });
529        let output_schema = self.output_schema.clone();
530        Ok(Box::pin(RecordBatchStreamAdapter::new(
531            output_schema,
532            lazy_take_stream.flatten(),
533        )))
534    }
535
536    fn metrics(&self) -> Option<MetricsSet> {
537        Some(self.metrics.clone_inner())
538    }
539
540    fn partition_statistics(
541        &self,
542        partition: Option<usize>,
543    ) -> Result<datafusion::physical_plan::Statistics> {
544        Ok(Statistics {
545            num_rows: self.input.partition_statistics(partition)?.num_rows,
546            ..Statistics::new_unknown(self.schema().as_ref())
547        })
548    }
549
550    fn properties(&self) -> &PlanProperties {
551        &self.properties
552    }
553
554    fn supports_limit_pushdown(&self) -> bool {
555        true
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    use arrow_array::{
564        ArrayRef, Float32Array, Int32Array, RecordBatchIterator, StringArray, StructArray,
565    };
566    use arrow_schema::{DataType, Field, Fields};
567    use datafusion::execution::TaskContext;
568    use lance_arrow::SchemaExt;
569    use lance_core::utils::tempfile::TempStrDir;
570    use lance_core::{datatypes::OnMissing, ROW_ID};
571    use lance_datafusion::{datagen::DatafusionDatagenExt, exec::OneShotExec, utils::MetricsExt};
572    use lance_datagen::{BatchCount, RowCount};
573    use rstest::rstest;
574
575    use crate::{
576        dataset::WriteParams,
577        io::exec::{LanceScanConfig, LanceScanExec},
578        utils::test::NoContextTestFixture,
579    };
580
581    struct TestFixture {
582        dataset: Arc<Dataset>,
583        _tmp_dir_guard: TempStrDir,
584    }
585
586    async fn test_fixture() -> TestFixture {
587        let struct_fields = Fields::from(vec![
588            Arc::new(Field::new("x", DataType::Int32, false)),
589            Arc::new(Field::new("y", DataType::Int32, false)),
590        ]);
591
592        let schema = Arc::new(ArrowSchema::new(vec![
593            Field::new("i", DataType::Int32, false),
594            Field::new("f", DataType::Float32, false),
595            Field::new("s", DataType::Utf8, false),
596            Field::new("struct", DataType::Struct(struct_fields.clone()), false),
597        ]));
598
599        // Write 3 batches.
600        let expected_batches: Vec<RecordBatch> = (0..3)
601            .map(|batch_id| {
602                let value_range = batch_id * 10..batch_id * 10 + 10;
603                let columns: Vec<ArrayRef> = vec![
604                    Arc::new(Int32Array::from_iter_values(value_range.clone())),
605                    Arc::new(Float32Array::from_iter(
606                        value_range.clone().map(|v| v as f32),
607                    )),
608                    Arc::new(StringArray::from_iter_values(
609                        value_range.clone().map(|v| format!("str-{v}")),
610                    )),
611                    Arc::new(StructArray::new(
612                        struct_fields.clone(),
613                        vec![
614                            Arc::new(Int32Array::from_iter(value_range.clone())),
615                            Arc::new(Int32Array::from_iter(value_range)),
616                        ],
617                        None,
618                    )),
619                ];
620                RecordBatch::try_new(schema.clone(), columns).unwrap()
621            })
622            .collect();
623
624        let test_dir = TempStrDir::default();
625        let test_uri = test_dir.as_str();
626        let params = WriteParams {
627            max_rows_per_file: 10,
628            ..Default::default()
629        };
630        let reader =
631            RecordBatchIterator::new(expected_batches.clone().into_iter().map(Ok), schema.clone());
632        Dataset::write(reader, test_uri, Some(params))
633            .await
634            .unwrap();
635
636        TestFixture {
637            dataset: Arc::new(Dataset::open(test_uri).await.unwrap()),
638            _tmp_dir_guard: test_dir,
639        }
640    }
641
642    #[tokio::test]
643    async fn test_take_schema() {
644        let TestFixture { dataset, .. } = test_fixture().await;
645
646        let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]);
647        let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap());
648
649        // With row id
650        let config = LanceScanConfig {
651            with_row_id: true,
652            ..Default::default()
653        };
654        let input = Arc::new(LanceScanExec::new(
655            dataset.clone(),
656            dataset.fragments().clone(),
657            None,
658            scan_schema,
659            config,
660        ));
661
662        let projection = dataset
663            .empty_projection()
664            .union_column("s", OnMissing::Error)
665            .unwrap();
666        let take_exec = TakeExec::try_new(dataset, input, projection)
667            .unwrap()
668            .unwrap();
669        let schema = take_exec.schema();
670        assert_eq!(
671            schema.fields.iter().map(|f| f.name()).collect::<Vec<_>>(),
672            vec!["i", ROW_ID, "s"]
673        );
674    }
675
676    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
677    enum TakeInput {
678        Ids,
679        Addrs,
680        IdsAndAddrs,
681    }
682
683    #[rstest]
684    #[tokio::test]
685    async fn test_simple_take(
686        #[values(TakeInput::Ids, TakeInput::Addrs, TakeInput::IdsAndAddrs)] take_input: TakeInput,
687    ) {
688        let TestFixture {
689            dataset,
690            _tmp_dir_guard,
691        } = test_fixture().await;
692
693        let scan_schema = Arc::new(dataset.schema().project(&["i"]).unwrap());
694        let config = LanceScanConfig {
695            with_row_address: take_input != TakeInput::Ids,
696            with_row_id: take_input != TakeInput::Addrs,
697            ..Default::default()
698        };
699        let input = Arc::new(LanceScanExec::new(
700            dataset.clone(),
701            dataset.fragments().clone(),
702            None,
703            scan_schema,
704            config,
705        ));
706
707        let projection = dataset
708            .empty_projection()
709            .union_column("s", OnMissing::Error)
710            .unwrap();
711        let take_exec = TakeExec::try_new(dataset, input, projection)
712            .unwrap()
713            .unwrap();
714        let schema = take_exec.schema();
715
716        let mut expected_fields = vec!["i"];
717        if take_input != TakeInput::Addrs {
718            expected_fields.push(ROW_ID);
719        }
720        if take_input != TakeInput::Ids {
721            expected_fields.push(ROW_ADDR);
722        }
723        expected_fields.push("s");
724        assert_eq!(&schema.field_names(), &expected_fields);
725
726        let mut stream = take_exec
727            .execute(0, Arc::new(TaskContext::default()))
728            .unwrap();
729
730        while let Some(batch) = stream.try_next().await.unwrap() {
731            assert_eq!(&batch.schema().field_names(), &expected_fields);
732        }
733    }
734
735    #[tokio::test]
736    async fn test_take_order() {
737        let TestFixture {
738            dataset,
739            _tmp_dir_guard,
740        } = test_fixture().await;
741
742        // Grab all row addresses, shuffle them, and select the first 15 (half of the rows)
743        let data = dataset
744            .scan()
745            .project(&["s"])
746            .unwrap()
747            .with_row_address()
748            .try_into_batch()
749            .await
750            .unwrap();
751        let indices = UInt64Array::from(vec![8, 13, 1, 7, 4, 5, 12, 9, 10, 2, 11, 6, 3, 0, 28]);
752        let data = arrow_select::take::take_record_batch(&data, &indices).unwrap();
753
754        let schema = Arc::new(ArrowSchema::new(vec![Field::new(
755            ROW_ADDR,
756            DataType::UInt64,
757            true,
758        )]));
759        let row_addrs = data.project_by_schema(&schema).unwrap();
760
761        // Split into 3 batches of 5
762        let batches = (0..3)
763            .map(|i| {
764                let start = i * 5;
765                row_addrs.slice(start, 5)
766            })
767            .collect::<Vec<_>>();
768
769        let row_addr_stream = futures::stream::iter(batches.clone().into_iter().map(Ok));
770        let row_addr_stream = Box::pin(RecordBatchStreamAdapter::new(schema, row_addr_stream));
771
772        let input = Arc::new(OneShotExec::new(row_addr_stream));
773
774        let projection = dataset
775            .empty_projection()
776            .union_column("s", OnMissing::Error)
777            .unwrap();
778        let take_exec = TakeExec::try_new(dataset, input, projection)
779            .unwrap()
780            .unwrap();
781
782        let stream = take_exec
783            .execute(0, Arc::new(TaskContext::default()))
784            .unwrap();
785
786        let expected = vec![data.slice(0, 5), data.slice(5, 5), data.slice(10, 5)];
787
788        let batches = stream.try_collect::<Vec<_>>().await.unwrap();
789        assert_eq!(batches.len(), 3);
790        for (batch, expected) in batches.into_iter().zip(expected) {
791            assert_eq!(batch.schema().field_names(), vec![ROW_ADDR, "s"]);
792            let expected = expected.project_by_schema(&batch.schema()).unwrap();
793            assert_eq!(batch, expected);
794        }
795
796        let metrics = take_exec.metrics().unwrap();
797        assert_eq!(metrics.output_rows(), Some(15));
798        assert_eq!(metrics.find_count("batches_processed").unwrap().value(), 3);
799    }
800
801    #[tokio::test]
802    async fn test_take_struct() {
803        // When taking fields into an existing struct, the field order should be maintained
804        // according to the schema of the struct.
805        let TestFixture {
806            dataset,
807            _tmp_dir_guard,
808        } = test_fixture().await;
809
810        let scan_schema = Arc::new(dataset.schema().project(&["struct.y"]).unwrap());
811
812        let config = LanceScanConfig {
813            with_row_address: true,
814            ..Default::default()
815        };
816        let input = Arc::new(LanceScanExec::new(
817            dataset.clone(),
818            dataset.fragments().clone(),
819            None,
820            scan_schema,
821            config,
822        ));
823
824        let projection = dataset
825            .empty_projection()
826            .union_column("struct.x", OnMissing::Error)
827            .unwrap();
828
829        let take_exec = TakeExec::try_new(dataset, input, projection)
830            .unwrap()
831            .unwrap();
832
833        let expected_schema = ArrowSchema::new(vec![
834            Field::new(
835                "struct",
836                DataType::Struct(Fields::from(vec![
837                    Arc::new(Field::new("x", DataType::Int32, false)),
838                    Arc::new(Field::new("y", DataType::Int32, false)),
839                ])),
840                false,
841            ),
842            Field::new(ROW_ADDR, DataType::UInt64, true),
843        ]);
844        let schema = take_exec.schema();
845        assert_eq!(schema.as_ref(), &expected_schema);
846
847        let mut stream = take_exec
848            .execute(0, Arc::new(TaskContext::default()))
849            .unwrap();
850
851        while let Some(batch) = stream.try_next().await.unwrap() {
852            assert_eq!(batch.schema().as_ref(), &expected_schema);
853        }
854    }
855
856    #[tokio::test]
857    async fn test_take_no_row_addr() {
858        let TestFixture { dataset, .. } = test_fixture().await;
859
860        let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]);
861        let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap());
862
863        let projection = dataset
864            .empty_projection()
865            .union_column("s", OnMissing::Error)
866            .unwrap();
867
868        // No row address
869        let input = Arc::new(LanceScanExec::new(
870            dataset.clone(),
871            dataset.fragments().clone(),
872            None,
873            scan_schema,
874            LanceScanConfig::default(),
875        ));
876        assert!(TakeExec::try_new(dataset, input, projection).is_err());
877    }
878
879    #[tokio::test]
880    async fn test_with_new_children() -> Result<()> {
881        let TestFixture { dataset, .. } = test_fixture().await;
882
883        let config = LanceScanConfig {
884            with_row_id: true,
885            ..Default::default()
886        };
887
888        let input_schema = Arc::new(dataset.schema().project(&["i"])?);
889        let projection = dataset
890            .empty_projection()
891            .union_column("s", OnMissing::Error)
892            .unwrap();
893
894        let input = Arc::new(LanceScanExec::new(
895            dataset.clone(),
896            dataset.fragments().clone(),
897            None,
898            input_schema,
899            config,
900        ));
901
902        assert_eq!(input.schema().field_names(), vec!["i", ROW_ID],);
903        let take_exec = TakeExec::try_new(dataset.clone(), input.clone(), projection)?.unwrap();
904        assert_eq!(take_exec.schema().field_names(), vec!["i", ROW_ID, "s"],);
905
906        let projection = dataset
907            .empty_projection()
908            .union_columns(["s", "f"], OnMissing::Error)
909            .unwrap();
910
911        let outer_take =
912            Arc::new(TakeExec::try_new(dataset, Arc::new(take_exec), projection)?.unwrap());
913        assert_eq!(
914            outer_take.schema().field_names(),
915            vec!["i", ROW_ID, "s", "f"],
916        );
917
918        // with_new_children should preserve the output schema.
919        let edited = outer_take.with_new_children(vec![input])?;
920        assert_eq!(edited.schema().field_names(), vec!["i", ROW_ID, "f", "s"],);
921        Ok(())
922    }
923
924    #[test]
925    fn no_context_take() {
926        // These tests ensure we can create nodes and call execute without a tokio Runtime
927        // being active.  This is a requirement for proper implementation of a Datafusion foreign
928        // table provider.
929        let fixture = NoContextTestFixture::new();
930        let arc_dasaset = Arc::new(fixture.dataset);
931
932        let input = lance_datagen::gen_batch()
933            .col(ROW_ID, lance_datagen::array::step::<UInt64Type>())
934            .into_df_exec(RowCount::from(50), BatchCount::from(2));
935
936        let take = TakeExec::try_new(
937            arc_dasaset.clone(),
938            input,
939            arc_dasaset
940                .empty_projection()
941                .union_column("text", OnMissing::Error)
942                .unwrap(),
943        )
944        .unwrap()
945        .unwrap();
946
947        take.execute(0, Arc::new(TaskContext::default())).unwrap();
948    }
949}